In [1]:
import torch
import numpy as np
import pandas as pd
import sensorium
import warnings
warnings.filterwarnings('ignore')
from nnfabrik.builder import get_data, get_model, get_trainer
import matplotlib.pyplot as plt

In [2]:
seed=31415
autistic_mouse_dataPath = "../data/new_data2023/static29027-6-17-1-6-5-GrayImageNetFrame2-7bed7f7379d99271be5d144e5e59a8e7.zip"
wildtype_mouse_dataPath = "../data/new_data2023/static29028-1-17-1-6-5-GrayImageNetFrame2-7bed7f7379d99271be5d144e5e59a8e7.zip"

## Initialize Dataloaders

In [3]:
dataset_fn = 'sensorium.datasets.static_loaders'

In [4]:
filenames_autistic = [autistic_mouse_dataPath, ]

dataset_config_autistic = {'paths': filenames_autistic,
                 'normalize': True,
                 'include_behavior': False,
                 'include_eye_position': False,
                 'batch_size': 128,
                 'scale':1,
                 }

dataloaders_autistic = get_data(dataset_fn, dataset_config_autistic)

In [5]:
filenames_wildtype = [wildtype_mouse_dataPath, ]

dataset_config_wildtype = {'paths': filenames_wildtype,
                 'normalize': True,
                 'include_behavior': False,
                 'include_eye_position': False,
                 'batch_size': 128,
                 'scale':1,
                 }

dataloaders_wildtype = get_data(dataset_fn, dataset_config_wildtype)

## Model pretrained on autistic mousedata and readout fine tuned on wild-type mouse data.

In [6]:
model_fn = 'sensorium.models.stacked_core_full_gauss_readout'

In [7]:
model_config_autistic = {'pad_input': False,
  'stack': -1,
  'layers': 4,
  'input_kern': 9,
  'gamma_input': 6.3831,
  'gamma_readout': 0.0076,
  'hidden_kern': 7,
  'hidden_channels': 64,
  'depth_separable': True,
  'grid_mean_predictor': {'type': 'cortex',
   'input_dimensions': 2,
   'hidden_layers': 1,
   'hidden_features': 30,
   'final_tanh': True},
  'init_sigma': 0.1,
  'init_mu_range': 0.3,
  'gauss_type': 'full',
  'shifter': False,
}

model_a= get_model(model_fn=model_fn,
                  model_config=model_config_autistic,
                  dataloaders=dataloaders_autistic,
                  seed=seed,
                  )

In [8]:
model_awt = get_model(model_fn=model_fn,
                       model_config=model_config_autistic,
                       dataloaders=dataloaders_wildtype,
                       seed=seed)

In [9]:
trainer_fn = "sensorium.training.standard_trainer"

trainer_config_autistic = {
    'max_iter': 200,
    'detach_core' : False,
    'verbose': True,
    'lr_decay_steps': 4,
    'avg_loss': False,
    'lr_init': 0.009,
    }


trainer_a = get_trainer(trainer_fn=trainer_fn, 
                     trainer_config=trainer_config_autistic)

In [10]:
validation_score_a, trainer_output_a, state_dict_a = trainer_a(model_a, dataloaders_autistic, seed=seed)

Epoch 1: 100%|██████████| 38/38 [00:09<00:00,  4.14it/s]
Epoch 2: 100%|██████████| 38/38 [00:04<00:00,  8.74it/s]
Epoch 3: 100%|██████████| 38/38 [00:04<00:00,  9.03it/s]
Epoch 4: 100%|██████████| 38/38 [00:04<00:00,  9.21it/s]
Epoch 5: 100%|██████████| 38/38 [00:04<00:00,  9.03it/s]
Epoch 6: 100%|██████████| 38/38 [00:04<00:00,  9.17it/s]
Epoch 7: 100%|██████████| 38/38 [00:04<00:00,  8.94it/s]
Epoch 8: 100%|██████████| 38/38 [00:04<00:00,  9.19it/s]
Epoch 9: 100%|██████████| 38/38 [00:04<00:00,  9.20it/s]
Epoch 10: 100%|██████████| 38/38 [00:04<00:00,  9.02it/s]
Epoch 11: 100%|██████████| 38/38 [00:04<00:00,  9.09it/s]
Epoch 12: 100%|██████████| 38/38 [00:04<00:00,  9.00it/s]
Epoch 13: 100%|██████████| 38/38 [00:04<00:00,  8.98it/s]
Epoch 14: 100%|██████████| 38/38 [00:04<00:00,  9.12it/s]
Epoch 15: 100%|██████████| 38/38 [00:04<00:00,  9.09it/s]
Epoch 16: 100%|██████████| 38/38 [00:04<00:00,  9.02it/s]
Epoch 17: 100%|██████████| 38/38 [00:04<00:00,  9.16it/s]
Epoch 18: 100%|████████

