In [36]:
import numpy as np
import torch

In [37]:
from torch import nn, Tensor
from dataclasses import dataclass


# hormone = 20 [id], 
# neuron = [activation_threshold, activation_decay, signal_multiplier, signal_delay, connection_affinity, spontaneity, ambient_hormone_release_rate, activation_hormone_release, hormone_decay_rate, hormone_range]
#          [offsets...] x 8 max simultaneous hormone effects
#          [decay_rates...] x 8
#          (d x 10) NL (10 x 20) - offsets
#          

# neuron_division_matrix = (d x d)

@dataclass
class Neuron:
  # ACTIVATION PARAMETERS
  activation_threshold: float  # positive
  signal_strength: float  # can be excitatory or inhibitory (positive or negative)
  hormone_emission: Tensor
  hormone_range: float  # drops off to zero at this distance
  
  # INTERNAL STATE
  activation_warmup: float  # if >=1, activation can occur, incremented by tanh(update)
  cell_damage: float  # if >= 1, cell death. min 0. incremented/decremented alongside activation parameter calculation
  mitosis_stage: float  # if >= 1, mitosis. min 0. incremented/decremented alongside activation parameter calculation
  total_receptivity: float  # summed connectivity coefficients of incoming connections
  total_emissivity: float  # summed connectivity coefficients of outgoing connections
  activation_progress: float  # incremented by signals
  hormone_influence: Tensor  # influence by hormones emitted from other neurons
  latent_state: Tensor  # produced exclusively by transforms defined in the genome
  
  # HIDDEN STATE (for implementation purposes only)
  position: Tensor  # 3d global position
  inputs: Tensor  # (MAX_CONNECTIONS x 2), each connection includes a neuron ID and a connectivity coefficient
  outputs: Tensor  # same structure as inputs
  
@dataclass
class Genome:
  pluripotent_latent_state: Tensor  # initial latent state for the first neuron(s)
  derive_parameters_from_state: nn.Module  # takes in internal state, outputs activation parameters and health parameter changes
  passive_transform: nn.Module  # applied continuously
  activation_transform: nn.Module  # takes in internal state, outputs new latent state
  # hormone_transform: nn.Module  # takes in internal state and hormone, outputs new latent state
  # hormone_emission: nn.Module  # takes in internal state, outputs hormone code and emission range
  hormone_decay: Tensor  # values in [0, 1] that are continuously multiplied against hormone_influence
  connectivity_coefficient: nn.Module  # takes in internal state pair and relative position (emitter and receiver and emitter to receiver), outputs receptivity of latter [0,1]
  mitosis_results: nn.Module  # takes in internal state, outputs new latent state and daughter latent state and mitosis direction
  mitosis_damage: float  # sigmoid(mitosis_health_penalty) / 2 + 0.5 is the cell damage increment upon mitosis
  
  # MUTABILITY
  # ... (modules of the same dimensions as genome which define the rate of genetic drift)
  

In [38]:
from typing import Union, Tuple
# data formatting strategy for efficiency and compactness:
# 1. slot all neuron data into a single 1d tensor, and then stack each neuron tensor.
#    the position of a neuron tensor is its ID except for position 0, which is a dummy
#    for computational purposes
# 2. Genome can just live in a class
from enum import Enum, IntEnum


_HORMONE_DIM = 10
_MAX_CONNECTIONS = 8
_LATENT_DIM = 24


class _Properties(IntEnum):
    ACTIVATION_THRESHOLD     = 0
    SIGNAL_STRENGTH          = 1
    HORMONE_EMISSION         = 2
    HORMONE_RANGE            = 3

    ACTIVATION_WARMUP        = 4
    CELL_DAMAGE              = 5
    MITOSIS_STAGE            = 6

    LATENT_STATE             = 7
    TOTAL_RECEPTIVITY        = 8
    TOTAL_EMISSIVITY         = 9
    ACTIVATION_PROGRESS      = 10
    HORMONE_INFLUENCE        = 11

    POSITION                 = 12
    INPUT_INDICES            = 13
    INPUT_CONNECTIVITY       = 14
    OUTPUT_INDICES           = 15
    OUTPUT_CONNECTIVITY      = 16



_DATA_SIZES = {
    _Properties.ACTIVATION_THRESHOLD:     1,
    _Properties.SIGNAL_STRENGTH:          1,
    _Properties.HORMONE_EMISSION:         _HORMONE_DIM,
    _Properties.HORMONE_RANGE:            1,
  
    _Properties.ACTIVATION_WARMUP:        1,
    _Properties.CELL_DAMAGE:              1,
    _Properties.MITOSIS_STAGE:            1,
  
    _Properties.LATENT_STATE:             _LATENT_DIM,
    _Properties.TOTAL_RECEPTIVITY:        1,
    _Properties.TOTAL_EMISSIVITY:         1,
    _Properties.ACTIVATION_PROGRESS:      1,
    _Properties.HORMONE_INFLUENCE:        _HORMONE_DIM,
  
    _Properties.POSITION:                 3,
    _Properties.INPUT_INDICES:            _MAX_CONNECTIONS,
    _Properties.INPUT_CONNECTIVITY:       _MAX_CONNECTIONS,
    _Properties.OUTPUT_INDICES:           _MAX_CONNECTIONS,
    _Properties.OUTPUT_CONNECTIVITY:      _MAX_CONNECTIONS,
}


