# Full Test of imported model

Here we run through a full test of the imported model, including loading the model, defining the input data, and running the model to get the output. We will also include some additional checks to ensure the model is working correctly. This is a hard task to train on and it's worth noting that the model will not appear to be working correctly AT ALL. 

In order to convince you that the model is actually working, try out the blinking dataset, which has neurons that are blinking at different frequencies. Furthermore if you believe that the problem is that the model does diffusion, there is a secret ```disable_diffusion``` flag that you can use to disable the diffusion process. I don't want people shooting themselves in the foot so the flag must be enabled manually after a ```DifusssionWrapper``` is instantiated.

Disabling Weights and Biases logging requires setting it to false in ```prepare_data_and_model``` and ```train_epoch```.


In [1]:
import torch
import numpy as np
import pickle

from sandman.data_loading.data_loader_map import make_map_loader
from sandman.data_loading.chaotic_rnn_loader import make_chaotic_rnn_loader
from sandman.data_loading.blinking import make_blinking_toy_loader
from sandman.data_loading.data_loader_ibl import make_ibl_loader
from sandman.paths import DATA_DIR

from sandman.models.training import train_epoch, prepare_data_and_model

from sandman.models.utils import TargetSpec, MaskingPolicy
from sandman.models.utils import SpikeCountPoissonTarget, SpikeCountMSETarget, mask_one_region, mask_one_region_some_times, no_mask_policy

from diffusers import DDPMScheduler

### Dataset Loading
Only uncomment one!

In [2]:
#========================== Import of MAP data loading

# session_order = pickle.load(open(DATA_DIR / "tables_and_infos/session_order.pkl", "rb"))
# eids = np.sort(session_order[:5]) # originall :40
# print("Using eids:", eids)

# data_loader, num_neurons, datasets, areaoi_ind, area_ind_list_list, heldout_info_list, trial_type_dict = make_map_loader(
#     eids,
#     batch_size=2,
#     include_opto=False
# )

# ========================== Import of IBL data loading

# with open(DATA_DIR / "tables_and_infos/ibl_eids.txt") as file:
#     eids = [line.rstrip() for line in file]

# eids = eids[:2]
# data_loader, num_neurons, datasets, areaoi_ind, area_ind_list_list, heldout_info_list, trial_type_dict = make_ibl_loader(
#     eids,
#     batch_size=1)

#========================== Import of Chaotic RNN data loading

eids = np.arange(5)
data_loader, num_neurons, _, area_ind_list_list, record_info_list = make_chaotic_rnn_loader(
    eids,
    batch_size=1,
)

#========================== Import of Blinking Toy data loading

# data_loader = make_blinking_toy_loader(T=100, batch_size=1, device="mps")
# num_neurons = [2, 2, 4]

#========================== Print number of neurons

print("Number of neurons:", num_neurons)



Loading existing data session  0
Loading existing data session  1
Loading existing data session  2
Loading existing data session  3
Loading existing data session  4
num_neurons:  [174, 157, 104, 112, 138]
num_trials:  {'train': [167, 165, 147, 136, 172], 'val': [56, 55, 49, 46, 58], 'test': [56, 55, 49, 46, 58]}
Succesfully constructing the dataloader for  train
Succesfully constructing the dataloader for  val
Succesfully constructing the dataloader for  test
Number of neurons: [174, 157, 104, 112, 138]


### Preview of the Dataset

In [3]:
sample_batch = next(iter(data_loader['train']))
for key, value in sample_batch.items():
    print(f"{key}: {value.shape}")
    print(f"The first (at most) 10 elements along each dimension of the {key} tensor are:")
    new_shape = tuple(slice(0, min(dim, 10)) for dim in value.shape)
    print(value[new_shape])
    print()

