In [70]:
import numpy as np
import torch
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib
matplotlib.use("ipympl")
from IPython.display import HTML

In [98]:
class Organism(object):
    def __init__(self, n_genes=0, genotype=None):
        if genotype is not None:
            self.genotype = genotype
        else:
            self.genotype = torch.randn((n_genes,))
        self.phenotype = self.get_phenotype()
        self.fitness = 0

    def get_phenotype(self):
        return self.genotype

    def update_fitness(self, optimal_phenotype, selection_coefficient):
        self.fitness = torch.exp(-torch.linalg.vector_norm(self.phenotype - optimal_phenotype) / 2*selection_coefficient)
        return self.fitness

    def mutate(self, mutation_rate, mutation_magnitude):
        a = torch.rand(1)
        if a < mutation_rate:
            i = torch.randint(low=0, high=self.genotype.shape[0], size=(1,))
            self.genotype[i] += torch.normal(torch.Tensor([0]), torch.Tensor([mutation_magnitude]))
            self.phenotype = self.get_phenotype()

    def reproduce(self, rate, mutation_rate, mutation_magnitude):
        if self.fitness > (1-rate):
            descendant = Organism(genotype=self.genotype)
            descendant.mutate(mutation_rate, mutation_magnitude)
            return descendant


class Env(object):
    def __init__(self, n_organisms, n_genes, optimal_phenotype, selection_coefficient, mutation_rate, mutation_magnitude, reproduction_rate,
                 optimal_phenotype_mean, optimal_phenotype_std):
        self.n_organisms = n_organisms
        self.n_genes = n_genes
        self.mutation_rate = mutation_rate
        self.mutation_magnitude = mutation_magnitude
        self.optimal_phenotype_mean = optimal_phenotype_mean
        self.optimal_phenotype_std = optimal_phenotype_std
        self.reproduction_rate = reproduction_rate
        self.organisms = [Organism(n_genes=self.n_genes) for _ in range(n_organisms)]
        self.optimal_phenotype = optimal_phenotype
        self.selection_coefficient = selection_coefficient

    def update_optimal_phenotype(self):
        self.optimal_phenotype += torch.normal(torch.Tensor([self.optimal_phenotype_mean]), torch.Tensor([self.optimal_phenotype_std]))

    def step(self):
        # Mutation
        for organism in self.organisms:
            organism.mutate(self.mutation_rate, self.mutation_magnitude)

        # Selection
        self.organisms.sort(key=lambda organism: organism.update_fitness(self.optimal_phenotype, self.selection_coefficient))
        if len(self.organisms) > self.n_organisms:
            self.organisms = self.organisms[0:self.n_organisms]

        # Reproduction
        for organism in self.organisms:
            descendant = organism.reproduce(self.reproduction_rate, self.mutation_rate, self.mutation_magnitude)
            if descendant is not None:
                self.organisms.append(descendant)

        # Update environment
        self.update_optimal_phenotype()

    def run(self, n_steps, logging_interval):
        fig, ax = plt.subplots()
        ax.set_xlim(-10, 10)
        ax.set_ylim(-10, 10)

        plots = []

        def animate(i):
            phenotypes, optimal_phenotype = plots[i]
            ax.clear()
            ax.scatter(phenotypes[:, 0], phenotypes[:, 1])
            ax.scatter(optimal_phenotype[:,0], optimal_phenotype[:,1])
            ax.set_xlim(-10, 10)
            ax.set_ylim(-10, 10)
            return ax,
        
        for i in range(n_steps):
            self.step()
        
            if i % logging_interval == 0:
                phenotypes = np.array([organism.get_phenotype() for organism in self.organisms])
                pca = PCA(n_components=2)
                phenotypes = pca.fit_transform(phenotypes)
                optimal_phenotype = self.optimal_phenotype.reshape((1, -1))
                optimal_phenotype = pca.transform(optimal_phenotype)
                plots.append((phenotypes, optimal_phenotype))

        anim = animation.FuncAnimation(fig, animate, repeat=True, frames=len(plots), interval=100).to_html5_video()
        plt.close()
        return HTML(anim)
                
        

In [None]:
env = Env(n_organisms=100, n_genes=10, optimal_phenotype=torch.Tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
          selection_coefficient=.5, mutation_rate=1., mutation_magnitude=.05, reproduction_rate=.5,
          optimal_phenotype_mean=0, optimal_phenotype_std=0.1)
env.run(300, logging_interval=10)