In [26]:
import numpy as np
import torch

In [27]:
from torch import nn, Tensor
from dataclasses import dataclass


# hormone = 20 [id], 
# neuron = [activation_threshold, activation_decay, signal_multiplier, signal_delay, connection_affinity, spontaneity, ambient_hormone_release_rate, activation_hormone_release, hormone_decay_rate, hormone_range]
#          [offsets...] x 8 max simultaneous hormone effects
#          [decay_rates...] x 8
#          (d x 10) NL (10 x 20) - offsets
#          

# neuron_division_matrix = (d x d)

@dataclass
class Neuron:
  # FUNCTIONAL PARAMETERS
  # - activation parameters
  activation_threshold: float  # positive
  activation_cooldown: float  # [0, 1], decremented by sigmoid(update)
  signal_strength: float  # can be excitatory or inhibitory (positive or negative)
  
  # INTERNAL STATE
  cell_damage: float  # if >= 1, cell death. min 0. incremented/decremented alongside activation parameter calculation
  mitosis_stage: float  # if >= 1, mitosis. min 0. incremented/decremented alongside activation parameter calculation
  total_receptivity: float  # summed connectivity coefficients of incoming connections
  total_emissivity: float  # summed connectivity coefficients of outgoing connections
  activation_progress: float  # incremented by signals
  latent_state: Tensor  # produced exclusively by transforms defined in the genome
  hormone_influence: Tensor  # influence by hormones emitted from other neurons
  
  # HIDDEN STATE (for implementation purposes only)
  position: Tensor  # 3d global position
  inputs: Tensor  # (MAX_CONNECTIONS x 2), each connection includes a neuron ID and a connectivity coefficient
  outputs: Tensor  # same structure as inputs
  
@dataclass
class Genome:
  pluripotent_latent_state: Tensor  # initial latent state for the first neuron(s)
  state_to_parameters_and_health: nn.Module  # takes in internal state, outputs activation parameters and health parameter changes
  passive_transform: nn.Module  # applied continuously
  activation_transform: nn.Module  # takes in internal state, outputs new latent state
  # hormone_transform: nn.Module  # takes in internal state and hormone, outputs new latent state
  hormone_emission: nn.Module  # takes in internal state, outputs hormone code and emission range
  hormone_decay: Tensor  # values in [0, 1] that are continuously multiplied against hormone_influence
  connectivity_coefficient: nn.Module  # takes in internal state pair and relative position (emitter and receiver and emitter to receiver), outputs receptivity of latter [0,1]
  mitosis_results: nn.Module  # takes in internal state, outputs new latent state and daughter latent state and mitosis direction
  mitosis_damage: float  # sigmoid(mitosis_health_penalty) / 2 + 0.5 is the cell damage increment upon mitosis
  
  # MUTABILITY
  # ... (modules of the same dimensions as genome which define the rate of genetic drift)
  

In [28]:
from typing import Union, Tuple
# data formatting strategy for efficiency and compactness:
# 1. slot all neuron data into a single 1d tensor, and then stack each neuron tensor.
#    the position of a neuron tensor is its ID except for position 0, which is a dummy
#    for computational purposes
# 2. Genome can just live in a class
from enum import Enum, IntEnum


_MAX_CONNECTIONS = 8


class _Properties(IntEnum):
    ACTIVATION_THRESHOLD     = 0
    ACTIVATION_COOLDOWN      = 1
    SIGNAL_STRENGTH          = 2
    
    CELL_DAMAGE              = 3
    MITOSIS_STAGE            = 4
    TOTAL_RECEPTIVITY        = 5
    TOTAL_EMISSIVITY         = 6
    ACTIVATION_PROGRESS      = 7
    HORMONE_INFLUENCE        = 8
    LATENT_STATE             = 9
    
    POSITION                 = 10
    INPUTS                   = 11
    OUTPUTS                  = 12
    
    FUNCTIONAL_PARAMETERS    = 13


_DATA_SIZES = {
    _Properties.ACTIVATION_THRESHOLD:     1,
    _Properties.ACTIVATION_COOLDOWN:      1,
    _Properties.SIGNAL_STRENGTH:          1,
  
    _Properties.CELL_DAMAGE:              1,
    _Properties.MITOSIS_STAGE:            1,
    _Properties.TOTAL_RECEPTIVITY:        1,
    _Properties.TOTAL_EMISSIVITY:         1,
    _Properties.ACTIVATION_PROGRESS:      1,
    _Properties.HORMONE_INFLUENCE:        10,
    _Properties.LATENT_STATE:             24,
  
    _Properties.POSITION:                 3,
    _Properties.INPUTS:                   2 * _MAX_CONNECTIONS,
    _Properties.OUTPUTS:                  2 * _MAX_CONNECTIONS,
}


def _get_indexing(property_name: _Properties) -> Union[int, slice]:
  property_names, property_sizes = list(_DATA_SIZES.keys()), list(_DATA_SIZES.values())
  property_index = property_names.index(property_name)
  property_size = property_sizes[property_index]
  data_format_start_index = sum(property_sizes[:property_index], 0)
  return data_format_start_index if property_size == 1 else slice(data_format_start_index, data_format_start_index + property_size)
  