Epoch    41: reducing learning rate of group 0 to 2.7000e-03.


Epoch 42: 100%|██████████| 38/38 [00:04<00:00,  9.07it/s]
Epoch 43: 100%|██████████| 38/38 [00:04<00:00,  9.04it/s]
Epoch 44: 100%|██████████| 38/38 [00:04<00:00,  8.97it/s]
Epoch 45: 100%|██████████| 38/38 [00:04<00:00,  9.05it/s]
Epoch 46: 100%|██████████| 38/38 [00:04<00:00,  9.06it/s]
Epoch 47: 100%|██████████| 38/38 [00:04<00:00,  9.02it/s]
Epoch 48: 100%|██████████| 38/38 [00:04<00:00,  9.03it/s]
Epoch 49: 100%|██████████| 38/38 [00:04<00:00,  8.94it/s]
Epoch 50: 100%|██████████| 38/38 [00:04<00:00,  9.06it/s]
Epoch 51: 100%|██████████| 38/38 [00:04<00:00,  9.19it/s]
Epoch 52: 100%|██████████| 38/38 [00:04<00:00,  8.92it/s]
Epoch 53: 100%|██████████| 38/38 [00:04<00:00,  9.10it/s]
Epoch 54: 100%|██████████| 38/38 [00:04<00:00,  9.01it/s]
Epoch 55: 100%|██████████| 38/38 [00:04<00:00,  9.17it/s]
Epoch 56: 100%|██████████| 38/38 [00:04<00:00,  9.04it/s]
Epoch 57: 100%|██████████| 38/38 [00:04<00:00,  8.99it/s]
Epoch 58: 100%|██████████| 38/38 [00:04<00:00,  9.05it/s]
Epoch 59: 100%

Epoch    60: reducing learning rate of group 0 to 8.1000e-04.


Epoch 61: 100%|██████████| 38/38 [00:04<00:00,  9.08it/s]
Epoch 62: 100%|██████████| 38/38 [00:04<00:00,  9.02it/s]
Epoch 63: 100%|██████████| 38/38 [00:04<00:00,  8.97it/s]
Epoch 64: 100%|██████████| 38/38 [00:04<00:00,  8.99it/s]
Epoch 65: 100%|██████████| 38/38 [00:04<00:00,  9.06it/s]
Epoch 66: 100%|██████████| 38/38 [00:04<00:00,  8.99it/s]
Epoch 67:   5%|▌         | 2/38 [00:00<00:03, 10.27it/s]

Epoch    66: reducing learning rate of group 0 to 2.4300e-04.


Epoch 67: 100%|██████████| 38/38 [00:04<00:00,  8.99it/s]
Epoch 68: 100%|██████████| 38/38 [00:04<00:00,  9.09it/s]
Epoch 69: 100%|██████████| 38/38 [00:04<00:00,  9.07it/s]


In [11]:
torch.save(model_a.state_dict(), './model_checkpoints/pretrained_autistic_model.pth')

In [12]:
model_awt.load_state_dict(torch.load("./model_checkpoints/pretrained_autistic_model.pth"), strict=False)

_IncompatibleKeys(missing_keys=['readout.29028-1-17-1-6-5.sigma', 'readout.29028-1-17-1-6-5._features', 'readout.29028-1-17-1-6-5.bias', 'readout.29028-1-17-1-6-5.source_grid', 'readout.29028-1-17-1-6-5.mu_transform.0.weight', 'readout.29028-1-17-1-6-5.mu_transform.0.bias', 'readout.29028-1-17-1-6-5.mu_transform.2.weight', 'readout.29028-1-17-1-6-5.mu_transform.2.bias'], unexpected_keys=['readout.29027-6-17-1-6-5.sigma', 'readout.29027-6-17-1-6-5._features', 'readout.29027-6-17-1-6-5.bias', 'readout.29027-6-17-1-6-5.source_grid', 'readout.29027-6-17-1-6-5.mu_transform.0.weight', 'readout.29027-6-17-1-6-5.mu_transform.0.bias', 'readout.29027-6-17-1-6-5.mu_transform.2.weight', 'readout.29027-6-17-1-6-5.mu_transform.2.bias'])

In [14]:
trainer_config_wildtype = {
    'max_iter': 200,
    'detach_core' : True,
    'verbose': True,
    'lr_decay_steps': 4,
    'avg_loss': False,
    'lr_init': 0.009,
    }
trainer_awt = get_trainer(trainer_fn=trainer_fn, trainer_config=trainer_config_wildtype)

validation_score_awt, trainer_output_awt, state_dict_awt = trainer_awt(model_awt, dataloaders_wildtype, seed=seed)

