In [1]:
import torch

L = 5
N = 1000
C = 10
J_D = 0.2
LAMBDA_LEFT = [0, 2.0, 2.0, 2.0, 2.0, 1.0]
LAMBDA_RIGHT = [4.0, 4.0, 4.0, 4.0, 1.0, 4.0]
DEVICE = "cpu"
SEED = 1232
lr = torch.tensor([0.1, 0.1, 0.1, 0.1, 0.1, 0.01, 0.01])
threshold = torch.tensor([2.5, 2.5, 2.5, 2.5, 2.5, 2.5])
weight_decay = torch.tensor([0.001, 0.001, 0.001, 0.001, 0.001, 0.0, 0.0])

In [2]:
from src.batch_me_if_u_can import BatchMeIfUCan
from src.handler import Handler

new = BatchMeIfUCan(
    num_layers=L,
    N=N,
    C=C,
    J_D=J_D,
    lambda_left=LAMBDA_LEFT,
    lambda_right=LAMBDA_RIGHT,
    device=DEVICE,
    seed=SEED,
    lr=lr,
    threshold=threshold,
    weight_decay=weight_decay,
)
handler = Handler(new)

In [3]:
from src.classifier import Classifier

old = Classifier(
    num_layers=L,
    N=N,
    C=C,
    J_D=J_D,
    lambda_left=LAMBDA_LEFT,
    lambda_right=LAMBDA_RIGHT,
    device=DEVICE,
    seed=SEED,
)

In [4]:
from src.data import prepare_mnist

P = 10
P_eval = 10
binarize = True

train_inputs, train_targets, eval_inputs, eval_targets, projection_matrix = (
    prepare_mnist(
        P * C,
        P_eval * C,
        N,
        binarize,
        SEED,
        shuffle=True,
    )
)

In [5]:
B = 16
x = train_inputs[:B]
y = train_targets[:B]

non_diagonal_mask = torch.ones((N, N), dtype=torch.bool)
non_diagonal_mask.fill_diagonal_(False)

tensor([[False,  True,  True,  ...,  True,  True,  True],
        [ True, False,  True,  ...,  True,  True,  True],
        [ True,  True, False,  ...,  True,  True,  True],
        ...,
        [ True,  True,  True,  ..., False,  True,  True],
        [ True,  True,  True,  ...,  True, False,  True],
        [ True,  True,  True,  ...,  True,  True, False]])

In [6]:
old_coup = new.couplings.clone()

# for idx in range(new.L):
#     new.couplings[idx, :, N : 2 * N] = old.couplings[idx, :, :].clone()
#     new.couplings[idx, :, :N][non_diagonal_mask] = 0
#     if idx != new.L - 1:
#         new.couplings[idx, :, 2 * N : 3 * N][non_diagonal_mask] = 0
new.couplings[-2, :, 2 * N : 2 * N + C] = old.W_back.T.clone()
new.couplings[-1, :C, :N] = old.W_forth.clone()

In [6]:
epochs = 8
new_train_acc_history, new_eval_acc_history, new_representations = handler.train_loop(
    epochs,
    train_inputs,
    train_targets,
    5,
    B,
    1,
    eval_inputs,
    eval_targets,
)
old_train_acc_history, old_eval_acc_history, old_representations = old.train_loop(
    epochs,
    train_inputs,
    train_targets,
    5,
    lr,
    threshold,
    weight_decay,
    B,
    1,
    eval_inputs,
    eval_targets,
)

0.692062497138977
0.4657750129699707
0.3788124918937683
0.3241249918937683
0.3245750069618225
0.3207874894142151
0.303849995136261
0.2727625072002411
0.28404998779296875
0.2829124927520752
0.2865374982357025
0.29413750767707825
0.2879124879837036
0.2770499885082245
0.24018749594688416
0.26365000009536743
0.26637500524520874
0.2770000100135803
0.2695874869823456
0.2530750036239624
0.27344998717308044
0.2442374974489212
0.24742500483989716
0.24501250684261322
0.23623749613761902
0.23742499947547913
0.23999999463558197
0.23395000398159027
0.22152499854564667
0.21359999477863312
0.21236249804496765
0.22027499973773956
0.22163750231266022
0.22644999623298645
0.23485000431537628
0.1992875039577484
0.19876250624656677
0.19853749871253967
0.19636249542236328
0.18952499330043793
0.2027125060558319
0.17970000207424164
0.17361250519752502
0.17560000717639923
0.17333750426769257
0.17499999701976776
0.1833374947309494
0.16967499256134033
0.15809999406337738
0.14955000579357147
0.14436249434947968
0

