# Train a POYO model with IBL Brain Wide Map data

References:
- [Dandiset](https://dandiarchive.org/dandiset/000409/draft)
- [explore existing sessions](https://viz.internationalbrainlab.org/app)

In [None]:
from pathlib import Path
import numpy as np
import torch
import torch_brain
from torch_brain.models import POYO
from torch_brain.registry import MODALITY_REGISTRY, ModalitySpec, DataType

from aux_functions import (
    download_model,
    get_dataset_config,
    get_loaders,
    get_unit_ids,
    Transform,
    finetune,
    plot_training_curves,
    compute_r2,
    run_test,
    plot_test_intervals,
)

In [None]:
# Download pre-trained model weights
download_model()

In [None]:
# Dataset config
cfg = get_dataset_config(
    brainset="ibl_processed",
    readout_id="wheel_velocity",
    session_ids=[
        "sub-CSHL059_ses-d2f5a130-b981-4546-8858-c94ae1da75ff_desc-processed_behavior+ecephys.h5",
        #"sub-NYU-21_ses-8c33abef-3d3e-4d42-9f27-445e9def08f9_desc-processed_behavior+ecephys.h5",
        #"sub-UCLA035_ses-6f36868f-5cc1-450c-82fa-6b9829ce0cfe_desc-processed_behavior+ecephys.h5",
    ]
)

# Get dataset and loaders
dir_path = Path(".").resolve()
(
    train_dataset,
    train_loader,
    val_dataset,
    val_loader,
    test_dataset,
    test_loader,
) = get_loaders(
    dir_path=dir_path,
    cfg=cfg,
    window_length=1.0,
    batch_size=16,
)

In [None]:
torch_brain.register_modality(
    name="wheel_velocity",
    dim=1,
    type=DataType.CONTINUOUS,
    timestamp_key="wheel_velocity.timestamps",
    value_key="wheel_velocity.values",
    loss_fn=torch_brain.nn.loss.MSELoss(),
)

readout_spec = MODALITY_REGISTRY["wheel_velocity"]

# Train a model with Motor cortex units

In [None]:
def load_model(checkpoint: str):
    model = POYO.load_pretrained(
        checkpoint_path=checkpoint,
        readout_spec=readout_spec,
        skip_readout=True,
    )
    
    device = (
        torch.device("mps") if torch.backends.mps.is_available()
        else torch.device("cuda:0") if torch.cuda.is_available()
        else torch.device("cpu")
    )
    model.to(device).float()  # float() is important on MPS

    return model

# Load pre-trained weights
model_motor = load_model(checkpoint="poyo_1.ckpt")

In [None]:
filter_str = "motor"

# Reinitialize the vocabs for the new units and sessions
units_ids = get_unit_ids(train_dataset, filter_str=filter_str)
try:
    model_motor.unit_emb.extend_vocab(units_ids)
    model_motor.unit_emb.subset_vocab(units_ids)
    
    model_motor.session_emb.extend_vocab(train_dataset.get_session_ids())
    model_motor.session_emb.subset_vocab(train_dataset.get_session_ids())
except Exception as e:
    print(e)

# Connect tokenizers to Datasets
train_dataset.transform = Transform(model=model_motor)
val_dataset.transform = Transform(model=model_motor)

In [None]:
# Setup Optimizer
optimizer = torch.optim.AdamW(model_motor.parameters(), lr=1e-3)

In [None]:
# Finetune
poyo_motor_r2, poyo_motor_loss, poyo_motor_train_outputs = finetune(
    model_motor,
    optimizer,
    train_loader,
    val_loader,
    num_epochs=8,
    epoch_to_unfreeze=-1,
)

In [None]:
# Visualize the results
plot_training_curves(poyo_motor_r2, poyo_motor_loss)

In [None]:
# Save the finetuned model
model_motor.save_checkpoint(checkpoint_path="poyo_motor.ckpt")

## Train a model with Caudoputamen units

In [None]:
# Load pre-trained weights
model_caudoputamen = load_model(checkpoint="poyo_1.ckpt")

filter_str = "caudoputamen"

# Reinitialize the vocabs for the new units and sessions
units_ids = get_unit_ids(train_dataset, filter_str=filter_str)
try:
    model_caudoputamen.unit_emb.extend_vocab(units_ids)
    model_caudoputamen.unit_emb.subset_vocab(units_ids)
    
    model_caudoputamen.session_emb.extend_vocab(train_dataset.get_session_ids())
    model_caudoputamen.session_emb.subset_vocab(train_dataset.get_session_ids())
except Exception as e:
    print(e)

# Connect tokenizers to Datasets
train_dataset.transform = Transform(model=model_caudoputamen)
val_dataset.transform = Transform(model=model_caudoputamen)

# Setup Optimizer
optimizer = torch.optim.AdamW(model_caudoputamen.parameters(), lr=1e-3)

# Finetune
poyo_caudoputamen_r2, poyo_caudoputamen_loss, poyo_caudoputamen_train_outputs = finetune(
    model_caudoputamen,
    optimizer,
    train_loader,
    val_loader,
    num_epochs=8,
    epoch_to_unfreeze=-1,
)

# Visualize the results
plot_training_curves(poyo_caudoputamen_r2, poyo_caudoputamen_loss)

# Save the finetuned model
model_caudoputamen.save_checkpoint(checkpoint_path="poyo_caudoputamen.ckpt")

## Run inference against Test set

In [None]:
test_motor = run_test(
    test_dataset,
    test_loader,
    model_motor,
)

test_caudoputamen = run_test(
    test_dataset,
    test_loader,
    model_caudoputamen,
)

test_results = dict(
    motor=test_motor,
    caudoputamen=test_caudoputamen
)

In [None]:
plot_test_intervals(
    test_results=test_results,
    n_intervals=10,
    order="top",
)

In [None]:
plot_test_intervals(
    test_results=test_results,
    n_intervals=10,
    order="bottom",
)