In [1]:
import torch
import numpy as np
import pandas as pd
import sensorium
import warnings
warnings.filterwarnings('ignore')
from nnfabrik.builder import get_data, get_model, get_trainer
import matplotlib.pyplot as plt

In [2]:
seed=31415
sensorium_dataPath = "../data/sensorium_data2022/static22846-10-16-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip"
autistic_mouse_dataPath = "../data/new_data2023/static29027-6-17-1-6-5-GrayImageNetFrame2-7bed7f7379d99271be5d144e5e59a8e7.zip"

In [3]:
dataset_fn = 'sensorium.datasets.static_loaders'

In [4]:
filenames_autistic = [autistic_mouse_dataPath, ]

dataset_config_autistic = {'paths': filenames_autistic,
                 'normalize': True,
                 'include_behavior': False,
                 'include_eye_position': False,
                 'batch_size': 128,
                 'scale':1,
                 }

dataloaders_autistic = get_data(dataset_fn, dataset_config_autistic)

In [5]:
filenames_sens = [sensorium_dataPath, ]

dataset_config_sens = {'paths': filenames_sens,
                 'normalize': True,
                 'include_behavior': False,
                 'include_eye_position': False,
                 'batch_size': 128,
                 'scale':0.25,
                 }

dataloaders_sens = get_data(dataset_fn, dataset_config_sens)

In [6]:
model_fn = 'sensorium.models.stacked_core_full_gauss_readout'
model_config_sens = {'pad_input': False,
  'stack': -1,
  'layers': 4,
  'input_kern': 9,
  'gamma_input': 6.3831,
  'gamma_readout': 0.0076,
  'hidden_kern': 7,
  'hidden_channels': 64,
  'depth_separable': True,
  'grid_mean_predictor': {'type': 'cortex',
   'input_dimensions': 2,
   'hidden_layers': 1,
   'hidden_features': 30,
   'final_tanh': True},
  'init_sigma': 0.1,
  'init_mu_range': 0.3,
  'gauss_type': 'full',
  'shifter': False,
}

model_s = get_model(model_fn=model_fn,
                  model_config=model_config_sens,
                  dataloaders=dataloaders_sens,
                  seed=seed,
                  )

In [7]:
model_sa = get_model(model_fn=model_fn,
                    model_config=model_config_sens,
                    dataloaders=dataloaders_autistic,
                    seed=seed,
                    )

In [8]:
trainer_fn = "sensorium.training.standard_trainer"

trainer_config_sens = {'max_iter': 200,
                 'verbose': True,
                 'lr_decay_steps': 4,
                 'avg_loss': False,
                 'lr_init': 0.009,
                 }

trainer_s = get_trainer(trainer_fn=trainer_fn, 
                     trainer_config=trainer_config_sens)

In [9]:
validation_score_s, trainer_output_s, state_dict_s = trainer_s(model_s, dataloaders_sens, seed=seed)

Epoch 1: 100%|██████████| 36/36 [00:08<00:00,  4.41it/s]
Epoch 2: 100%|██████████| 36/36 [00:03<00:00,  9.17it/s]
Epoch 3: 100%|██████████| 36/36 [00:03<00:00,  9.24it/s]
Epoch 4: 100%|██████████| 36/36 [00:03<00:00,  9.14it/s]
Epoch 5: 100%|██████████| 36/36 [00:03<00:00,  9.16it/s]
Epoch 6: 100%|██████████| 36/36 [00:03<00:00,  9.12it/s]
Epoch 7: 100%|██████████| 36/36 [00:03<00:00,  9.22it/s]
Epoch 8: 100%|██████████| 36/36 [00:03<00:00,  9.07it/s]
Epoch 9: 100%|██████████| 36/36 [00:03<00:00,  9.28it/s]
Epoch 10: 100%|██████████| 36/36 [00:03<00:00,  9.26it/s]
Epoch 11: 100%|██████████| 36/36 [00:03<00:00,  9.09it/s]
Epoch 12: 100%|██████████| 36/36 [00:03<00:00,  9.20it/s]
Epoch 13: 100%|██████████| 36/36 [00:03<00:00,  9.24it/s]
Epoch 14: 100%|██████████| 36/36 [00:03<00:00,  9.26it/s]
Epoch 15: 100%|██████████| 36/36 [00:03<00:00,  9.26it/s]
Epoch 16: 100%|██████████| 36/36 [00:03<00:00,  9.21it/s]
Epoch 17: 100%|██████████| 36/36 [00:03<00:00,  9.12it/s]
Epoch 18: 100%|████████