Epoch 1: 100%|██████████| 38/38 [00:07<00:00,  5.05it/s]
Epoch 2: 100%|██████████| 38/38 [00:03<00:00, 11.69it/s]
Epoch 3: 100%|██████████| 38/38 [00:03<00:00, 11.99it/s]
Epoch 4: 100%|██████████| 38/38 [00:03<00:00, 11.87it/s]
Epoch 5: 100%|██████████| 38/38 [00:03<00:00, 11.96it/s]
Epoch 6: 100%|██████████| 38/38 [00:03<00:00, 11.61it/s]
Epoch 7: 100%|██████████| 38/38 [00:03<00:00, 11.81it/s]
Epoch 8: 100%|██████████| 38/38 [00:03<00:00, 11.93it/s]
Epoch 9: 100%|██████████| 38/38 [00:03<00:00, 11.95it/s]
Epoch 10: 100%|██████████| 38/38 [00:03<00:00, 12.15it/s]
Epoch 11: 100%|██████████| 38/38 [00:03<00:00, 11.88it/s]
Epoch 12: 100%|██████████| 38/38 [00:03<00:00, 12.03it/s]
Epoch 13: 100%|██████████| 38/38 [00:03<00:00, 12.04it/s]
Epoch 14: 100%|██████████| 38/38 [00:03<00:00, 11.99it/s]
Epoch 15: 100%|██████████| 38/38 [00:03<00:00, 12.08it/s]
Epoch 16: 100%|██████████| 38/38 [00:03<00:00, 12.16it/s]
Epoch 17:   5%|▌         | 2/38 [00:00<00:02, 13.09it/s]

Epoch    16: reducing learning rate of group 0 to 2.7000e-03.


Epoch 17: 100%|██████████| 38/38 [00:03<00:00, 12.00it/s]
Epoch 18: 100%|██████████| 38/38 [00:03<00:00, 11.86it/s]
Epoch 19: 100%|██████████| 38/38 [00:03<00:00, 11.98it/s]
Epoch 20: 100%|██████████| 38/38 [00:03<00:00, 12.02it/s]
Epoch 21: 100%|██████████| 38/38 [00:03<00:00, 11.93it/s]
Epoch 22: 100%|██████████| 38/38 [00:03<00:00, 12.21it/s]
Epoch 23: 100%|██████████| 38/38 [00:03<00:00, 12.02it/s]
Epoch 24: 100%|██████████| 38/38 [00:03<00:00, 11.96it/s]
Epoch 25: 100%|██████████| 38/38 [00:03<00:00, 12.04it/s]
Epoch 26: 100%|██████████| 38/38 [00:03<00:00, 11.90it/s]
Epoch 27: 100%|██████████| 38/38 [00:03<00:00, 12.14it/s]
Epoch 28: 100%|██████████| 38/38 [00:03<00:00, 12.00it/s]
Epoch 29: 100%|██████████| 38/38 [00:03<00:00, 12.01it/s]
Epoch 30:   5%|▌         | 2/38 [00:00<00:02, 12.39it/s]

Epoch    29: reducing learning rate of group 0 to 8.1000e-04.


Epoch 30: 100%|██████████| 38/38 [00:03<00:00, 11.98it/s]
Epoch 31: 100%|██████████| 38/38 [00:03<00:00, 11.95it/s]
Epoch 32: 100%|██████████| 38/38 [00:03<00:00, 11.98it/s]
Epoch 33: 100%|██████████| 38/38 [00:03<00:00, 11.99it/s]
Epoch 34: 100%|██████████| 38/38 [00:03<00:00, 12.18it/s]
Epoch 35: 100%|██████████| 38/38 [00:03<00:00, 11.98it/s]
Epoch 36:   3%|▎         | 1/38 [00:00<00:05,  6.97it/s]

Epoch    35: reducing learning rate of group 0 to 2.4300e-04.


Epoch 36: 100%|██████████| 38/38 [00:03<00:00, 11.74it/s]
Epoch 37: 100%|██████████| 38/38 [00:03<00:00, 11.99it/s]
Epoch 38: 100%|██████████| 38/38 [00:03<00:00, 12.19it/s]


In [15]:
torch.save(model_awt.state_dict(), './model_checkpoints/autistic_core_wildtype_readout.pth')

## Model pretrained on normal mouse data and fine tuned on autistic mouse data

In [17]:
model_fn = 'sensorium.models.stacked_core_full_gauss_readout'
model_config_wildtype = {'pad_input': False,
  'stack': -1,
  'layers': 4,
  'input_kern': 9,
  'gamma_input': 6.3831,
  'gamma_readout': 0.0076,
  'hidden_kern': 7,
  'hidden_channels': 64,
  'depth_separable': True,
  'grid_mean_predictor': {'type': 'cortex',
   'input_dimensions': 2,
   'hidden_layers': 1,
   'hidden_features': 30,
   'final_tanh': True},
  'init_sigma': 0.1,
  'init_mu_range': 0.3,
  'gauss_type': 'full',
  'shifter': False,
}