spikes_data: torch.Size([1, 400, 157])
The first (at most) 10 elements along each dimension of the spikes_data tensor are:
tensor([[[ 4.,  1.,  0.,  0.,  3.,  0.,  0., 15.,  2.,  2.],
         [ 5.,  0.,  0.,  0.,  1.,  0.,  0., 16.,  2.,  1.],
         [ 3.,  0.,  1.,  1.,  3.,  0.,  1., 10.,  4.,  2.],
         [ 7.,  1.,  1.,  1.,  3.,  0.,  0., 14.,  5.,  3.],
         [ 4.,  2.,  0.,  0.,  4.,  1.,  0.,  5.,  3.,  2.],
         [ 4.,  1.,  1.,  0.,  1.,  0.,  0., 10.,  3.,  4.],
         [ 6.,  1.,  0.,  0.,  0.,  1.,  0., 13.,  2.,  0.],
         [ 3.,  1.,  1.,  0.,  0.,  0.,  1., 12.,  4.,  0.],
         [ 7.,  0.,  2.,  0.,  1.,  0.,  1., 14., 13.,  3.],
         [ 6.,  0.,  3.,  0.,  0.,  1.,  2., 17.,  8.,  0.]]])

spikes_timestamps: torch.Size([1, 400])
The first (at most) 10 elements along each dimension of the spikes_timestamps tensor are:
tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])

target: torch.Size([1, 400, 1])
The first (at most) 10 elements along each dimension of the 

### Training/Model Objective!

I'm proud of these. This is a generalization of the neural activation modelling process. Play around with the multitude I made in the ```sandman.models.utils``` module. Make your own too! They only take a couple of lines to make!

In [4]:
target_spec = SpikeCountPoissonTarget(key="spikes_data")
masking_policy = no_mask_policy

### Model Architecture

Good luck! It's tough to make something that works well. You are going to have a lot of ```n_neurons x d_model``` matrices for encoding and decoding so parameter count is mostly dependent on ```d_model```. Don't let it drop below 16 or it will fundementally break RoPE embeddings.

In [5]:
# model_args = {
#     "d_model": 8,        # enough to encode parity + region ID
#     "n_heads": 1,        # single head is sufficient
#     "n_layers": 1,       # one transformer block
#     "K": 1,              # ONE channel per region per time
#     "max_neurons": 8,    # exact
#     "max_regions": 3,    # exact
#     "max_eids": 1,       # exact
#     "use_eid": False,    # eid carries no information here
# }

model_args = {
    # -------- Region STATE dimension --------
    # This is the entire latent the diffusion model operates on.
    # Previously this was "per channel"; now it is the full region state.
    "d_model": 128,

    # -------- Diffusion backbone --------
    "n_layers": 8,
    "n_heads": 8,

    # -------- Local encoder pooling --------
    # Kenc = number of local pooling queries per region-time
    # These DO NOT enter diffusion; they are just an encoder bottleneck.
    "K": 4,

    # -------- Identity / scale parameters --------
    "max_neurons": 4096,
    "max_regions": 512,
    "max_eids": 4096,
    "use_eid": True,

    # -------- Conditioning --------
    # Controls AdaLN conditioning width; small is fine.
    "cond_dim": 64,

    # -------- Decoder identity width --------
    # Width of neuron identity embedding before hypernetwork.
    # Smaller than d_model is good.
    "d_neur_id": 64,
}



noise_scheduler_args = {
    "num_train_timesteps": 1000,
    "beta_start": 1e-4,
    "beta_end": 0.02,
    "beta_schedule": "linear",
}

diffusion, optimizer, lr_scheduler, accelerator, data_loader = prepare_data_and_model(
        data_loader=data_loader,
        model_args=model_args,
        noise_scheduler_args=noise_scheduler_args,
        target_spec=target_spec,
        masking_policy=masking_policy,
        max_epochs=50,
        lr=3e-4,
        weight_decay=1e-4,
        num_warmup_steps=200,
        wandb_enabled=True,
        reconstruct_loss_weight=0.1,
    )