def _get_indexing(property_name: _Properties) -> Union[int, slice]:
  property_names, property_sizes = list(_DATA_SIZES.keys()), list(_DATA_SIZES.values())
  property_index = property_names.index(property_name)
  property_size = property_sizes[property_index]
  data_format_start_index = sum(property_sizes[:property_index], 0)
  return data_format_start_index if property_size == 1 else slice(data_format_start_index, data_format_start_index + property_size)
  

def _get_block_segment(start_property: _Properties, end_property: _Properties, offset_from: _Properties = None) -> slice:
  property_names, property_sizes = list(_DATA_SIZES.keys()), list(_DATA_SIZES.values())
  start_index = property_names.index(start_property)
  end_index = property_names.index(end_property)
  data_format_start_index = sum(property_sizes[:start_index], 0)
  data_format_end_index = sum(property_sizes[:end_index + 1], 0)
  if offset_from is not None:
    offset = sum(property_sizes[:property_names.index(offset_from)], 0)
    data_format_start_index -= offset
    data_format_end_index -= offset
  return slice(data_format_start_index, data_format_end_index)


# NOTE: when editing the data format, be careful about reordering properties,
#       as some definitions rely on subsets of properties being in a continuous chunk
class Data(Enum):
    # State-derived Parameters -----------
    # - activation parameters
    ACTIVATION_THRESHOLD     = _get_indexing(_Properties.ACTIVATION_THRESHOLD)
    SIGNAL_STRENGTH          = _get_indexing(_Properties.SIGNAL_STRENGTH)
    # - hormone emission
    HORMONE_EMISSION         = _get_indexing(_Properties.HORMONE_EMISSION)
    HORMONE_RANGE            = _get_indexing(_Properties.HORMONE_EMISSION)
    
    # State Parameters -------------------
    # - incremented/decremented parameters
    ACTIVATION_WARMUP        = _get_indexing(_Properties.ACTIVATION_WARMUP)
    CELL_DAMAGE              = _get_indexing(_Properties.CELL_DAMAGE)
    MITOSIS_STAGE            = _get_indexing(_Properties.MITOSIS_STAGE)
    
    # - direct parameters
    LATENT_STATE             = _get_indexing(_Properties.LATENT_STATE)
    TOTAL_RECEPTIVITY        = _get_indexing(_Properties.TOTAL_RECEPTIVITY)
    TOTAL_EMISSIVITY         = _get_indexing(_Properties.TOTAL_EMISSIVITY)
    ACTIVATION_PROGRESS      = _get_indexing(_Properties.ACTIVATION_PROGRESS)
    HORMONE_INFLUENCE        = _get_indexing(_Properties.HORMONE_INFLUENCE)
    
    # Hidden State Parameters -------------
    POSITION                 = _get_indexing(_Properties.POSITION)
    INPUT_INDICES            = _get_indexing(_Properties.INPUT_INDICES)
    INPUT_CONNECTIVITY       = _get_indexing(_Properties.INPUT_CONNECTIVITY)
    OUTPUT_INDICES           = _get_indexing(_Properties.OUTPUT_INDICES)
    OUTPUT_CONNECTIVITY      = _get_indexing(_Properties.OUTPUT_CONNECTIVITY)
    
    # Shortcuts
    # - shortcut slices to subsections of above parameters
    DERIVED_PARAMETERS       = _get_block_segment(_Properties.ACTIVATION_THRESHOLD, _Properties.HORMONE_RANGE)
    STATE                    = _get_block_segment(_Properties.ACTIVATION_WARMUP, _Properties.HORMONE_INFLUENCE)
    INCREMENTED_PARAMETERS   = _get_block_segment(_Properties.ACTIVATION_WARMUP, _Properties.MITOSIS_STAGE)
    
    TRANSFORM_INCREMENTED    = _get_block_segment(
      _Properties.ACTIVATION_WARMUP, _Properties.MITOSIS_STAGE, offset_from=_Properties.ACTIVATION_WARMUP)
    TRANSFORM_LATENT         = _get_block_segment(
      _Properties.LATENT_STATE, _Properties.LATENT_STATE, offset_from=_Properties.ACTIVATION_WARMUP)


NEURON_DATA_DIM = sum(_DATA_SIZES.values())