model_wt = get_model(model_fn=model_fn,
                  model_config=model_config_wildtype,
                  dataloaders=dataloaders_wildtype,
                  seed=seed,
                  )

In [18]:
model_wta = get_model(model_fn=model_fn,
                    model_config=model_config_wildtype,
                    dataloaders=dataloaders_autistic,
                    seed=seed,
                    )

In [19]:
trainer_fn = "sensorium.training.standard_trainer"

trainer_config_wildtype = {'max_iter': 200,
                 'verbose': True,
                 'lr_decay_steps': 4,
                 'avg_loss': False,
                 'lr_init': 0.009,
                 }

trainer_wt = get_trainer(trainer_fn=trainer_fn, 
                     trainer_config=trainer_config_wildtype)

In [20]:
validation_score_wt, trainer_output_wt, state_dict_wt = trainer_wt(model_wt, dataloaders_wildtype, seed=seed)

Epoch 1: 100%|██████████| 38/38 [00:04<00:00,  9.08it/s]
Epoch 2: 100%|██████████| 38/38 [00:04<00:00,  9.17it/s]
Epoch 3: 100%|██████████| 38/38 [00:04<00:00,  9.28it/s]
Epoch 4: 100%|██████████| 38/38 [00:04<00:00,  9.17it/s]
Epoch 5: 100%|██████████| 38/38 [00:04<00:00,  9.21it/s]
Epoch 6: 100%|██████████| 38/38 [00:04<00:00,  8.99it/s]
Epoch 7: 100%|██████████| 38/38 [00:04<00:00,  9.15it/s]
Epoch 8: 100%|██████████| 38/38 [00:04<00:00,  9.03it/s]
Epoch 9: 100%|██████████| 38/38 [00:04<00:00,  9.27it/s]
Epoch 10: 100%|██████████| 38/38 [00:04<00:00,  9.16it/s]
Epoch 11: 100%|██████████| 38/38 [00:04<00:00,  9.19it/s]
Epoch 12: 100%|██████████| 38/38 [00:04<00:00,  9.20it/s]
Epoch 13: 100%|██████████| 38/38 [00:04<00:00,  9.17it/s]
Epoch 14: 100%|██████████| 38/38 [00:04<00:00,  9.21it/s]
Epoch 15: 100%|██████████| 38/38 [00:04<00:00,  9.12it/s]
Epoch 16: 100%|██████████| 38/38 [00:04<00:00,  9.12it/s]
Epoch 17: 100%|██████████| 38/38 [00:04<00:00,  9.09it/s]
Epoch 18: 100%|████████

Epoch    31: reducing learning rate of group 0 to 2.7000e-03.


Epoch 32: 100%|██████████| 38/38 [00:04<00:00,  9.11it/s]
Epoch 33: 100%|██████████| 38/38 [00:04<00:00,  9.07it/s]
Epoch 34: 100%|██████████| 38/38 [00:04<00:00,  9.21it/s]
Epoch 35: 100%|██████████| 38/38 [00:04<00:00,  9.10it/s]
Epoch 36: 100%|██████████| 38/38 [00:04<00:00,  9.17it/s]
Epoch 37: 100%|██████████| 38/38 [00:04<00:00,  9.15it/s]
Epoch 38: 100%|██████████| 38/38 [00:04<00:00,  9.02it/s]
Epoch 39: 100%|██████████| 38/38 [00:04<00:00,  9.15it/s]
Epoch 40: 100%|██████████| 38/38 [00:04<00:00,  9.09it/s]
Epoch 41: 100%|██████████| 38/38 [00:04<00:00,  9.12it/s]
Epoch 42: 100%|██████████| 38/38 [00:04<00:00,  9.10it/s]
Epoch 43: 100%|██████████| 38/38 [00:04<00:00,  9.09it/s]
Epoch 44: 100%|██████████| 38/38 [00:04<00:00,  9.12it/s]
Epoch 45: 100%|██████████| 38/38 [00:04<00:00,  9.14it/s]
Epoch 46: 100%|██████████| 38/38 [00:04<00:00,  9.13it/s]
Epoch 47: 100%|██████████| 38/38 [00:04<00:00,  9.06it/s]
Epoch 48: 100%|██████████| 38/38 [00:04<00:00,  9.09it/s]
Epoch 49: 100%

In [21]:
torch.save(model_wt.state_dict(), './model_checkpoints/pretrained_wildtype_model.pth')