In [7]:
new_train_acc_history, old_train_acc_history

([0.23000000417232513,
  0.3700000047683716,
  0.5,
  0.5299999713897705,
  0.6700000166893005,
  0.6299999952316284,
  0.7200000286102295,
  0.8100000023841858],
 [0.23999999463558197,
  0.36000001430511475,
  0.5600000023841858,
  0.6100000143051147,
  0.7400000095367432,
  0.7400000095367432,
  0.8100000023841858,
  0.8199999928474426])

In [7]:
def check_equality_couplings(new, old):
    for idx in range(new.L):
        assert torch.allclose(
            new.internal_couplings[idx],
            old.couplings[idx, :, :],
            atol=1e-5,
        )
    assert torch.allclose(
        new.W_back,
        old.W_back.T,
        atol=1e-5,
    )
    assert torch.allclose(
        new.W_forth,
        old.W_forth,
        atol=1e-5,
    )

In [6]:
old.generator.manual_seed(0)
new.generator.manual_seed(0)

<torch._C.Generator at 0x111403f90>

In [7]:
sweeps_old, updates_old = old.train_step(x, y, 5, lr, threshold, weight_decay)
sweeps_new = new.train_step(x, y, 5)

0.6969500184059143


In [8]:
old_logits, _, _ = old.inference(x, 5)
new_logits, _, _ = new.inference(x, 5)
torch.allclose(old_logits, new_logits, atol=1e-5)

False

In [9]:
handler.train_epoch(
    train_inputs,
    train_targets,
    5,
    B,
)
old.train_epoch(
    train_inputs,
    train_targets,
    5,
    lr,
    threshold,
    weight_decay,
    B,
)

0.44362500309944153
0.382099986076355
0.3532249927520752
0.32231250405311584
0.3210124969482422
0.3241625130176544
0.31095001101493835


([5, 5, 5, 5, 5, 5, 5],
 [0.44362500309944153,
  0.382099986076355,
  0.3532249927520752,
  0.32231250405311584,
  0.3210124969482422,
  0.3241625130176544,
  0.31095001101493835])

In [12]:
check_equality_couplings(new, old)

AssertionError: 

In [13]:
check_equality_couplings(new, old)

AssertionError: 

In [8]:
new_train_acc_history, old_train_acc_history

([0.11999999731779099, 0.15000000596046448, 0.20999999344348907],
 [0.23999999463558197, 0.36000001430511475, 0.5600000023841858])

## State Initialization

In [5]:
import numpy as np

B = 2
i = np.random.randint(0, len(train_inputs) - B)
x = train_inputs[i : i + B]
y = train_targets[i : i + B]
new_state = new.initialize_state(B, x, y)
old_state, old_readout = old.initialize_neurons_state(B, x)

## Set couplings and states to be equal

In [6]:
new_state[:, 1:-2, :] = old_state.permute(1, 0, 2)
new_state[:, -2, :C] = old_readout

In [7]:
non_diagonal_mask = torch.ones((N, N), dtype=torch.bool)
non_diagonal_mask.fill_diagonal_(False)

for idx in range(new.L):
    new.couplings[idx, :, N : 2 * N] = old.couplings[idx, :, :]
    new.couplings[idx, :, :N][non_diagonal_mask] = 0
    if idx != new.L - 1:
        new.couplings[idx, :, 2 * N : 3 * N][non_diagonal_mask] = 0
new.couplings[-2, :, 2 * N : 2 * N + C] = old.W_back.T
new.couplings[-1, :C, :N] = old.W_forth

## Field Computation

In [8]:
field_new = new.fields(new_state, ignore_right=0)
field_old, readout_field_old = old.local_field(
    old_state, old_readout, ignore_right=False, x=x, y=y
)

In [9]:
different_mask = ~torch.isclose(field_new[0, :-1], field_old[:, 0])
field_new[0, :-1][different_mask]

