# Full Test of imported model

Here we run through a full test of the imported model, including loading the model, defining the input data, and running the model to get the output. We will also include some additional checks to ensure the model is working correctly. This is a hard task to train on and it's worth noting that the model will not appear to be working correctly AT ALL. 

In order to convince you that the model is actually working, try out the blinking dataset, which has neurons that are blinking at different frequencies. Furthermore if you believe that the problem is that the model does diffusion, there is a secret ```disable_diffusion``` flag that you can use to disable the diffusion process. I don't want people shooting themselves in the foot so the flag must be enabled manually after a ```DifusssionWrapper``` is instantiated.

Disabling Weights and Biases logging requires setting it to false in ```prepare_data_and_model``` and ```train_epoch```.


In [None]:
import torch
import numpy as np
import pickle

from sandman.data_loading.data_loader_map import make_map_loader
from sandman.data_loading.chaotic_rnn_loader import make_chaotic_rnn_loader
from sandman.data_loading.blinking import make_blinking_toy_loader
from sandman.paths import DATA_DIR

from sandman.models.training import train_epoch, prepare_data_and_model

from sandman.models.utils import TargetSpec, MaskingPolicy
from sandman.models.utils import SpikeCountTarget, mask_one_region_policy, SpikeCountMSETarget

### Dataset Loading
Only uncomment one!

In [None]:
#========================== Import of MAP data loading

# session_order = pickle.load(open(DATA_DIR / "tables_and_infos/session_order.pkl", "rb"))
# eids = np.sort(session_order[:5]) # originall :40
# print("Using eids:", eids)

# data_loader, num_neurons, datasets, areaoi_ind, area_ind_list_list, heldout_info_list, trial_type_dict = make_map_loader(
#     eids,
#     batch_size=2,
#     include_opto=False
# )

#========================== Import of Chaotic RNN data loading

# eids = np.arange(5)
# data_loader, num_neurons, _, area_ind_list_list, record_info_list = make_chaotic_rnn_loader(
#     eids,
#     batch_size=1,
# )

#========================== Import of Blinking Toy data loading

data_loader = make_blinking_toy_loader(T=100, batch_size=1, device="mps")



### Training/Model Objective!

I'm proud of these. This is a generalization of the neural activation modelling process. Play around with the multitude I made in the ```sandman.models.utils``` module. Make your own too! They only take a couple of lines to make!

In [3]:
target = SpikeCountMSETarget()
masking_policy = mask_one_region_policy

### Model Architecture

Good luck! It's tough to make something that works well. You are going to have a lot of ```n_neurons x d_model``` matrices for encoding and decoding so parameter count is mostly dependent on ```d_model```. Don't let it drop below 16 or it will fundementally break RoPE embeddings.

In [4]:
model_args = {
    "d_model": 32,
    "n_layers": 8,
    "n_heads": 4,
    'encoder_n_heads': 2
}

noise_scheduler_args = {
    "num_train_timesteps": 20,
    "beta_start": 1e-4,
    "beta_end": 0.02,
    "beta_schedule": "linear",
}

diffusion, optimizer, lr_scheduler, accelerator, data_loader = prepare_data_and_model(
    data_loader=data_loader,
    model_args=model_args,
    noise_scheduler_args=noise_scheduler_args,
    max_epochs=100,
    lr=3e-4,
    weight_decay=1e-4,
    target_spec=target,
    masking_policy=masking_policy
)

[34m[1mwandb[0m: Currently logged in as: [33mnoah[0m to [32mhttp://localhost:8080[0m. Use [1m`wandb login --relogin`[0m to force relogin


Model has 1.10 million parameters.


In [5]:
sample_batch = next(iter(data_loader["train"]))
print(sample_batch['spikes_data'].shape)

torch.Size([1, 100, 8])


In [6]:
num_epochs = 1

for epoch in range(num_epochs):
    train_loss = train_epoch(
        loader=data_loader["train"],
        diffusion=diffusion,
        optimizer=optimizer,
        accelerator=accelerator,
        epoch=epoch,
        log_every=50 # Lots of logging for debugging purposes
    )
    print(f"Epoch {epoch} train loss: {train_loss}")

  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 0 train loss: 0.2818349611312151