def _get_block_segment(start_property: _Properties, end_property: _Properties) -> slice:
  property_names, property_sizes = list(_DATA_SIZES.keys()), list(_DATA_SIZES.values())
  start_index = property_names.index(start_property)
  end_index = property_names.index(end_property)
  data_format_start_index = sum(property_sizes[:start_index], 0)
  data_format_end_index = sum(property_sizes[:end_index + 1], 0)
  return slice(data_format_start_index, data_format_end_index)


class Data(Enum):
    ACTIVATION_THRESHOLD     = _get_indexing(_Properties.ACTIVATION_THRESHOLD)
    ACTIVATION_COOLDOWN      = _get_indexing(_Properties.ACTIVATION_COOLDOWN)
    SIGNAL_STRENGTH          = _get_indexing(_Properties.SIGNAL_STRENGTH)
    
    CELL_DAMAGE              = _get_indexing(_Properties.CELL_DAMAGE)
    MITOSIS_STAGE            = _get_indexing(_Properties.MITOSIS_STAGE)
    TOTAL_RECEPTIVITY        = _get_indexing(_Properties.TOTAL_RECEPTIVITY)
    TOTAL_EMISSIVITY         = _get_indexing(_Properties.TOTAL_EMISSIVITY)
    ACTIVATION_PROGRESS      = _get_indexing(_Properties.ACTIVATION_PROGRESS)
    HORMONE_INFLUENCE        = _get_indexing(_Properties.HORMONE_INFLUENCE)
    LATENT_STATE             = _get_indexing(_Properties.LATENT_STATE)
    
    POSITION                 = _get_indexing(_Properties.POSITION)
    INPUTS                   = _get_indexing(_Properties.INPUTS)
    OUTPUTS                  = _get_indexing(_Properties.OUTPUTS)
    
    FUNCTIONAL_PARAMETERS    = _get_block_segment(_Properties.ACTIVATION_THRESHOLD, _Properties.SIGNAL_STRENGTH)
    STATE                    = _get_block_segment(_Properties.CELL_DAMAGE, _Properties.LATENT_STATE)


NEURON_DATA_DIM = sum(_DATA_SIZES.values())


class Specimen:
  def __init__(self, genome: Genome):
    self.genome = genome
    
    initial_neuron_buffer_size = 16
    self.living_neuron_indices = []
    self.dead_neurons_indices = list(range(initial_neuron_buffer_size))
    self.neurons = torch.zeros((initial_neuron_buffer_size, NEURON_DATA_DIM))
      
  def step(self):
    indices = torch.tensor(self.living_neuron_indices, dtype=torch.int)
    activated_neurons = self.neurons[indices, Data.ACTIVATION_PROGRESS.value] > self.neurons[indices, Data.ACTIVATION_THRESHOLD]

  def add_neurons(self, positions: Tensor, latent_states: Tensor):
    """
    
    
    Cell damage, mitosis stage, total receptivity, total emissivity, activation progress,
    and hormone influence are initialized as zeros.
    
    :param positions: 
    :param latent_states: 
    :return: 
    """
    neuron_indices = self._allocate_neurons(positions.size(0))
    
    self.neurons[neuron_indices, :] = 0
    self.neurons[neuron_indices, Data.POSITION.value] = positions
    self.neurons[neuron_indices, Data.LATENT_STATE.value] = latent_states
    
    parameters = self.genome.state_to_parameters_and_health(self.neurons[neuron_indices, Data.STATE.value])
    functional_parameters = parameters[:, :-2]
    self.neurons[neuron_indices, Data.FUNCTIONAL_PARAMETERS.value] = functional_parameters
    
      
  def _allocate_neurons(self, neuron_count: int) -> Tensor:
    # allocate more space if necessary
    if len(self.dead_neurons_indices) < neuron_count:
      current_neuron_buffer_size = self.neurons.size(0)
      
      # keep doubling the total size of the buffer until we have enough space
      extension_size = current_neuron_buffer_size
      extension_factor = 2
      while len(self.dead_neurons_indices) + extension_size < neuron_count:
        extension_size += extension_factor * current_neuron_buffer_size
        extension_factor *= 2
      self.neurons = torch.cat((self.neurons, torch.zeros((extension_size, NEURON_DATA_DIM))))
      self.dead_neurons_indices.extend(range(current_neuron_buffer_size, current_neuron_buffer_size + extension_size))
    allocation = self.dead_neurons_indices[:neuron_count]
    self.dead_neurons_indices = self.dead_neurons_indices[neuron_count:]
    self.living_neuron_indices.extend(allocation)
    return torch.tensor(allocation, dtype=torch.int)
  
  def _deallocate_neurons(self, indices: Tensor):
    indices = set(indices)
    self.living_neuron_indices = list(filter(lambda i: i not in indices, self.living_neuron_indices))
    self.dead_neurons_indices.extend(indices)
          