[34m[1mwandb[0m: Currently logged in as: [33mnoah[0m to [32mhttp://localhost:8080[0m. Use [1m`wandb login --relogin`[0m to force relogin


Model has 4.59 million parameters.


In [6]:
sample_batch = next(iter(data_loader["train"]))
print(sample_batch['spikes_data'].shape)

torch.Size([1, 400, 112])


In [None]:
num_epochs = 5

diffusion.set_stage("ae") # joint, diffusion, ae
for epoch in range(num_epochs):
        avg_loss = train_epoch(
            loader=data_loader["train"],
            diffusion=diffusion,
            optimizer=optimizer,
            accelerator=accelerator,
            epoch=epoch,
            log_every=10,
            wandb_enabled=True,
        )

        lr_scheduler.step()

        if accelerator.is_main_process:
            print(f"[Epoch {epoch:03d}] avg loss = {avg_loss:.6f}")

num_epochs = 10

diffusion.set_stage("diffusion") # joint, diffusion, ae
for epoch in range(num_epochs):
        avg_loss = train_epoch(
            loader=data_loader["train"],
            diffusion=diffusion,
            optimizer=optimizer,
            accelerator=accelerator,
            epoch=epoch,
            log_every=10,
            wandb_enabled=True,
        )

        lr_scheduler.step()

        if accelerator.is_main_process:
            print(f"[Epoch {epoch:03d}] avg loss = {avg_loss:.6f}")


100%|██████████| 787/787 [00:37<00:00, 20.87it/s]


[Epoch 000] avg loss = 3.816432


100%|██████████| 787/787 [00:35<00:00, 22.17it/s]


[Epoch 001] avg loss = 2.946109


100%|██████████| 787/787 [00:37<00:00, 20.95it/s]


[Epoch 002] avg loss = 2.701451


100%|██████████| 787/787 [00:35<00:00, 22.43it/s]


[Epoch 003] avg loss = 2.654687


100%|██████████| 787/787 [00:34<00:00, 22.74it/s]


[Epoch 004] avg loss = 2.613941


100%|██████████| 787/787 [01:17<00:00, 10.14it/s]


[Epoch 000] avg loss = 4.397157


100%|██████████| 787/787 [01:39<00:00,  7.94it/s]


[Epoch 001] avg loss = 4.357480


100%|██████████| 787/787 [01:21<00:00,  9.68it/s]


[Epoch 002] avg loss = 3.854009


100%|██████████| 787/787 [01:24<00:00,  9.30it/s]


[Epoch 003] avg loss = 3.689104


100%|██████████| 787/787 [01:19<00:00,  9.86it/s]


[Epoch 004] avg loss = 3.707832


100%|██████████| 787/787 [01:19<00:00,  9.88it/s]


[Epoch 005] avg loss = 3.180093


100%|██████████| 787/787 [01:23<00:00,  9.41it/s]


[Epoch 006] avg loss = 2.636528


100%|██████████| 787/787 [01:17<00:00, 10.19it/s]


[Epoch 007] avg loss = 2.355665


 79%|███████▉  | 624/787 [01:03<00:15, 10.58it/s]

### Inference
This is where you start to wonder what you set reconstruction loss to.

In [None]:
from sandman.models.sampling import sample_with_diffusion

In [None]:
# output = infer_batch(
#     diffusion=diffusion,
#     batch=sample_batch,
#     accelerator=accelerator,
#     masking_policy=masking_policy,
#     use_diffusion=True,
# )  # [B,T,N]

output = sample_with_diffusion(
    diffusion=diffusion,
    batch=sample_batch,
    accelerator=accelerator,
    masking_policy=masking_policy,
    num_steps=1000,
)  # [B,T,N]

In [None]:
print("Output shape:", output.shape)

Output shape: torch.Size([1, 400, 112])


In [None]:
softplus = torch.nn.Softplus()
pred = softplus(output)
# pred = output
print(pred>0.2)

tensor([[[False, False,  True,  ..., False, False, False],
         [False, False,  True,  ..., False, False, False],
         [ True,  True,  True,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False]]], device='mps:0')


In [None]:
if save_model := True:
    model_path = DATA_DIR / "models" / "rnn_example_model.pth"
    torch.save(diffusion.state_dict(), model_path)
    other_info = {
        "model_args": model_args,
        "noise_scheduler_args": noise_scheduler_args,
        "target_spec": target_spec,
        "masking_policy": masking_policy,
        "reconstruct_loss_weight": diffusion.reconstruct_loss_weight,
    }
    torch.save(other_info, DATA_DIR / "models" / "rnn_example_model_info.pth")

NameError: name 'DATA_DIR' is not defined