In [22]:
model_wta.load_state_dict(torch.load("./model_checkpoints/pretrained_wildtype_model.pth"), strict=False);

In [23]:
trainer_config_autistic = {
    'max_iter': 200,
    'detach_core' : True,
    'verbose': True,
    'lr_decay_steps': 4,
    'avg_loss': False,
    'lr_init': 0.009,
    } 
trainer_wta = get_trainer(trainer_fn=trainer_fn, trainer_config=trainer_config_autistic)

validation_score_wta, trainer_output_wta, state_dict_wta = trainer_wta(model_wta, dataloaders_autistic, seed=seed)

Epoch 1: 100%|██████████| 38/38 [00:03<00:00, 12.05it/s]
Epoch 2: 100%|██████████| 38/38 [00:03<00:00, 12.07it/s]
Epoch 3: 100%|██████████| 38/38 [00:03<00:00, 11.87it/s]
Epoch 4: 100%|██████████| 38/38 [00:03<00:00, 11.83it/s]
Epoch 5: 100%|██████████| 38/38 [00:03<00:00, 12.06it/s]
Epoch 6: 100%|██████████| 38/38 [00:03<00:00, 11.88it/s]
Epoch 7: 100%|██████████| 38/38 [00:03<00:00, 12.03it/s]
Epoch 8: 100%|██████████| 38/38 [00:03<00:00, 11.71it/s]
Epoch 9: 100%|██████████| 38/38 [00:03<00:00, 11.93it/s]
Epoch 10: 100%|██████████| 38/38 [00:03<00:00, 11.84it/s]
Epoch 11: 100%|██████████| 38/38 [00:03<00:00, 11.85it/s]
Epoch 12: 100%|██████████| 38/38 [00:03<00:00, 12.02it/s]
Epoch 13: 100%|██████████| 38/38 [00:03<00:00, 11.85it/s]
Epoch 14: 100%|██████████| 38/38 [00:03<00:00, 12.12it/s]
Epoch 15: 100%|██████████| 38/38 [00:03<00:00, 11.90it/s]
Epoch 16: 100%|██████████| 38/38 [00:03<00:00, 11.99it/s]
Epoch 17: 100%|██████████| 38/38 [00:03<00:00, 12.09it/s]
Epoch 18: 100%|████████

Epoch    26: reducing learning rate of group 0 to 2.7000e-03.


Epoch 27: 100%|██████████| 38/38 [00:03<00:00, 11.85it/s]
Epoch 28: 100%|██████████| 38/38 [00:03<00:00, 11.94it/s]
Epoch 29: 100%|██████████| 38/38 [00:03<00:00, 11.72it/s]
Epoch 30: 100%|██████████| 38/38 [00:03<00:00, 11.84it/s]
Epoch 31: 100%|██████████| 38/38 [00:03<00:00, 11.85it/s]
Epoch 32: 100%|██████████| 38/38 [00:03<00:00, 11.90it/s]
Epoch 33: 100%|██████████| 38/38 [00:03<00:00, 11.92it/s]
Epoch 34: 100%|██████████| 38/38 [00:03<00:00, 12.05it/s]
Epoch 35: 100%|██████████| 38/38 [00:03<00:00, 11.94it/s]
Epoch 36: 100%|██████████| 38/38 [00:03<00:00, 11.95it/s]
Epoch 37: 100%|██████████| 38/38 [00:03<00:00, 11.93it/s]
Epoch 38: 100%|██████████| 38/38 [00:03<00:00, 11.80it/s]
Epoch 39:   5%|▌         | 2/38 [00:00<00:02, 12.63it/s]

Epoch    38: reducing learning rate of group 0 to 8.1000e-04.


Epoch 39: 100%|██████████| 38/38 [00:03<00:00, 11.68it/s]
Epoch 40: 100%|██████████| 38/38 [00:03<00:00, 11.91it/s]
Epoch 41: 100%|██████████| 38/38 [00:03<00:00, 11.84it/s]
Epoch 42: 100%|██████████| 38/38 [00:03<00:00, 12.01it/s]
Epoch 43: 100%|██████████| 38/38 [00:03<00:00, 11.75it/s]
Epoch 44: 100%|██████████| 38/38 [00:03<00:00, 11.90it/s]
Epoch 45:   5%|▌         | 2/38 [00:00<00:02, 12.98it/s]

Epoch    44: reducing learning rate of group 0 to 2.4300e-04.


Epoch 45: 100%|██████████| 38/38 [00:03<00:00, 11.79it/s]
Epoch 46: 100%|██████████| 38/38 [00:03<00:00, 11.85it/s]
Epoch 47: 100%|██████████| 38/38 [00:03<00:00, 11.69it/s]


In [25]:
torch.save(model_wta.state_dict(), './model_checkpoints/wildtype_core_autistic_readout.pth')

