# Demo Notebook on how to load the transfer core and train a model

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2 

In [None]:
import torch
from collections import OrderedDict
import neuralpredictors as neur

## Build the dataloaders

In [None]:
from lurz2020.datasets.mouse_loaders import static_loaders

paths = ['data/Lurz2020/static20457-5-9-preproc0']

dataset_config = {'paths': paths, 
                  'batch_size': 64, 
                  'seed': 1}

dataloaders = static_loaders(**dataset_config)

## Build the model

If you want to load the transfer core later on, the arguments in the model config that concern the architecture of the model can not be changed. 

In [None]:
from lurz2020.models.models import se2d_fullgaussian2d

model_config = {'init_mu_range': 0.55,
                'init_sigma': 0.4,
                'input_kern': 15,
                'hidden_kern': 13,
                'gamma_input': 1.0,
                'grid_mean_predictor': {'type': 'cortex',
                                        'input_dimensions': 2,
                                        'hidden_layers': 0,
                                        'hidden_features': 0,
                                        'final_tanh': False},
                'gamma_readout': 2.439}

model = se2d_fullgaussian2d(**model_config, dataloaders=dataloaders, seed=1)

## Load the weights of the transfer core

This will load the weights of the core and discard the weights of the readout.

In [None]:
transfer_model = torch.load('models/transfer_model.pth.tar') 
model.load_state_dict(transfer_model, strict=False)

## Build the trainer

In [None]:
from lurz2020.training.trainers import standard_trainer as trainer

# If you want to allow fine tuning of the core, set detach_core to False
detach_core=True
if detach_core:
    print('Core is fixed and will not be fine-tuned')
else:
    print('Core will be fine-tuned')

trainer_config = {'track_training': True,
                  'detach_core': detach_core}

## Run training

In [None]:
score, output, model_state = trainer(model=model, dataloaders=dataloaders, seed=1, **trainer_config)