Epoch    37: reducing learning rate of group 0 to 2.7000e-03.


Epoch 38: 100%|██████████| 36/36 [00:04<00:00,  8.83it/s]
Epoch 39: 100%|██████████| 36/36 [00:04<00:00,  8.79it/s]
Epoch 40: 100%|██████████| 36/36 [00:04<00:00,  8.90it/s]
Epoch 41: 100%|██████████| 36/36 [00:04<00:00,  8.96it/s]
Epoch 42: 100%|██████████| 36/36 [00:03<00:00,  9.06it/s]
Epoch 43: 100%|██████████| 36/36 [00:03<00:00,  9.09it/s]
Epoch 44: 100%|██████████| 36/36 [00:03<00:00,  9.25it/s]
Epoch 45: 100%|██████████| 36/36 [00:03<00:00,  9.21it/s]
Epoch 46: 100%|██████████| 36/36 [00:03<00:00,  9.02it/s]
Epoch 47: 100%|██████████| 36/36 [00:03<00:00,  9.25it/s]
Epoch 48:   6%|▌         | 2/36 [00:00<00:03, 10.56it/s]

Epoch    47: reducing learning rate of group 0 to 8.1000e-04.


Epoch 48: 100%|██████████| 36/36 [00:03<00:00,  9.26it/s]
Epoch 49: 100%|██████████| 36/36 [00:03<00:00,  9.23it/s]
Epoch 50: 100%|██████████| 36/36 [00:03<00:00,  9.11it/s]
Epoch 51: 100%|██████████| 36/36 [00:04<00:00,  8.94it/s]
Epoch 52: 100%|██████████| 36/36 [00:03<00:00,  9.24it/s]
Epoch 53: 100%|██████████| 36/36 [00:03<00:00,  9.25it/s]
Epoch 54: 100%|██████████| 36/36 [00:03<00:00,  9.24it/s]
Epoch 55: 100%|██████████| 36/36 [00:03<00:00,  9.24it/s]
Epoch 56: 100%|██████████| 36/36 [00:03<00:00,  9.24it/s]
Epoch 57: 100%|██████████| 36/36 [00:03<00:00,  9.25it/s]
Epoch 58:   6%|▌         | 2/36 [00:00<00:03, 10.50it/s]

Epoch    57: reducing learning rate of group 0 to 2.4300e-04.


Epoch 58: 100%|██████████| 36/36 [00:03<00:00,  9.24it/s]
Epoch 59: 100%|██████████| 36/36 [00:03<00:00,  9.23it/s]
Epoch 60: 100%|██████████| 36/36 [00:03<00:00,  9.17it/s]
Epoch 61: 100%|██████████| 36/36 [00:03<00:00,  9.13it/s]


In [10]:
torch.save(model_s.state_dict(), './model_checkpoints/sensoriumI_model.pth')

In [11]:
model_sa.load_state_dict(torch.load("./model_checkpoints/sensoriumI_model.pth"), strict=False);

In [12]:
trainer_config_autistic = {
    'max_iter': 200,
    'detach_core' : True,
    'verbose': True,
    'lr_decay_steps': 4,
    'avg_loss': False,
    'lr_init': 0.009,
    } 
trainer_sa = get_trainer(trainer_fn=trainer_fn, trainer_config=trainer_config_autistic)

validation_score_sa, trainer_output_sa, state_dict_sa = trainer_sa(model_sa, dataloaders_autistic, seed=seed)