## One-dataset models: wild-type (Core + Readout) and autistic (Core + Readout)

In [26]:
model_fn = 'sensorium.models.stacked_core_full_gauss_readout'
model_config_wildtype = {'pad_input': False,
  'stack': -1,
  'layers': 4,
  'input_kern': 9,
  'gamma_input': 6.3831,
  'gamma_readout': 0.0076,
  'hidden_kern': 7,
  'hidden_channels': 64,
  'depth_separable': True,
  'grid_mean_predictor': {'type': 'cortex',
   'input_dimensions': 2,
   'hidden_layers': 1,
   'hidden_features': 30,
   'final_tanh': True},
  'init_sigma': 0.1,
  'init_mu_range': 0.3,
  'gauss_type': 'full',
  'shifter': False,
}

model_wt = get_model(model_fn=model_fn,
                  model_config=model_config_wildtype,
                  dataloaders=dataloaders_wildtype,
                  seed=seed,
                  )
trainer_fn = "sensorium.training.standard_trainer"

trainer_config_wildtype = {'max_iter': 200,
                 'verbose': True,
                 'lr_decay_steps': 4,
                 'avg_loss': False,
                 'lr_init': 0.009,
                 }

trainer_wildtype = get_trainer(trainer_fn=trainer_fn, 
                     trainer_config=trainer_config_wildtype)

In [27]:
validation_score_wt, trainer_output_wt, state_dict_wt = trainer_wt(model_wt, dataloaders_wildtype, seed=seed+1)

Epoch 1: 100%|██████████| 38/38 [00:04<00:00,  8.65it/s]
Epoch 2: 100%|██████████| 38/38 [00:04<00:00,  9.16it/s]
Epoch 3: 100%|██████████| 38/38 [00:04<00:00,  9.02it/s]
Epoch 4: 100%|██████████| 38/38 [00:04<00:00,  9.12it/s]
Epoch 5: 100%|██████████| 38/38 [00:04<00:00,  9.09it/s]
Epoch 6: 100%|██████████| 38/38 [00:04<00:00,  9.13it/s]
Epoch 7: 100%|██████████| 38/38 [00:04<00:00,  9.16it/s]
Epoch 8: 100%|██████████| 38/38 [00:04<00:00,  9.16it/s]
Epoch 9: 100%|██████████| 38/38 [00:04<00:00,  9.18it/s]
Epoch 10: 100%|██████████| 38/38 [00:04<00:00,  9.14it/s]
Epoch 11: 100%|██████████| 38/38 [00:04<00:00,  9.08it/s]
Epoch 12: 100%|██████████| 38/38 [00:04<00:00,  9.28it/s]
Epoch 13: 100%|██████████| 38/38 [00:04<00:00,  9.25it/s]
Epoch 14: 100%|██████████| 38/38 [00:04<00:00,  9.19it/s]
Epoch 15: 100%|██████████| 38/38 [00:04<00:00,  9.18it/s]
Epoch 16: 100%|██████████| 38/38 [00:04<00:00,  9.13it/s]
Epoch 17: 100%|██████████| 38/38 [00:04<00:00,  9.14it/s]
Epoch 18: 100%|████████

Epoch    39: reducing learning rate of group 0 to 2.7000e-03.


Epoch 40: 100%|██████████| 38/38 [00:04<00:00,  9.14it/s]
Epoch 41: 100%|██████████| 38/38 [00:04<00:00,  9.12it/s]
Epoch 42: 100%|██████████| 38/38 [00:04<00:00,  9.06it/s]
Epoch 43: 100%|██████████| 38/38 [00:04<00:00,  9.07it/s]
Epoch 44: 100%|██████████| 38/38 [00:04<00:00,  9.14it/s]
Epoch 45: 100%|██████████| 38/38 [00:04<00:00,  9.15it/s]
Epoch 46: 100%|██████████| 38/38 [00:04<00:00,  9.16it/s]
Epoch 47: 100%|██████████| 38/38 [00:04<00:00,  9.22it/s]
Epoch 48: 100%|██████████| 38/38 [00:04<00:00,  9.14it/s]
Epoch 49: 100%|██████████| 38/38 [00:04<00:00,  9.14it/s]
Epoch 50: 100%|██████████| 38/38 [00:04<00:00,  9.18it/s]
Epoch 51: 100%|██████████| 38/38 [00:04<00:00,  9.18it/s]
Epoch 52: 100%|██████████| 38/38 [00:04<00:00,  9.21it/s]
Epoch 53: 100%|██████████| 38/38 [00:04<00:00,  9.10it/s]
Epoch 54: 100%|██████████| 38/38 [00:04<00:00,  9.10it/s]
Epoch 55: 100%|██████████| 38/38 [00:04<00:00,  9.11it/s]
Epoch 56:   5%|▌         | 2/38 [00:00<00:03, 10.23it/s]