tensor([-0.0026])

In [10]:
field_old[:, 0][different_mask]

tensor([-0.0026])

In [11]:
field_new[:, -1, :C]

tensor([[-3.2000, -4.4000, -3.8000, -5.0000, -2.4000, -5.0000, -4.2000, -3.6000,
          3.2000, -3.4000],
        [-4.0000,  1.2000, -2.6000, -3.8000, -4.0000, -4.2000, -5.4000, -3.2000,
         -4.0000, -5.8000]])

In [12]:
readout_field_old

tensor([[-3.2000, -4.4000, -3.8000, -5.0000, -2.4000, -5.0000, -4.2000, -3.6000,
          3.2000, -3.4000],
        [-4.0000,  1.2000, -2.6000, -3.8000, -4.0000, -4.2000, -5.4000, -3.2000,
         -4.0000, -5.8000]])

In [13]:
torch.all(old.W_forth == new.couplings[-1, :C, :N])

tensor(True)

In [14]:
new_state[:, -3]

tensor([[ 1., -1., -1., -1., -1.,  1., -1.,  1.,  1.,  1., -1., -1.,  1., -1.,
         -1.,  1.,  1., -1.,  1., -1.,  1.,  1., -1.,  1., -1., -1.,  1., -1.,
          1., -1., -1., -1., -1., -1.,  1., -1.,  1.,  1., -1., -1., -1.,  1.,
         -1.,  1., -1.,  1., -1., -1.,  1., -1.,  1.,  1., -1., -1., -1., -1.,
         -1., -1., -1., -1.,  1., -1., -1.,  1., -1.,  1.,  1., -1., -1., -1.,
         -1., -1.,  1.,  1., -1.,  1., -1., -1.,  1.,  1., -1.,  1., -1.,  1.,
         -1.,  1., -1., -1.,  1.,  1.,  1., -1.,  1., -1.,  1.,  1., -1., -1.,
          1.,  1.],
        [ 1., -1., -1.,  1., -1.,  1., -1., -1., -1.,  1.,  1.,  1.,  1., -1.,
          1.,  1.,  1., -1.,  1.,  1.,  1.,  1.,  1., -1.,  1.,  1., -1., -1.,
          1.,  1.,  1.,  1., -1.,  1., -1., -1., -1.,  1.,  1., -1., -1., -1.,
          1.,  1., -1., -1., -1.,  1.,  1., -1.,  1.,  1.,  1., -1., -1.,  1.,
         -1.,  1., -1., -1.,  1., -1., -1.,  1., -1., -1., -1.,  1.,  1.,  1.,
          1., -1.,  1.,  1.,  1.

In [15]:
new_state[:, -3] @ new.couplings[-1, :C, :N].T

tensor([[ 8.0000e-01, -4.0000e-01,  2.0000e-01, -1.0000e+00,  1.6000e+00,
         -1.0000e+00, -2.0000e-01,  4.0000e-01, -8.0000e-01,  6.0000e-01],
        [ 5.9605e-08, -2.8000e+00,  1.4000e+00,  2.0000e-01, -8.9407e-08,
         -2.0000e-01, -1.4000e+00,  8.0000e-01, -1.4901e-08, -1.8000e+00]])

In [16]:
old.left_field(old_state, x)[1]

tensor([[ 8.0000e-01, -4.0000e-01,  2.0000e-01, -1.0000e+00,  1.6000e+00,
         -1.0000e+00, -2.0000e-01,  4.0000e-01, -8.0000e-01,  6.0000e-01],
        [ 5.9605e-08, -2.8000e+00,  1.4000e+00,  2.0000e-01, -8.9407e-08,
         -2.0000e-01, -1.4000e+00,  8.0000e-01, -1.4901e-08, -1.8000e+00]])

In [17]:
new.left_couplings[-1, :C, :] == old.W_forth

tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, 

In [18]:
new_state[:, -2] @ new.couplings[-1, :, N : 2 * N].T

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.]])

In [19]:
new_state[:, -1] @ new.couplings[-1, :C, 2 * N : 3 * N].T

tensor([[-4., -4., -4., -4., -4., -4., -4., -4.,  4., -4.],
        [-4.,  4., -4., -4., -4., -4., -4., -4., -4., -4.]])

