In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import optim
from tqdm import tqdm
from itertools import repeat
from collections import OrderedDict

from nips2018.utils import set_seed
from nips2018.movie import data
from nips2018.movie.parameters import DataConfig, Seed
from nips2018.architectures.readouts import SpatialTransformerPooled3dReadout, ST3dSharedGridStopGradientReadout
from nips2018.architectures.cores import StackedFeatureGRUCore, Stacked3dCore
from nips2018.architectures.shifters import StaticAffineShifter
from nips2018.architectures.modulators import GateGRUModulator 
from nips2018.movie import parameters
from nips2018.movie.models import Encoder
from nips2018.architectures.base import CorePlusReadout3d

from attorch.layers import elu1, Elu1
from attorch.train import early_stopping, cycle_datasets
from attorch.losses import PoissonLoss3d

import datajoint as dj

Connecting nikoskar@at-database1.ad.bcm.edu:3306


In [2]:
def save_checkpoint(model, filename):
    torch.save(model.state_dict(), filename)
    
def load_checkpoint(model, filename):
    statedict = torch.load(filename)
    model.load_state_dict(statedict)

In [9]:
# this you need to get the data hash
DataConfig.AreaLayerClipRawInputResponse()
DataConfig.AreaLayerClip()

data_hash  unique identifier for configuration,stats_source  normalization source,train_seq_len  training sequence length in frames,layer  short name for cortical area,brain_area  area name
29793eec715efd8898f2cba46dd6982a,stimulus.Clip,150,L2/3,V1


In [5]:
data.MovieMultiDataset().Member() & dict(animal_id=17797, session=8, scan_idx=5, segmentation_method=6)

group_id  index of group,animal_id  id number,session  session index for the mouse,scan_idx  number of TIFF stack file,pipe_version,segmentation_method,spike_method  spike inference method,preproc_id  preprocessing ID,name  string description to be used for training
67,17797,8,5,1,6,5,0,group067-17797-8-5-pre0-seg6-spi5-pip1


In [10]:
### seeds 880, 990, 1009, 1215, 2606, 7797, 8142

# this key for V1 and L2/3
# key = dict(data_hash='5253599d3dceed531841271d6eeba9c5', group_id=67, seed=880)

# this key for all neurons in scan (natural video)
key = dict(data_hash='29793eec715efd8898f2cba46dd6982a', group_id=67, seed=1009)

In [11]:
N_GPU = torch.cuda.device_count()
N_GPU = 1
batch_size = 3
val_subsample = None #1000
n_subsample = None

In [12]:
set_seed(key['seed'])
trainsets, trainloaders = DataConfig().load_data(key, tier='train', batch_size=batch_size)
n_neurons = OrderedDict([(k, v.n_neurons) for k, v in trainsets.items()])
valsets, valloaders = DataConfig().load_data(key, tier='validation', batch_size=1, key_order=trainsets)
img_shape = list(trainloaders.values())[0].dataset.img_shape

Setting numpy and torch seed to 1009
---------------------------+--------------------------------------------------------------------------------
AreaLayerClip              | Loading stimulus.Clip dataset with tier= train
---------------------------+--------------------------------------------------------------------------------
MovieMultiDataset          | Fetching data for
                           |  {         'brain_area': 'V1',
                           |           'data_hash': '29793eec715efd8898f2cba46dd6982a',
                           |           'group_id': 67,
                           |           'layer': 'L2/3',
                           |           'seed': 1009,
                           |           'stats_source': 'stimulus.Clip',
                           |           'train_seq_len': 150}
MovieMultiDataset          | Data will be (inputs,behavior,eye_position,responses)
MovieMultiDataset          | Loading dataset group067-17797-8-5-pre0-seg6-spi5-pip1 -->
      

In [13]:
best = Encoder * (dj.U('group_id').aggr(Encoder, best = 'max(val_corr)')) & 'val_corr >= best and group_id=22'
best * parameters.CoreConfig.StackedFeatureGRU

