In [1]:
seed = 42

In [2]:
# Uncomment if imports of sensorium do not work 
# !ln -s ../sensorium

import torch
from nnfabrik.utility.nn_helpers import set_random_seed
set_random_seed(seed)

from sensorium.datasets.mouse_video_loaders import mouse_video_loader
from sensorium.utility.scores import get_correlations
from nnfabrik.builder import get_trainer
from sensorium.models.make_model import make_video_model

device = 'cuda'
torch.cuda.set_device(device)

In [4]:
'''
    paths to data folders like
    paths = [
    '/my_path/dynamic29515-10-12-Video-9b4f6a1a067fe51e15306b9628efea20/',
    '/my_path/dynamic29623-4-9-Video-9b4f6a1a067fe51e15306b9628efea20/',
    '/my_path/dynamic29647-19-8-Video-9b4f6a1a067fe51e15306b9628efea20/',
    '/my_path/dynamic29712-5-9-Video-9b4f6a1a067fe51e15306b9628efea20/',
    '/my_path/dynamic29755-2-8-Video-9b4f6a1a067fe51e15306b9628efea20/'
    ]
    
    the "/" at the end is important
'''

print("Loading data..")
data_loaders = mouse_video_loader(
    paths=paths,
    batch_size=8,
    scale=1,
    max_frame=None,
    frames=80, # frames has to be > 50.
    offset=-1,
    include_behavior=True,
    include_pupil_centers=True,
    cuda=device!='cpu',
)
print('Data loaded')

Loading data..
Data loaded


## GRU Benchmark

In [5]:
equivar_2D_core_dict = dict(
    input_channels=3,
    hidden_channels=8,
    input_kern=9,
    hidden_kern=7,
    layers=4,
    num_rotations=8,
    gamma_input=500,
    skip=0,
    pad_input=False,
    final_nonlinearity=False,
    bias=True,
    momentum=0.9,
    batch_norm=True,
    hidden_dilation=1,
    laplace_padding=None,
    input_regularizer="LaplaceL2norm",
    stack=-1,
    depth_separable=False,
    linear=False,
    attention_conv=False,
    hidden_padding=None,
    use_avg_reg=False,
    final_batchnorm_scale=True,
    gamma_hidden=500_000,
)

gru_dict = dict(
    # input channels should be the last hidden channels from the core_dict
    input_channels=64,
    # rec channels should be the input channels to the readouts
    rec_channels=64,
    input_kern=9,
    rec_kern=9,
    gamma_rec=0,
)

In [6]:
shifter_dict = dict(
    gamma_shifter=0,
    shift_layers=3,
    input_channels_shifter=2,
    hidden_channels_shifter=5,
)


readout_dict = dict(
    bias=True,
    init_mu_range=0.2,
    init_sigma=1.0,
    gamma_readout=0.0,
    gauss_type='full',
    grid_mean_predictor={
        'type': 'cortex',
        'input_dimensions': 2,
        'hidden_layers': 1,
        'hidden_features': 30,
        'final_tanh': True
    },
    share_features=False,
    share_grid=False,
    shared_match_ids=None,
    gamma_grid_dispersion=0.0,
)

In [7]:
gru_2d_model_equivariant = make_video_model(data_loaders,
                 seed,
                 core_dict=equivar_2D_core_dict,
                 core_type='2D_equivariant',
                 readout_dict=readout_dict.copy(),
                 readout_type='gaussian',               
                 use_gru=True,
                 gru_dict=gru_dict,
                 use_shifter=True,
                 shifter_dict=shifter_dict,
                 shifter_type='MLP',
                                            
                 # todo - put this to True if you are using deeplake
                 # first connections to deeplake may take up for 10 mins
                 deeplake_ds=False,
                 )

  xavier_normal(m.weight.data)
  init.constant(m.bias.data, 0.0)


In [8]:
gru_2d_model_equivariant