In [20]:
old.right_field(old_state, old_readout, y)[1]

tensor([[-4., -4., -4., -4., -4., -4., -4., -4.,  4., -4.],
        [-4.,  4., -4., -4., -4., -4., -4., -4., -4., -4.]])

## Relaxation

In [21]:
old_final_state, old_final_readout, _ = old.relax(
    old_state, old_readout, x, y, 5, ignore_right=False
)

In [22]:
new_final_state, _ = new.relax(new_state, 5, ignore_right=0)

In [23]:
(
    torch.all(old_final_state[:, 0, :] == new_final_state[0, 1:-2, :]),
    torch.all(old_final_readout == new_final_state[0, -2, :C]),
)

(tensor(True), tensor(False))

## Perceptron Rule

In [None]:
old.perceptron_rule_update(
    old_final_state,
    old_final_readout,
    x,
    y,
    lr=lr,
    threshold=threshold,
    weight_decay=weight_decay,
)

In [None]:
new.perceptron_rule(new_final_state)

In [None]:
idx = 3
torch.all(old.couplings[idx, :, :] == new.couplings[idx, :, N : 2 * N])

In [None]:
for idx in range(new.L):
    print(torch.all(new.couplings[idx, :, N : 2 * N] == old.couplings[idx, :, :]))
    # print(torch.all(new.couplings[idx, :, :N][non_diagonal_mask] == 0))
    # if idx != new.L - 1:
    #     new.couplings[idx, :, 2 * N : 3 * N][non_diagonal_mask] = 0
print(torch.all(new.couplings[-2, :, 2 * N : 2 * N + C] == old.W_back.T))
print(torch.all(new.couplings[-1, :C, :N] == old.W_forth))

## Inference

In [None]:
new_logits, _, _ = new.inference(x, 5)

In [None]:
old_logits, _, _ = old.inference(x, 5)

In [None]:
new_logits

In [None]:
old_logits

## Multi-Layer

In [1]:
import torch

seed = 17
device = "mps"

# Data
N = 100
P = 600
P_eval = 300
C = 10
binarize = True

# Model
H = 400
num_layers = 2
max_steps = 4
init_mode = "zeros"
init_noise = 0.0
fc_left = True
fc_right = False
fc_input = False
symmetric_W = "buggy"
double_dynamics = False
double_update = False
use_local_ce = False
beta_ce = 10.0

# Couplings
lambda_x = 1000.0
lambda_y = 1000.0
lambda_l = 2.0
lambda_r = 2.0
lambda_wback = 5.0
lambda_internal = 1.0
lambda_fc = 1.0
J_D = 0.0

lambda_cylinder = 1000.0
lambda_wback_skip = 1.0
lambda_wforth_skip = 1.0
lr_wforth_skip = 0.1
weight_decay_wforth_skip = 0.005

# Training
num_epochs = 50
batch_size = 16
# lr_J = 0.05
# lr_W = 0.1
# threshold_hidden = 2.5
# threshold_readout = 5.0
# weight_decay_J = 0.1
# weight_decay_W = 0.02

# Evaluation
eval_interval = 1
skip_representations = True
skip_couplings = False

In [2]:
# Assemble the pieces
# lr = [lr_J] * num_layers + [lr_W] * 2
# weight_decay = [weight_decay_J] * num_layers + [weight_decay_W] * 2
# threshold = [threshold_hidden] * num_layers + [threshold_readout]
lambda_left = [lambda_x] + [lambda_l] * (num_layers - 1) + [1.0]
lambda_right = [lambda_r] * (num_layers - 1) + [lambda_wback] + [lambda_y]

lr = [
    0.03,
    0.03,
    0.0,
    0.1,
]
threshold = [
    1.0,
    1.0,
    3.0,
]
weight_decay = [
    0.005,
    0.005,
    0.0,
    0.005,
]

In [3]:
from src.data import prepare_mnist

train_inputs, train_targets, eval_inputs, eval_targets, projection_matrix = (
    prepare_mnist(
        P * C,
        P_eval * C,
        N,
        binarize,
        seed,
        shuffle=True,
    )
)
train_inputs = train_inputs.to(device)
train_targets = train_targets.to(device)
eval_inputs = eval_inputs.to(device)
eval_targets = eval_targets.to(device)