group_id  index of group,core_hash  unique identifier for configuration,ro_hash  unique identifier for configuration,shift_hash  unique identifier for configuration,mod_hash  unique identifier for configuration,data_hash  unique identifier for configuration,train_hash  unique identifier for configuration,seed  random seed,val_corr  validation correlation (single trial),model  stored model,best  calculated attribute,hidden_channels  hidden channels,rec_channels  recurrent hidden channels,input_kern  kernel size at input convolutional layers,hidden_kern  kernel size at hidden convolutional layers,rec_kern  kernel size at hidden convolutional layers,layers  layers,gamma_rec  regularization constant for recurrent bias term,gamma_hidden  regularization constant for hidden layers in CNN,gamma_input  regularization constant for input convolutional layers,skip  use skip connection to previous `skip` layers,bias  use bias,pad_input  use padding,momentum  use padding
22,22d11147b37e3947e7d1034cc00d402c,bf00321c11e46d68d4a42653a725969d,64add03e1462b7413b59812d446aee9f,4954311aa3bebb347ebf411ab5198890,6c0290da908317e55c4baf92e379d651,624f62a2ef01d39f6703f3491bb9242b,2606,0.141586,=BLOB=,0.141586,12,36,7,3,3,3,0.0,0.1,50.0,2,0,1,0.1


In [16]:
core = StackedFeatureGRUCore(
    input_channels=img_shape[1], 
    hidden_channels=12, 
    rec_channels=36,
    input_kern=7, 
    hidden_kern=3, 
    rec_kern=3, 
    layers=3,
    gamma_input=50, 
    gamma_hidden=.1, 
    gamma_rec=.0, 
    momentum=.1,
    skip=2, 
    bias=False, 
    batch_norm=True, 
    pad_input=True)

ro_in_shape = CorePlusReadout3d.get_readout_in_shape(
    core, 
    img_shape)

readout = ST3dSharedGridStopGradientReadout(
    ro_in_shape,
    n_neurons,
    positive=False,
    gamma_features=1.,
    pool_steps=2,
    kernel_size=4,
    stride=4,
    gradient_pass_mod=3)

shifter = StaticAffineShifter(
    n_neurons, 
    input_channels=2, 
    hidden_channels=2, 
    bias=True, 
    gamma_shifter=0.001)

modulator = GateGRUModulator(
    n_neurons, 
    gamma_modulator=0.0, 
    hidden_channels=50, 
    offset=1, 
    bias=True)

model = CorePlusReadout3d(
    core, 
    readout, 
    nonlinearity=Elu1(),
    shifter=shifter, 
    modulator=modulator, 
    burn_in=15)

---------------------------+--------------------------------------------------------------------------------
Stacked2dCore              | Ignoring input {} when creating Stacked2dCore
---------------------------+--------------------------------------------------------------------------------
ConvGRUCell                | Ignoring input {'bias': False, 'skip': 2, 'batch_norm': True} when creating
                           |      ConvGRUCell
ConvGRUCell                | 	Initializing first hidden state
---------------------------+--------------------------------------------------------------------------------
ST3dSharedGridStopGrad...  | Ignoring input {} when creating ST3dSharedGridStopGradientReadout
ST3dSharedGridStopGrad...  | 	Neuron change detected from -1 to 9941 ! Resetting grid!
ST3dSharedGridStopGrad...  | 	Gradient for group067-17797-8-5-pre0-seg6-spi5-pip1 will pass
---------------------------+--------------------------------------------------------------------------------
St

In [17]:
criterion = PoissonLoss3d()

n_datasets = len(trainloaders)
acc = 1
grad_passes = 0
for ro in model.readout.values():
    grad_passes += int(not ro.stop_grad)

stop_closure = Encoder().get_stop_closure(valloaders, subsamp_size=val_subsample)
mu_dict = OrderedDict([(k, dl.dataset.mean_trial().responses) for k, dl in trainloaders.items()])
model.readout.initialize(mu_dict)
model.core.initialize()

if model.shifter is not None:
    biases = OrderedDict([(k, -dl.dataset.mean_trial().eye_position) for k, dl in trainloaders.items()])
    model.shifter.initialize(bias=biases)