Epoch    55: reducing learning rate of group 0 to 8.1000e-04.


Epoch 56: 100%|██████████| 38/38 [00:04<00:00,  9.16it/s]
Epoch 57: 100%|██████████| 38/38 [00:04<00:00,  9.15it/s]
Epoch 58: 100%|██████████| 38/38 [00:04<00:00,  9.16it/s]
Epoch 59: 100%|██████████| 38/38 [00:04<00:00,  9.13it/s]
Epoch 60: 100%|██████████| 38/38 [00:04<00:00,  9.16it/s]
Epoch 61: 100%|██████████| 38/38 [00:04<00:00,  8.86it/s]
Epoch 62: 100%|██████████| 38/38 [00:04<00:00,  9.16it/s]
Epoch 63: 100%|██████████| 38/38 [00:04<00:00,  9.04it/s]
Epoch 64: 100%|██████████| 38/38 [00:04<00:00,  9.10it/s]
Epoch 65: 100%|██████████| 38/38 [00:04<00:00,  9.17it/s]
Epoch 66: 100%|██████████| 38/38 [00:04<00:00,  9.09it/s]
Epoch 67: 100%|██████████| 38/38 [00:04<00:00,  9.11it/s]
Epoch 68:   5%|▌         | 2/38 [00:00<00:03, 10.26it/s]

Epoch    67: reducing learning rate of group 0 to 2.4300e-04.


Epoch 68: 100%|██████████| 38/38 [00:04<00:00,  9.11it/s]
Epoch 69: 100%|██████████| 38/38 [00:04<00:00,  9.08it/s]
Epoch 70: 100%|██████████| 38/38 [00:04<00:00,  9.22it/s]
Epoch 71: 100%|██████████| 38/38 [00:04<00:00,  9.14it/s]
Epoch 72: 100%|██████████| 38/38 [00:04<00:00,  9.05it/s]
Epoch 73: 100%|██████████| 38/38 [00:04<00:00,  9.07it/s]
Epoch 74: 100%|██████████| 38/38 [00:04<00:00,  9.00it/s]
Epoch 75: 100%|██████████| 38/38 [00:04<00:00,  9.10it/s]
Epoch 76: 100%|██████████| 38/38 [00:04<00:00,  9.08it/s]
Epoch 77: 100%|██████████| 38/38 [00:04<00:00,  9.22it/s]
Epoch 78: 100%|██████████| 38/38 [00:04<00:00,  9.12it/s]
Epoch 79: 100%|██████████| 38/38 [00:04<00:00,  9.05it/s]
Epoch 80: 100%|██████████| 38/38 [00:04<00:00,  9.02it/s]
Epoch 81: 100%|██████████| 38/38 [00:04<00:00,  9.07it/s]
Epoch 82: 100%|██████████| 38/38 [00:04<00:00,  9.10it/s]
Epoch 83: 100%|██████████| 38/38 [00:04<00:00,  9.04it/s]
Epoch 84: 100%|██████████| 38/38 [00:04<00:00,  8.99it/s]


In [28]:
torch.save(model_wt.state_dict(), './model_checkpoints/wildtype_model.pth')

In [29]:
model_config_autistic = {'pad_input': False,
  'stack': -1,
  'layers': 4,
  'input_kern': 9,
  'gamma_input': 6.3831,
  'gamma_readout': 0.0076,
  'hidden_kern': 7,
  'hidden_channels': 64,
  'depth_separable': True,
  'grid_mean_predictor': {'type': 'cortex',
   'input_dimensions': 2,
   'hidden_layers': 1,
   'hidden_features': 30,
   'final_tanh': True},
  'init_sigma': 0.1,
  'init_mu_range': 0.3,
  'gauss_type': 'full',
  'shifter': False,
}

model_a= get_model(model_fn=model_fn,
                  model_config=model_config_autistic,
                  dataloaders=dataloaders_autistic,
                  seed=seed,
                  )
trainer_fn = "sensorium.training.standard_trainer"

trainer_config_autistic = {'max_iter': 200,
                 'verbose': True,
                 'lr_decay_steps': 4,
                 'avg_loss': False, 
                 'lr_init': 0.009,
                 }

trainer_a = get_trainer(trainer_fn=trainer_fn, 
                     trainer_config=trainer_config_autistic)

In [30]:
validation_score_a, trainer_output_a, state_dict_a = trainer_a(model_a, dataloaders_autistic, seed=seed+1)