class Specimen:
  def __init__(self, genome: Genome):
    self.genome = genome
    
    initial_neuron_buffer_size = 16
    self.living_neuron_indices = []
    self.dead_neurons_indices = list(range(initial_neuron_buffer_size))
    self.neurons = torch.zeros((initial_neuron_buffer_size, NEURON_DATA_DIM))
      
  def step(self):
    indices = torch.tensor(self.living_neuron_indices, dtype=torch.int)
    previous_neurons = self.neurons[indices]
    updated_neurons = previous_neurons.clone()
      
    # handle firing and passive neuron state updates
    self._step_activations(previous_neurons, updated_neurons, indices)
    
    # decay hormones
    updated_neurons[:, Data.HORMONE_INFLUENCE.value].mul_(self.genome.hormone_decay)
    
    # absorb hormones
    # handle cell death
    # handle cell division
    # update connectivity
    # update direct parameters
    # update derived parameters
    
  def _step_activations(self, previous_neurons: Tensor, updated_neurons: Tensor, indices):
    # locate neurons that are ready to fire signals - ion threshold reached and firing warmup completed
    activation_threshold_reached = (
            previous_neurons[:, Data.ACTIVATION_PROGRESS.value]
            >= previous_neurons[:, Data.ACTIVATION_THRESHOLD.value]
    )
    activation_ready = updated_neurons[:, Data.ACTIVATION_WARMUP.value] >= 1
    activated = activation_threshold_reached & activation_ready
    
    # process the current state of firing neurons and non-firing neurons to get state changes
    activated_state_update = self.genome.activation_transform(previous_neurons[activated, Data.STATE.value])
    passive_state_update = self.genome.passive_transform(previous_neurons[~activated, Data.STATE.value])
    torch.tanh_(activated_state_update[:, Data.TRANSFORM_INCREMENTED.value])
    torch.tanh_(passive_state_update[:, Data.TRANSFORM_INCREMENTED.value])
    
    # set activation warmup to 0 for neurons that have just fired
    updated_neurons[activated, Data.ACTIVATION_WARMUP.value].zero_()
    # increment activation_warmup, cell_damage, and mitosis_stage
    updated_neurons[activated, Data.INCREMENTED_PARAMETERS.value].add_(activated_state_update[:, Data.TRANSFORM_INCREMENTED.value])
    updated_neurons[~activated, Data.INCREMENTED_PARAMETERS.value].add_(passive_state_update[:, Data.TRANSFORM_INCREMENTED.value])
    # set the updates to the latent state
    updated_neurons[activated, Data.LATENT_STATE.value].copy_(activated_state_update[:, Data.TRANSFORM_LATENT.value])
    updated_neurons[~activated, Data.LATENT_STATE.value].copy_(passive_state_update[:, Data.TRANSFORM_LATENT.value])
    
    # send signals to destination neurons
    signal_destinations = previous_neurons[activated, Data.OUTPUT_INDICES.value].reshape(-1)
    signal_connectivity = previous_neurons[activated, Data.OUTPUT_CONNECTIVITY.value]
    signal_strengths = (signal_connectivity * previous_neurons[activated, Data.SIGNAL_STRENGTH.value]).reshape(-1)
      
    cumulative_signals = torch.zeros_like(self.neurons[:, Data.ACTIVATION_PROGRESS.value])
    cumulative_signals.scatter_add_(0, signal_destinations, signal_strengths)
    updated_neurons[:, Data.ACTIVATION_PROGRESS.value].add_(cumulative_signals[indices])

  def add_neurons(self, positions: Tensor, latent_states: Tensor, set_parameters: bool = True):
    """
    
    
    Cell damage, mitosis stage, total receptivity, total emissivity, activation progress,
    and hormone influence are initialized as zeros.
    
    :param positions: 
    :param latent_states: 
    :param set_parameters:
    :return: 
    """
    neuron_indices = self._allocate_neurons(positions.size(0))
    
    self.neurons[neuron_indices, :] = 0
    self.neurons[neuron_indices, Data.POSITION.value] = positions
    self.neurons[neuron_indices, Data.LATENT_STATE.value] = latent_states
    
    if set_parameters:
      parameters = self.genome.derive_parameters_from_state(self.neurons[neuron_indices, Data.STATE.value])
      self.neurons[neuron_indices, Data.DERIVED_PARAMETERS.value] = parameters
    
      
  def _allocate_neurons(self, neuron_count: int) -> Tensor:
    # allocate more space if necessary
    if len(self.dead_neurons_indices) < neuron_count:
      current_neuron_buffer_size = self.neurons.size(0)
      
      # keep doubling the total size of the buffer until we have enough space
      extension_size = current_neuron_buffer_size
      extension_factor = 2
      while len(self.dead_neurons_indices) + extension_size < neuron_count:
        extension_size += extension_factor * current_neuron_buffer_size
        extension_factor *= 2
      self.neurons = torch.cat((self.neurons, torch.zeros((extension_size, NEURON_DATA_DIM))))
      self.dead_neurons_indices.extend(range(current_neuron_buffer_size, current_neuron_buffer_size + extension_size))
    allocation = self.dead_neurons_indices[:neuron_count]
    self.dead_neurons_indices = self.dead_neurons_indices[neuron_count:]
    self.living_neuron_indices.extend(allocation)
    return torch.tensor(allocation, dtype=torch.int)
  
  def _deallocate_neurons(self, indices: Tensor):
    indices = set(indices)
    self.living_neuron_indices = list(filter(lambda i: i not in indices, self.living_neuron_indices))
    self.dead_neurons_indices.extend(indices)
          