In [4]:
import os

from src.batch_me_if_u_can import BatchMeIfUCan
from src.handler import Handler

output_dir = "prova"
os.makedirs(output_dir, exist_ok=True)

model_kwargs = {
    "num_layers": num_layers,
    "N": N,
    "C": C,
    "lambda_left": lambda_left,
    "lambda_right": lambda_right,
    "lambda_internal": lambda_internal,
    "J_D": J_D,
    "device": device,
    "seed": seed,
    "lr": torch.tensor(lr),
    "threshold": torch.tensor(threshold),
    "weight_decay": torch.tensor(weight_decay),
    "init_mode": init_mode,
    "init_noise": init_noise,
    "symmetric_W": symmetric_W,
    "double_dynamics": double_dynamics,
    "double_update": double_update,
    "use_local_ce": use_local_ce,
    "beta_ce": beta_ce,
    "fc_left": fc_left,
    "fc_right": fc_right,
    "fc_input": fc_input,
    "lambda_fc": lambda_fc,
    "lambda_cylinder": lambda_cylinder,
    "lambda_wback_skip": lambda_wback_skip,
    "lambda_wforth_skip": lambda_wforth_skip,
    "lr_wforth_skip": lr_wforth_skip,
    "weight_decay_wforth_skip": weight_decay_wforth_skip,
    "H": H,
}
model_cls = BatchMeIfUCan  # noqa: F821
model = model_cls(**model_kwargs)
handler = Handler(
    model,
    init_mode,
    skip_representations,
    skip_couplings,
    "prova",
)

In [5]:
layer_idx = 1
i = 0

torch.set_printoptions(precision=1, sci_mode=True)
print("Fraction of learnable weights:", end=" ")
print(
    model.is_learnable[layer_idx][:, H * i : H * (i + 1)]
    .mean(dtype=torch.float32)
    .item()
)
print("Number of learnable weights:", end=" ")
print(
    model.is_learnable[layer_idx][:, H * i : H * (i + 1)].sum().item(), f"out of {H**2}"
)

print("Couplings:")
print(model.couplings[layer_idx][:, H * i : H * (i + 1)])
print("Learning rates:")
print(model.lr[layer_idx][:, H * i : H * (i + 1)])
print("Weight decay:")
print(model.weight_decay[layer_idx][:, H * i : H * (i + 1)])
print("Is learnable:")
print(model.is_learnable[layer_idx][:, H * i : H * (i + 1)])

Fraction of learnable weights: 0.9993749856948853
Number of learnable weights: 159900 out of 160000
Couplings:
tensor([[ 1.0e+03, -3.6e-02,  5.5e-02,  ..., -4.0e-02, -9.9e-03,  3.8e-03],
        [ 4.8e-02,  1.0e+03, -2.1e-02,  ...,  2.8e-02, -1.5e-02,  5.7e-02],
        [ 6.4e-02,  6.5e-02,  1.0e+03,  ..., -3.7e-02, -1.7e-02,  1.3e-02],
        ...,
        [ 4.3e-02,  7.7e-02, -2.1e-02,  ...,  2.0e+00, -2.8e-03, -1.6e-01],
        [ 2.8e-02, -2.5e-02, -6.4e-02,  ...,  3.5e-02,  2.0e+00, -4.0e-02],
        [-5.0e-02, -1.6e-02, -1.6e-02,  ..., -4.6e-02, -4.9e-02,  2.0e+00]],
       device='mps:0')
Learning rates:
tensor([[0.0e+00, 1.5e-03, 1.5e-03,  ..., 1.5e-03, 1.5e-03, 1.5e-03],
        [1.5e-03, 0.0e+00, 1.5e-03,  ..., 1.5e-03, 1.5e-03, 1.5e-03],
        [1.5e-03, 1.5e-03, 0.0e+00,  ..., 1.5e-03, 1.5e-03, 1.5e-03],
        ...,
        [1.5e-03, 1.5e-03, 1.5e-03,  ..., 1.5e-03, 1.5e-03, 1.5e-03],
        [1.5e-03, 1.5e-03, 1.5e-03,  ..., 1.5e-03, 1.5e-03, 1.5e-03],
        [1.5e-03,