In [None]:
import numpy as np
from tqdm import tqdm
import datetime 
import matplotlib.pyplot as plt
from scipy.stats import pearsonr, spearmanr, poisson
from scipy.special import softmax
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import pairwise_distances
from matplotlib import animation
from IPython.display import HTML
import random

def generate(num_predators, num_venomous_prey, num_mimics, d=2, venom_const=0.5):

    predators_detectors_initial_mean = np.array([1 / np.sqrt(2) for _ in range(d)])
    predators_detectors_initial_cov = 0.4**2 * np.eye(d)
    
    venomous_signals_initial_mean = np.array([-1 / np.sqrt(2), *[-1 / np.sqrt(2) for _ in range(d-1)]])
    venomous_signals_initial_cov = 0.4**2 * np.eye(d)
    
    mimic_signals_initial_mean = np.array([*[-1 / np.sqrt(2) for _ in range(d-1)], 1 / np.sqrt(2)])
    mimic_signals_initial_cov = 0.4**2 * np.eye(d)

    detectors = np.random.multivariate_normal(mean=predators_detectors_initial_mean,
                                              cov=predators_detectors_initial_cov,
                                              size=num_predators)
    venomous_signals = np.random.multivariate_normal(mean=venomous_signals_initial_mean,
                                                     cov=venomous_signals_initial_cov,
                                                     size=num_venomous_prey)
    mimic_signals = np.random.multivariate_normal(mean=mimic_signals_initial_mean,
                                                  cov=mimic_signals_initial_cov,
                                                  size=num_mimics)
    
    signals = np.vstack((venomous_signals, mimic_signals))
    
    risk_tols = np.random.exponential(scale=1000., size=num_predators)
    venom_levels = np.concatenate((np.zeros(num_venomous_prey) + venom_const, np.zeros(num_mimics)))

    return detectors, signals, risk_tols, venom_levels

def similarity(detectors, signals, phenotype_type='vector'):
    match phenotype_type:
        case 'vector':
            dist = np.linalg.norm(detectors[:, np.newaxis] - signals, axis=2)
            return - dist**2
            # cossim = cosine_similarity(detectors, signals)
            # return cossim
        case 'bitstring':
            d = signals.shape[1]
            hamming_distances = np.sum(detectors[:, np.newaxis] != signals, axis=2)
            return 1 - hamming_distances/d
        case _:
            raise NotImplementedError

def calculate_preference_matrix(detectors, signals, risk_tols, phenotype_type='vector'):
    similarity_matrix = similarity(detectors, signals, phenotype_type=phenotype_type) 
    return 1 - np.exp(similarity_matrix / risk_tols[:, np.newaxis])
    
def calculate_predation_matrix(detectors, signals, risk_tols, handling_time, attack_freq, phenotype_type='vector'):
    preference_matrix = calculate_preference_matrix(detectors, signals, risk_tols, phenotype_type)
    n_prey = preference_matrix.shape[1]
    n_effective_prey = preference_matrix.sum(1)
    intake_rates = attack_freq / (1 + attack_freq * handling_time * n_effective_prey)
    return intake_rates[:, None] * preference_matrix

def sample_predators(predation_matrix, venom_levels, pred_conversion_ratio, death_rate=0.01):
    num_predators = predation_matrix.shape[0]
    means = (predation_matrix * (1 - venom_levels) * pred_conversion_ratio - predation_matrix * venom_levels).sum(1) - death_rate
    means[means < 0] = 0
    counts = np.random.poisson(means)
    counts[means < 0] = 0
    return np.repeat(np.arange(num_predators), counts)

def sample_prey(predation_matrix, prey_conversion_ratio, venom_levels, death_rate=0.01):
    nv = (venom_levels > 0).sum()
    nm = (venom_levels == 0).sum()
    num_prey = (venom_levels > 0) * nv + (venom_levels == 0) * nm
    predation = predation_matrix.sum(0)
    n_children = np.random.poisson(prey_conversion_ratio * (1 - num_prey / 1000)) - predation
    n_children = np.int64(n_children)
    n_children[n_children < 0] = 0
    return np.repeat(np.arange(nv + nm), n_children)

