# Full Test of imported model

Here we run through a full test of the imported model, including loading the model, defining the input data, and running the model to get the output. We will also include some additional checks to ensure the model is working correctly. This is a hard task to train on and it's worth noting that the model will not appear to be working correctly AT ALL. 

In order to convince you that the model is actually working, try out the blinking dataset, which has neurons that are blinking at different frequencies. Furthermore if you believe that the problem is that the model does diffusion, there is a secret ```disable_diffusion``` flag that you can use to disable the diffusion process. I don't want people shooting themselves in the foot so the flag must be enabled manually after a ```DifusssionWrapper``` is instantiated.

Disabling Weights and Biases logging requires setting it to false in ```prepare_data_and_model``` and ```train_epoch```.


In [1]:
import torch
import numpy as np
import pickle

from sandman.data_loading.data_loader_map import make_map_loader
from sandman.data_loading.chaotic_rnn_loader import make_chaotic_rnn_loader
from sandman.data_loading.blinking import make_blinking_toy_loader
from sandman.data_loading.data_loader_ibl import make_ibl_loader
from sandman.paths import DATA_DIR

from sandman.models.training import train_epoch, prepare_data_and_model

from sandman.models.utils import TargetSpec, MaskingPolicy
from sandman.models.utils import SpikeCountPoissonTarget, SpikeCountMSETarget, mask_one_region, mask_one_region_some_times

from sandman.models.sampling import sample_region_latent_ddpm
from diffusers import DDPMScheduler

### Dataset Loading
Only uncomment one!

In [2]:
#========================== Import of MAP data loading

# session_order = pickle.load(open(DATA_DIR / "tables_and_infos/session_order.pkl", "rb"))
# eids = np.sort(session_order[:5]) # originall :40
# print("Using eids:", eids)

# data_loader, num_neurons, datasets, areaoi_ind, area_ind_list_list, heldout_info_list, trial_type_dict = make_map_loader(
#     eids,
#     batch_size=2,
#     include_opto=False
# )

# ========================== Import of IBL data loading

# with open(DATA_DIR / "tables_and_infos/ibl_eids.txt") as file:
#     eids = [line.rstrip() for line in file]

# eids = eids[:2]
# data_loader, num_neurons, datasets, areaoi_ind, area_ind_list_list, heldout_info_list, trial_type_dict = make_ibl_loader(
#     eids,
#     batch_size=1)

#========================== Import of Chaotic RNN data loading

# eids = np.arange(5)
# data_loader, num_neurons, _, area_ind_list_list, record_info_list = make_chaotic_rnn_loader(
#     eids,
#     batch_size=1,
# )

#========================== Import of Blinking Toy data loading

data_loader = make_blinking_toy_loader(T=100, batch_size=1, device="mps")



### Preview of the Dataset

In [3]:
sample_batch = next(iter(data_loader['train']))
for key, value in sample_batch.items():
    print(f"{key}: {value.shape}")
    print(f"The first (at most) 10 elements along each dimension of the {key} tensor are:")
    new_shape = tuple(slice(0, min(dim, 10)) for dim in value.shape)
    print(value[new_shape])
    print()

spikes_data: torch.Size([1, 100, 8])
The first (at most) 10 elements along each dimension of the spikes_data tensor are:
tensor([[[1., 0., 0., 1., 1., 0., 0., 1.],
         [0., 1., 0., 1., 0., 1., 0., 1.],
         [1., 0., 1., 0., 1., 0., 1., 0.],
         [0., 1., 1., 0., 0., 1., 1., 0.],
         [1., 0., 0., 1., 1., 0., 0., 1.],
         [0., 1., 0., 1., 0., 1., 0., 1.],
         [1., 0., 1., 0., 1., 0., 1., 0.],
         [0., 1., 1., 0., 0., 1., 1., 0.],
         [1., 0., 0., 1., 1., 0., 0., 1.],
         [0., 1., 0., 1., 0., 1., 0., 1.]]], device='mps:0')

neuron_regions: torch.Size([1, 8])
The first (at most) 10 elements along each dimension of the neuron_regions tensor are:
tensor([[0, 0, 1, 1, 2, 2, 2, 2]], device='mps:0')

eid: torch.Size([1])
The first (at most) 10 elements along each dimension of the eid tensor are:
tensor([0], device='mps:0')



### Training/Model Objective!

I'm proud of these. This is a generalization of the neural activation modelling process. Play around with the multitude I made in the ```sandman.models.utils``` module. Make your own too! They only take a couple of lines to make!

In [4]:
target = SpikeCountMSETarget()
masking_policy = mask_one_region

### Model Architecture

Good luck! It's tough to make something that works well. You are going to have a lot of ```n_neurons x d_model``` matrices for encoding and decoding so parameter count is mostly dependent on ```d_model```. Don't let it drop below 16 or it will fundementally break RoPE embeddings.

In [5]:
model_args = {
    "d_model": 64,
    "n_layers": 8,
    "n_heads": 16,
}

noise_scheduler_args = {
    "num_train_timesteps": 100,
    "beta_start": 1e-4,
    "beta_end": 0.02,
    "beta_schedule": "linear",
}

diffusion, optimizer, lr_scheduler, accelerator, data_loader = prepare_data_and_model(
    data_loader=data_loader,
    model_args=model_args,
    noise_scheduler_args=noise_scheduler_args,
    max_epochs=100,
    lr=3e-4,
    weight_decay=1e-4,
    target_spec=target,
    masking_policy=masking_policy,
    reconstruct_loss_weight=0.0,
)

[34m[1mwandb[0m: Network error (HTTPError), entering retry loop.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


Model has 0.60 million parameters.


In [6]:
sample_batch = next(iter(data_loader["train"]))
print(sample_batch['spikes_data'].shape)

torch.Size([1, 100, 8])


In [None]:
num_epochs = 20

for epoch in range(num_epochs):
    train_loss = train_epoch(
        loader=data_loader["train"],
        diffusion=diffusion,
        optimizer=optimizer,
        accelerator=accelerator,
        epoch=epoch,
        log_every=5 # Lots of logging for debugging purposes
    )
    print(f"Epoch {epoch} train loss: {train_loss}")

  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 0 train loss: 1.9944816876649856


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 1 train loss: 1.9946530561447144


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 2 train loss: 1.9924277589321135


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 3 train loss: 1.9936228736639023


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 4 train loss: 1.9915680581331252


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 5 train loss: 1.9917154278755187


  0%|          | 0/1000 [00:00<?, ?it/s]

### Inference
This is where you start to wonder what you set reconstruction loss to.

In [8]:
pred_full, x0_latent = sample_region_latent_ddpm(
    model=diffusion.model,
    masking_policy=masking_policy,
    scheduler=DDPMScheduler(
        num_train_timesteps=1000,
        beta_schedule="linear",  # or cosine if you trained with that
    ),
    batch=sample_batch,
    num_inference_steps=20,
    device="mps",
)

In [9]:
softplus = torch.nn.Softplus()
pred = softplus(pred_full)
print(pred>1)

tensor([[[False, False, False, False, False, False, False, False],
         [False,  True, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False],
         [False,  True, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False],
         [False,  True, False, False, False, False, False, False],
         [ True, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, False],
         [False, False, False, False, False, False, False, Fal

## Analysis:

The above result is completely amazing. I trained for 10 epochs, then turn reconstruction to 1 and trained for another 2 epochs. The model clearly learned to reconstruct the input data, but is failing on one region. I must check if this is a problem with the masking or the model architecture, or the training