Epoch 1: 100%|██████████| 38/38 [00:07<00:00,  5.15it/s]
Epoch 2: 100%|██████████| 38/38 [00:03<00:00, 12.18it/s]
Epoch 3: 100%|██████████| 38/38 [00:03<00:00, 12.32it/s]
Epoch 4: 100%|██████████| 38/38 [00:03<00:00, 12.27it/s]
Epoch 5: 100%|██████████| 38/38 [00:03<00:00, 12.33it/s]
Epoch 6: 100%|██████████| 38/38 [00:03<00:00, 12.23it/s]
Epoch 7: 100%|██████████| 38/38 [00:03<00:00, 12.26it/s]
Epoch 8: 100%|██████████| 38/38 [00:03<00:00, 12.25it/s]
Epoch 9: 100%|██████████| 38/38 [00:03<00:00, 12.24it/s]
Epoch 10: 100%|██████████| 38/38 [00:03<00:00, 12.27it/s]
Epoch 11: 100%|██████████| 38/38 [00:03<00:00, 12.26it/s]
Epoch 12: 100%|██████████| 38/38 [00:03<00:00, 12.28it/s]
Epoch 13: 100%|██████████| 38/38 [00:03<00:00, 12.25it/s]
Epoch 14: 100%|██████████| 38/38 [00:03<00:00, 12.26it/s]
Epoch 15: 100%|██████████| 38/38 [00:03<00:00, 12.28it/s]
Epoch 16:   5%|▌         | 2/38 [00:00<00:02, 13.09it/s]

Epoch    15: reducing learning rate of group 0 to 2.7000e-03.


Epoch 16: 100%|██████████| 38/38 [00:03<00:00, 12.26it/s]
Epoch 17: 100%|██████████| 38/38 [00:03<00:00, 12.00it/s]
Epoch 18: 100%|██████████| 38/38 [00:03<00:00, 11.72it/s]
Epoch 19: 100%|██████████| 38/38 [00:03<00:00, 11.99it/s]
Epoch 20: 100%|██████████| 38/38 [00:03<00:00, 12.01it/s]
Epoch 21: 100%|██████████| 38/38 [00:03<00:00, 12.02it/s]
Epoch 22: 100%|██████████| 38/38 [00:03<00:00, 12.05it/s]
Epoch 23: 100%|██████████| 38/38 [00:03<00:00, 11.99it/s]
Epoch 24: 100%|██████████| 38/38 [00:03<00:00, 11.86it/s]
Epoch 25: 100%|██████████| 38/38 [00:03<00:00, 12.06it/s]
Epoch 26: 100%|██████████| 38/38 [00:03<00:00, 12.11it/s]
Epoch 27: 100%|██████████| 38/38 [00:03<00:00, 12.12it/s]
Epoch 28: 100%|██████████| 38/38 [00:03<00:00, 12.12it/s]
Epoch 29: 100%|██████████| 38/38 [00:03<00:00, 12.26it/s]
Epoch 30: 100%|██████████| 38/38 [00:03<00:00, 12.22it/s]
Epoch 31: 100%|██████████| 38/38 [00:03<00:00, 12.19it/s]
Epoch 32:   5%|▌         | 2/38 [00:00<00:02, 13.13it/s]

Epoch    31: reducing learning rate of group 0 to 8.1000e-04.


Epoch 32: 100%|██████████| 38/38 [00:03<00:00, 12.12it/s]
Epoch 33: 100%|██████████| 38/38 [00:03<00:00, 12.29it/s]
Epoch 34: 100%|██████████| 38/38 [00:03<00:00, 12.28it/s]
Epoch 35: 100%|██████████| 38/38 [00:03<00:00, 12.31it/s]
Epoch 36: 100%|██████████| 38/38 [00:03<00:00, 12.33it/s]
Epoch 37: 100%|██████████| 38/38 [00:03<00:00, 12.27it/s]
Epoch 38:   5%|▌         | 2/38 [00:00<00:02, 13.31it/s]

Epoch    37: reducing learning rate of group 0 to 2.4300e-04.


Epoch 38: 100%|██████████| 38/38 [00:03<00:00, 12.26it/s]
Epoch 39: 100%|██████████| 38/38 [00:03<00:00, 12.21it/s]
Epoch 40: 100%|██████████| 38/38 [00:03<00:00, 12.13it/s]
Epoch 41: 100%|██████████| 38/38 [00:03<00:00, 12.06it/s]
Epoch 42: 100%|██████████| 38/38 [00:03<00:00, 12.06it/s]
Epoch 43: 100%|██████████| 38/38 [00:03<00:00, 12.08it/s]
Epoch 44: 100%|██████████| 38/38 [00:03<00:00, 12.03it/s]


In [13]:
torch.save(model_sa.state_dict(), './model_checkpoints/sensoriumI_core_autistic_readout.pth')