def phenotype_crossover(phenotypes, parents, phenotype_type='vector'):

    # phenotypes can be signals or detectors
    assert len(parents) % 2 == 0, 'Crossover not implemented yet for odd numbers of parents'
    parent_phenotypes = phenotypes[parents]
    child_phenotypes = np.zeros_like(parent_phenotypes)
    match phenotype_type:
        case 'vector':
            interpolation_values = np.random.rand(parent_phenotypes.shape[0] // 2)[:, np.newaxis]
            child_phenotypes[::2]  = interpolation_values * parent_phenotypes[::2] \
                                    + (1 - interpolation_values) * parent_phenotypes[1::2]
            child_phenotypes[1::2] = (1 - interpolation_values) * parent_phenotypes[::2] \
                                    + interpolation_values * parent_phenotypes[1::2]
            return child_phenotypes
        case 'bitstring':
            raise NotImplementedError
        case _:
            raise NotImplementedError
        

def phenotype_mutate(phenotypes, mutation_rate=0.01, phenotype_type='vector'):
    match phenotype_type:
        case 'vector':
            return phenotypes + np.random.normal(scale=mutation_rate, size=phenotypes.shape)
        case 'bitstring':
            raise NotImplementedError
        case _:
            raise NotImplementedError

def impose_periodic_boundary(vectors, boundary=5):
    """
    Imposes periodic boundary conditions on an array of 2D vectors.
    
    Parameters:
    vectors (np.array): Array of shape (n, 2) containing 2D vectors.
    boundary (float): The boundary value for both x and y dimensions. Default is 5.
    
    Returns:
    np.array: Array of shape (n, 2) with periodic boundary conditions applied.
    """
    # Ensure the input is a numpy array
    vectors = np.array(vectors)
    
    # Apply periodic boundary conditions
    vectors = np.mod(vectors + boundary, 2 * boundary) - boundary
    
    return vectors


def update(detectors, signals, risk_tols, venom_levels, num_venomous, 
           handling_time=1, attack_freq=2, predator_conversion_ratio=1000, prey_conversion_ratio=10000,
           mutation_rate=0.001, phenotype_type='vector'):

    assert np.all(venom_levels[num_venomous:] == 0), 'A mimic has a non-zero venom level'

    predation_matrix = calculate_predation_matrix(detectors, signals, risk_tols, handling_time, attack_freq, phenotype_type=phenotype_type)

    predator_children = sample_predators(predation_matrix, venom_levels, predator_conversion_ratio)
    prey_children = sample_prey(predation_matrix, prey_conversion_ratio, venom_levels)
    
    predator_children_detectors = detectors[predator_children]
    prey_children_signals = signals[prey_children]
    
    predator_childrens_detectors = phenotype_mutate(predator_children_detectors, mutation_rate=mutation_rate, phenotype_type=phenotype_type)
    prey_childrens_signals = phenotype_mutate(prey_children_signals, mutation_rate=mutation_rate, phenotype_type=phenotype_type)

    predator_childrens_detectors = impose_periodic_boundary(predator_childrens_detectors, 2)
    prey_childrens_signals = impose_periodic_boundary(prey_childrens_signals, 2)
    
    predator_childrens_risk_tols = phenotype_mutate(risk_tols[predator_children], mutation_rate=mutation_rate, phenotype_type=phenotype_type)
    predator_childrens_risk_tols = abs(predator_childrens_risk_tols)
    # prey_childrens_venoms = phenotype_mutate(venom_levels[prey_children], mutation_rate=mutation_rate, phenotype_type=phenotype_type)
    # prey_childrens_venoms[prey_childrens_venoms > 0.9999] = 0.9999
    # prey_childrens_venoms[prey_childrens_venoms < 0.0001] = 0.0001
    # prey_childrens_venoms[venom_levels[prey_children] == 0] = 0

    prey_childrens_venoms = venom_levels[prey_children]

    new_num_venomous = (prey_childrens_venoms > 0).sum()

    return predator_childrens_detectors, prey_childrens_signals, predator_childrens_risk_tols, prey_childrens_venoms, new_num_venomous