VideoFiringRateEncoder(
  (core): RotationEquivariant2dCore(
    (_input_weights_regularizer): LaplaceL2norm(
      (laplace): Laplace()
    )
    (features): Sequential(
      (layer0): Sequential(
        (hermite_conv): HermiteConv2D(
          (rotate_hermite): RotateHermite(
            (Rs): ParameterList(
                (0): Parameter containing: [torch.float32 of size 45x45]
                (1): Parameter containing: [torch.float32 of size 45x45]
                (2): Parameter containing: [torch.float32 of size 45x45]
                (3): Parameter containing: [torch.float32 of size 45x45]
                (4): Parameter containing: [torch.float32 of size 45x45]
                (5): Parameter containing: [torch.float32 of size 45x45]
                (6): Parameter containing: [torch.float32 of size 45x45]
                (7): Parameter containing: [torch.float32 of size 45x45]
            )
          )
        )
        (norm): RotationEquivariantBatchNorm2D(
          (batch_n

## Factorized Benchmark

In [9]:
factorised_3D_core_dict = dict(
    input_channels=3,
    hidden_channels=[32, 64, 128],
    spatial_input_kernel=(11,11),
    temporal_input_kernel=11,
    spatial_hidden_kernel=(5,5),
    temporal_hidden_kernel=5,
    stride=1,
    layers=3,
    gamma_input_spatial=10,
    gamma_input_temporal=0.01, 
    bias=True, 
    hidden_nonlinearities='elu', 
    x_shift=0, 
    y_shift=0,
    batch_norm=True, 
    laplace_padding=None,
    input_regularizer='LaplaceL2norm',
    padding=False,
    final_nonlin=True,
    momentum=0.7
)

In [10]:
factorised_3d_model = make_video_model(
    data_loaders,
    seed,
    core_dict=factorised_3D_core_dict,
    core_type='3D_factorised',
    readout_dict=readout_dict.copy(),
    readout_type='gaussian',               
    use_gru=False,
    gru_dict=None,
    use_shifter=True,
    shifter_dict=shifter_dict,
    shifter_type='MLP',
    deeplake_ds=False,
)



In [11]:
factorised_3d_model

VideoFiringRateEncoder(
  (core): Factorized3dCore(
    (_input_weight_regularizer): LaplaceL2norm(
      (laplace): Laplace()
    )
    (temporal_regularizer): DepthLaplaceL21d(
      (laplace): Laplace1d()
    )
    (features): Sequential(
      (layer0): Sequential(
        (conv_spatial): Conv3d(3, 32, kernel_size=(1, 11, 11), stride=(1, 1, 1))
        (conv_temporal): Conv3d(32, 32, kernel_size=(11, 1, 1), stride=(1, 1, 1))
        (norm): BatchNorm3d(32, eps=1e-05, momentum=0.7, affine=True, track_running_stats=True)
        (nonlin): ELU(alpha=1.0)
      )
      (layer1): Sequential(
        (conv_spatial_1): Conv3d(32, 64, kernel_size=(1, 5, 5), stride=(1, 1, 1))
        (conv_temporal_1): Conv3d(64, 64, kernel_size=(5, 1, 1), stride=(1, 1, 1))
        (norm): BatchNorm3d(64, eps=1e-05, momentum=0.7, affine=True, track_running_stats=True)
        (nonlin): ELU(alpha=1.0)
      )
      (layer2): Sequential(
        (conv_spatial_2): Conv3d(64, 128, kernel_size=(1, 5, 5), stride=

## Traning 

In [None]:
trainer_fn = "sensorium.training.video_training_loop.standard_trainer"

trainer_config = {
    'dataloaders' : data_loaders,
    'seed' : seed,
    'use_wandb' : False,
    'verbose': True,
    'lr_decay_steps': 4,
    'lr_init': 0.005,
    'device' : device,
    'detach_core' : False,
    # todo - put this to True if you are using deeplake
    # first connections to deeplake may take up for 10 mins
    'deeplake_ds' : False,
    'checkpoint_save_path': 'benchmarks/'
                 }

trainer = get_trainer(trainer_fn=trainer_fn, 
                 trainer_config=trainer_config)

In [None]:
# replace with factorised_3d_model to train factorized benchmark
validation_score, trainer_output, state_dict = trainer(gru_2d_model_equivariant)

## Make submission

In [None]:
from sensorium.utility.submission import generate_submission

In [None]:
data_loaders2 = mouse_video_loader(
    paths=mice,
    batch_size=1,
    scale=1,
    max_frame=None,
    frames=350, # take all frames
    offset=-1,
    include_behavior=True,
    include_pupil_centers=True,
    cuda=device!='cpu',
    to_cut=False, # take all frames for submission
)

In [None]:
generate_submission(data_loaders2, gru_2d_model_equivariant, deeplake_ds=False, device=device, path='benchmarks/')