In [38]:
import torch 
import numpy as np
import pickle 
import os
from src.data import prepare_mnist
model_path = "/Users/willinki/GIT/Biological-Learning/outputs/prova/long_epochs/model"
models = os.listdir(model_path)

def load_model(epoch: int):
    model_file = os.path.join(model_path, f"model_epoch_{epoch}.pkl")
    if not os.path.exists(model_file):
        raise FileNotFoundError(f"Model file {model_file} does not exist.")
    with open(model_file, 'rb') as f:
        model = pickle.load(f)
    return model

def relaxation_trajectory(classifier, x, y, max_steps, state=None):
    states = []
    unsats = []
    if state is None:
        state = classifier.initialize_state(x, y, "zeros")
    for step in range(max_steps):
        state, _, unsat = classifier.relax(
            state,
            max_steps=1,
            ignore_right=0,
        )
        states.append(state.clone())
        unsats.append(unsat.clone())
    for step in range(max_steps):
        state, _, unsat = classifier.relax(
            state,
            max_steps=1,
            ignore_right=1,
        )
        states.append(state.clone())
        unsats.append(unsat.clone())
    states = torch.stack(states, dim=0)  # T, B, L, N
    states = states.permute(1, 0, 2, 3)  # B, T, L, N
    unsats = torch.stack(unsats, dim=0)  # T, B, L, N
    unsats = unsats.permute(1, 0, 2, 3)  # B, T, L, N
    return states, unsats

def pairwise_overlap(state_1, state_2):
    """
    Computes the pairwise overlap between two states.
    """
    state_1, state_2 = state_1.to(torch.float16), state_2.to(torch.float16)
    overlaps = (state_1 * state_2).sum(dim=-1) / state_1.shape[-1]
    return overlaps

P = 100
C = 10
P_eval = 100
N = 100
binarize = True
seed = 17
device="mps"
train_inputs, train_targets, eval_inputs, eval_targets, projection_matrix = (
    prepare_mnist(
        P * C,
        P_eval * C,
        N,
        binarize,
        seed,
        shuffle=True,
    )
)
train_inputs = train_inputs.to(device)
train_targets = train_targets.to(device)
eval_inputs = eval_inputs.to(device)
eval_targets = eval_targets.to(device)
print("found models: ", models)

found models:  ['model_epoch_17.pkl', 'model_epoch_16.pkl', 'model_epoch_14.pkl', 'model_epoch_15.pkl', 'model_epoch_11.pkl', 'model_epoch_10.pkl', 'model_epoch_9.pkl', 'model_epoch_12.pkl', 'model_epoch_13.pkl', 'model_epoch_8.pkl', 'model_epoch_5.pkl', 'model_epoch_4.pkl', 'model_epoch_6.pkl', 'model_epoch_7.pkl', 'model_epoch_3.pkl', 'model_epoch_18.pkl', 'model_epoch_19.pkl', 'model_epoch_2.pkl', 'model_epoch_0.pkl', 'model_epoch_1.pkl', 'model_epoch_-1.pkl']


In [113]:
classifier = load_model(16)

In [None]:
import seaborn as sns
from src.handler import Handler
from itertools import combinations
from typing import Dict
import pandas as pd

def compute_overlap_evolution(states, times) -> Dict[str, torch.Tensor]:
    # data, time, state
    overlaps_stats = {}
    for time1, time2 in combinations(times, 2):
        state_1 = states[:, time1, :]
        state_2 = states[:, time2, :]
        overlaps = (state_1 * state_2).sum(dim=-1) / state_1.shape[-1]
        overlaps_stats[f"{time1}-{time2}"] = overlaps
    return overlaps_stats

def plot_overlap_from_key(overlaps_stats, key):
    xy = [
        (
            float(k.split('-')[1]),
            overlaps_stats[k]
        ) 
        for k in overlaps_stats.keys() if k.startswith(key)
    ]
    print(len(xy), "overlaps found for key", key)
    x = [item[0] for item in xy]
    y = [item[1].mean().item() for item in xy]
    y_err = [item[1].std().item() for item in xy]
    return x, y, y_err

def table_overlap_evolution(overlaps_stats, keys):
    rows = []
    for key in keys:
        x, y, y_err = plot_overlap_from_key(overlaps_stats, key)
        for xi, yi, yerri in zip(x, y, y_err):
            rows.append({'key': f"{key}-{int(xi)}", 'y': yi, 'y_err': yerri})
    df_table = pd.DataFrame(rows)
    display(df_table)
     

num_layers = 1
states, unsats = relaxation_trajectory(
    classifier,
    train_inputs,
    train_targets,
    max_steps=10,
)

In [115]:
states_1 = states[:, :, 1, :]  # B, T, L, N

In [120]:
overlap_stats = compute_overlap_evolution(states_1, [0, 4, 9, 14, 19])

In [121]:
table_overlap_evolution(overlap_stats, ["0-9", "9-19"])

1 overlaps found for key 0-9
1 overlaps found for key 9-19


Unnamed: 0,key,y,y_err
0,0-9-9,0.939308,0.025686
1,9-19-19,1.0,0.0
