In [1]:
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
from tqdm.notebook import tqdm

from src.binary_engine import GFlowBinaryEngine
from src.continuous_engine import GFlowContinuousEngine
from src.functional import correlation_for_all_neurons, sigmoid
from src.model import FlowModel
from src.simulators import NetworkSystemSimulator
from src.utils import plot_neural_activity

In [9]:
F_sa = FlowModel(3, 512)
F_sa.load_state_dict(torch.load("training/binary/exp3/model.pt", weights_only=True))
F_sa.to("cpu")
F_sa.eval()

FlowModel(
  (mlp): Sequential(
    (0): Linear(in_features=9, out_features=512, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=512, out_features=2, bias=True)
  )
)

In [10]:
simulator = NetworkSystemSimulator()
engine = GFlowContinuousEngine(F_sa, simulator, device="cpu")

In [11]:
torch.manual_seed(0)

matrix = torch.eye(3)
x0 = torch.rand(3, )
print(f"{x0 = }")
print(f"{matrix = }")
simulator = NetworkSystemSimulator()
simulation = simulator.simulate_neurons(matrix, 100, x0)
fig = make_subplots(rows=3, cols=1)
fig.add_trace(go.Scatter(x=list(range(100)), y=simulation[0], mode="lines"), row=1, col=1)
fig.add_trace(go.Scatter(x=list(range(100)), y=simulation[1], mode="lines"), row=2, col=1)
fig.add_trace(go.Scatter(x=list(range(100)), y=simulation[2], mode="lines"), row=3, col=1)

x0 = tensor([0.4963, 0.7682, 0.0885])
matrix = tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])


In [12]:
torch.manual_seed(10)

matrix = engine(torch.empty(3, 100)).reshape(3, 3)
x0 = torch.tensor([0.4963, 0.7682, 0.0885])
print(f"{x0 = }")
print(f"{matrix = }")
simulator = NetworkSystemSimulator()
simulation2 = simulator.simulate_neurons(matrix, 100, x0).detach()
fig = make_subplots(rows=3, cols=1)
fig.add_trace(go.Scatter(x=list(range(100)), y=simulation2[0], mode="lines"), row=1, col=1)
fig.add_trace(go.Scatter(x=list(range(100)), y=simulation2[1], mode="lines"), row=2, col=1)
fig.add_trace(go.Scatter(x=list(range(100)), y=simulation2[2], mode="lines"), row=3, col=1)

x0 = tensor([0.4963, 0.7682, 0.0885])
matrix = tensor([[0.0021, 0.0021, 0.0021],
        [0.0018, 0.0018, 0.0018],
        [0.0018, 0.0021, 0.0018]], grad_fn=<ViewBackward0>)


In [13]:
torch.nn.functional.mse_loss(simulation, simulation2)

tensor(0.1023)

In [15]:
diff = (simulation - simulation2).abs()
values = diff.sort().values
fig = make_subplots(rows=3, cols=1)
fig.add_trace(go.Scatter(x=list(range(100)), y=diff[0], mode="lines"), row=1, col=1)
fig.add_trace(go.Scatter(x=list(range(100)), y=diff[1], mode="lines"), row=2, col=1)
fig.add_trace(go.Scatter(x=list(range(100)), y=diff[2], mode="lines"), row=3, col=1)

In [16]:
(simulation - simulation2).square().sum()

tensor(30.6800)