if model.modulator is not None:
    model.modulator.initialize()

model = model.cuda()

---------------------------+--------------------------------------------------------------------------------
ST3dSharedGridStopGrad...  | Initializing with mu_dict: group067-17797-8-5-pre0-seg6-spi5-pip1: 9941
---------------------------+--------------------------------------------------------------------------------
StaticAffineShifter        | Initializing affine weights
---------------------------+--------------------------------------------------------------------------------
StaticAffine               | Setting bias to predefined value [-0. -0.]
---------------------------+--------------------------------------------------------------------------------
GateGRUModulator           | Initializing GateGRUModulator


In [15]:
def full_objective(
    model, 
    readout_key, 
    inputs, 
    beh, 
    eye_pos, 
    targets):
    
    outputs = model(inputs, readout_key, eye_pos=eye_pos, behavior=beh)
    return (criterion(outputs, targets))
#         + (model.core.regularizer() / grad_passes if not model.readout[readout_key].stop_grad else 0)
#         + model.readout.regularizer(readout_key).cuda(0)
#         + (model.shifter.regularizer(readout_key) if model.shift else 0)
#         + (model.modulator.regularizer(readout_key) if model.modulate else 0)) / acc


def train(
    model, 
    objective, 
    optimizer, 
    stop_closure, 
    trainloaders, 
    epoch=0, 
    post_epoch_hook=None,
    interval=1, 
    patience=10, 
    max_iter=10, 
    maximize=True, 
    tolerance=1e-6, 
    cuda=True,
    restore_best=True, 
    accumulate_gradient=1):
    
    optimizer.zero_grad()
    iteration = 0
    losses = []
    
    for epoch, val_obj in early_stopping(
        model, 
        stop_closure,
        interval=interval, 
        patience=patience,
        start=epoch, 
        max_iter=max_iter, 
        maximize=maximize,
        tolerance=tolerance, 
        restore_best=restore_best):
        
        for batch_no, (readout_key, *data) in tqdm(
            enumerate(cycle_datasets(
                trainloaders,
                requires_grad=False, 
                cuda=cuda)),
            desc='Training  | Epoch {}'.format(epoch)):
            
            obj = objective(model, readout_key, *data)
            obj.backward()
            if iteration % accumulate_gradient == accumulate_gradient - 1:
                optimizer.step()
                optimizer.zero_grad()
            iteration += 1

        if post_epoch_hook is not None:
            model = post_epoch_hook(model, epoch)
        losses.append(obj)
    return model, epoch, losses

In [18]:
epoch = 0
# --- train core, modulator, and readout but not shifter
schedule = [0.001, 0.00075]

for opt, lr in zip(repeat(torch.optim.Adam), schedule):
    print('Training with learning rate', lr)
    optimizer = opt(model.parameters(), lr=lr)

    model, epoch, losses = train(
        model, 
        full_objective, 
        optimizer,
        stop_closure, 
        trainloaders,
        epoch=epoch,
        max_iter=200,
        interval=4,
        patience=4,
        accumulate_gradient=acc)

  0%|          | 0/9 [00:00<?, ?it/s]

Training with learning rate 0.001


100%|██████████| 9/9 [00:03<00:00,  2.33it/s]
Training  | Epoch 1: 0it [00:00, ?it/s]

---------------------------+--------------------------------------------------------------------------------
Encoder                    | 	group067-17797-8-5-pre0-seg6-spi5-pip1 correlation 0.012990778


Training  | Epoch 1: 101it [01:50,  1.09s/it]
Training  | Epoch 2: 101it [01:12,  1.40it/s]
Training  | Epoch 3: 101it [01:11,  1.41it/s]
Training  | Epoch 4: 17it [00:12,  1.41it/s]

KeyboardInterrupt: 

In [13]:
save_path = './saved_models/'
def save_checkpoint(model, filename):
    torch.save(model.state_dict(), filename)
    
save_checkpoint(model, save_path + '17797_8_5_v3.pt')