In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import networkx
from loguru import logger
from tqdm.notebook import tqdm

from pim.models.network import Network
from pim.models.new.stone import StoneExperiment, StoneResults
from pim.models.new.stone.rate import CXRatePontin, CPU4PontinLayer

from pim.models.stone import analysis

logger.remove()


In [None]:
x = np.linspace(0, 1, 100)
bins = np.linspace(0, 1, 10, endpoint=False)
print(bins)
y = (np.digitize(x, bins)-1) / 10
plt.plot(x, x)
plt.plot(x, y)

In [None]:
def create_quantized_layer(N):
    def closure(*args, **kwargs):
        return QuantizedCPU4PontinLayer(N, *args, **kwargs)
    return closure

class QuantizedCPU4PontinLayer(CPU4PontinLayer):
    def __init__(self, N, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.N = N
        self.bins = np.linspace(0, 1, self.N, endpoint = False)
    
    def step(self, network: Network, dt: float):
        """Memory neurons update.
        cpu4[0-7] store optic flow peaking at left 45 deg
        cpu[8-15] store optic flow peaking at right 45 deg."""
        tb1 = network.output(self.TB1)
        tn1 = network.output(self.TN1) * dt
        tn2 = network.output(self.TN2) * dt

        mem_update = np.dot(self.W_TN, tn2)
        mem_update -= np.dot(self.W_TB1, tb1)
        mem_update = np.clip(mem_update, 0, 1)
        mem_update *= self.gain
        self.memory += mem_update
        self.memory -= 0.125 * self.gain * dt
        self.memory = np.clip((np.digitize(self.memory, self.bins) - 1) / self.N, 0.0, 1.0)

In [None]:
parameters = {
    "model": "stone",
    "T_outbound": 1500,
    "T_inbound": 1000,
    "time_subdivision": 1,
    "noise": 0.1,
    "cx": "pontin"
}

def create_experiment(cpu4):
    cx = CXRatePontin(CPU4LayerClass=cpu4, noise = parameters["noise"])
    cx.setup()
    experiment = StoneExperiment(parameters)
    experiment.cx = cx
    return experiment

def run_experiment(cpu4, N = 0, ts = 1, report = False):
    experiment = create_experiment(cpu4)
    experiment.parameters["time_subdivision"] = ts
    results = experiment.run("test")
    if report:
        results.report()
    return np.linalg.norm(results.closest_position())


In [None]:
run_experiment(CPU4PontinLayer, N=1, ts=10, report=True)

In [None]:
mean_benchmark = np.mean([run_experiment(CPU4PontinLayer) for i in tqdm(range(0, 10))])
print(f"Benchmark mean: {mean_benchmark}")

In [None]:
Ns = range(10, 10000, 10)
results1 = [run_experiment(create_quantized_layer(N), ts=1) for N in tqdm(Ns)]
results2 = [run_experiment(create_quantized_layer(N), ts=2) for N in tqdm(Ns)]
results10 = [run_experiment(create_quantized_layer(N), ts=10) for N in tqdm(Ns)]

In [None]:
px.scatter(x=Ns, y=results1, labels={"x": "resolution", "y": "smallest distance from nest"})