Epoch 1: 100%|██████████| 38/38 [00:04<00:00,  8.95it/s]
Epoch 2: 100%|██████████| 38/38 [00:04<00:00,  9.02it/s]
Epoch 3: 100%|██████████| 38/38 [00:04<00:00,  8.95it/s]
Epoch 4: 100%|██████████| 38/38 [00:04<00:00,  8.86it/s]
Epoch 5: 100%|██████████| 38/38 [00:04<00:00,  9.04it/s]
Epoch 6: 100%|██████████| 38/38 [00:04<00:00,  9.01it/s]
Epoch 7: 100%|██████████| 38/38 [00:04<00:00,  8.97it/s]
Epoch 8: 100%|██████████| 38/38 [00:04<00:00,  9.00it/s]
Epoch 9: 100%|██████████| 38/38 [00:04<00:00,  8.94it/s]
Epoch 10: 100%|██████████| 38/38 [00:04<00:00,  9.00it/s]
Epoch 11: 100%|██████████| 38/38 [00:04<00:00,  9.00it/s]
Epoch 12: 100%|██████████| 38/38 [00:04<00:00,  9.09it/s]
Epoch 13: 100%|██████████| 38/38 [00:04<00:00,  9.01it/s]
Epoch 14: 100%|██████████| 38/38 [00:04<00:00,  8.98it/s]
Epoch 15: 100%|██████████| 38/38 [00:04<00:00,  8.99it/s]
Epoch 16: 100%|██████████| 38/38 [00:04<00:00,  9.00it/s]
Epoch 17: 100%|██████████| 38/38 [00:04<00:00,  9.01it/s]
Epoch 18: 100%|████████

Epoch    53: reducing learning rate of group 0 to 2.7000e-03.


Epoch 54: 100%|██████████| 38/38 [00:04<00:00,  8.99it/s]
Epoch 55: 100%|██████████| 38/38 [00:04<00:00,  9.04it/s]
Epoch 56: 100%|██████████| 38/38 [00:04<00:00,  8.99it/s]
Epoch 57: 100%|██████████| 38/38 [00:04<00:00,  9.02it/s]
Epoch 58: 100%|██████████| 38/38 [00:04<00:00,  8.85it/s]
Epoch 59: 100%|██████████| 38/38 [00:04<00:00,  8.87it/s]
Epoch 60: 100%|██████████| 38/38 [00:04<00:00,  8.90it/s]
Epoch 61: 100%|██████████| 38/38 [00:04<00:00,  9.06it/s]
Epoch 62: 100%|██████████| 38/38 [00:04<00:00,  8.68it/s]
Epoch 63: 100%|██████████| 38/38 [00:04<00:00,  8.91it/s]
Epoch 64: 100%|██████████| 38/38 [00:04<00:00,  9.10it/s]
Epoch 65: 100%|██████████| 38/38 [00:04<00:00,  9.02it/s]
Epoch 66: 100%|██████████| 38/38 [00:04<00:00,  9.05it/s]
Epoch 67: 100%|██████████| 38/38 [00:04<00:00,  9.04it/s]
Epoch 68: 100%|██████████| 38/38 [00:04<00:00,  9.01it/s]
Epoch 69:   5%|▌         | 2/38 [00:00<00:03, 10.82it/s]

Epoch    68: reducing learning rate of group 0 to 8.1000e-04.


Epoch 69: 100%|██████████| 38/38 [00:04<00:00,  9.03it/s]
Epoch 70: 100%|██████████| 38/38 [00:04<00:00,  8.97it/s]
Epoch 71: 100%|██████████| 38/38 [00:04<00:00,  8.95it/s]
Epoch 72: 100%|██████████| 38/38 [00:04<00:00,  8.99it/s]
Epoch 73: 100%|██████████| 38/38 [00:04<00:00,  9.04it/s]
Epoch 74: 100%|██████████| 38/38 [00:04<00:00,  9.08it/s]
Epoch 75: 100%|██████████| 38/38 [00:04<00:00,  9.04it/s]
Epoch 76: 100%|██████████| 38/38 [00:04<00:00,  9.10it/s]
Epoch 77: 100%|██████████| 38/38 [00:04<00:00,  9.05it/s]
Epoch 78: 100%|██████████| 38/38 [00:04<00:00,  8.95it/s]
Epoch 79:   5%|▌         | 2/38 [00:00<00:03, 10.69it/s]

Epoch    78: reducing learning rate of group 0 to 2.4300e-04.


Epoch 79: 100%|██████████| 38/38 [00:04<00:00,  9.08it/s]
Epoch 80: 100%|██████████| 38/38 [00:04<00:00,  9.00it/s]
Epoch 81: 100%|██████████| 38/38 [00:04<00:00,  9.02it/s]
Epoch 82: 100%|██████████| 38/38 [00:04<00:00,  8.99it/s]


In [31]:
torch.save(model_a.state_dict(), './model_checkpoints/autistic_model.pth')