In [1]:
%load_ext autoreload
%autoreload 2

import os
import random

import numpy as np
import torch
from torch.optim import AdamW
import scipy.io

from experiment import activation
from experiment import experiment
from experiment import michaels_load
from experiment import mRNN
from experiment import stim
from experiment import utils

import cpn_model
import stim_model

CUDA = None
if isinstance(CUDA, str):
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = CUDA
    CUDA = torch.device(0)

In [2]:
cfg = experiment.get_config(stim_retain_grad=True, coadapt=False, cuda=CUDA)

obs_dim, stim_dim, out_dim, cuda = cfg.unpack()
per_mod_obs_dim = cfg.observer_instance.out_dim

cpn = cpn_model.CPNModelLSTM(
    obs_dim,
    stim_dim,
    num_neurons=obs_dim,
    activation_func=cfg.cpn_activation,
    cuda=cuda,
)
for param in cpn.parameters():
    param.requires_grad = True


mike = mRNN.MichaelsRNN(
    init_data_path=michaels_load.get_default_path(),
    stimulus=cfg.stim_instance,
    cuda=CUDA
)
mike.set_lesion(cfg.lesion_instance)
for param in mike.parameters():
    param.requires_grad = True

In [3]:
dataloader = cfg.loader_train
data = next(iter(dataloader))
batch_size, trial_len, out_dim = data[0].shape

In [8]:
# One entry for each time step

din, trial_end, _, dout, labels = data

def forward(cpn, din, trial_end, dout, labels, en=None):
    stim_params = []
    stims = []
    preds = []
    preds_en = []
    mike.reset()
    
    cpn.reset()
    
    if en is not None:
        en.reset()

    # First time step: no stimulation; just priming the mRNN
    mike_in = din[:, 0, :].T
    mike_out = mike(mike_in)

    for tidx in range(1, trial_len):
        # Observe current activity
        obs_raw = mike.observe(cfg.observer_instance)

        brain_data = obs_raw + (trial_end[:, tidx - 1, :],)
        cpn_in = torch.cat(brain_data, axis=1)

        stims.append(mike.last_stimulus)
        stims[-1].retain_grad()

        stim = cpn(cpn_in)
        stim_params.append(stim)
        stim_params[-1].retain_grad()
        
        if en is not None:
            # en receives (obs, stims, trial_end)
            en_data = obs_raw + (stim_params[-1], trial_end[:, tidx - 1, :])
            en_in = torch.cat(en_data, axis=1)
            en_out = en(en_in)
            
            preds_en.append(en_out.unsqueeze(dim=1))

        mike.stimulate(stim)

        mike_in = din[:, tidx, :].T
        mike_out = mike(mike_in)

        preds.append(mike_out.unsqueeze(dim=1))
    
    return stims, stim_params, preds, preds_en

In [5]:
# Used for CPN training
def lr_sched(opt, rtl, eidx):
        """
        Args:
            rtl - recent training loss, which we use to determine the learning rate
        """
        if rtl is None or eidx < 4000:
            for p in opt.param_groups:
                p["lr"] = 1e-3
        elif rtl >= 0.008:
            for p in opt.param_groups:
                p["lr"] = 1e-3
        elif rtl >= 0.006:
            for p in opt.param_groups:
                p["lr"] = 5e-4
        elif rtl >= 0.005:
            for p in opt.param_groups:
                p["lr"] = 1e-5
        elif rtl >= 0.004:
            for p in opt.param_groups:
                p["lr"] = 2e-6
        elif rtl >= 0.0025:
            for p in opt.param_groups:
                p["lr"] = 1e-6
        else:
            for p in opt.param_groups:
                p["lr"] = 1e-7

In [12]:
# We need an EN to compare to...
# We make it, then train it.

# Obs, stim, and trial end indicator (trial indicator included in obs_dim)
en_in_dim = obs_dim + stim_dim
en_out_dim = dout.shape[-1]

en, opt = stim_model.get_stim_model(en_in_dim, en_out_dim, cuda=CUDA)

rtl = 1
eidx = 0

while rtl > 0.003:
    cpn_noisy = cpn_model.CPNNoiseyLSTMCollection(cpn, noise_var=0.1,
                                             white_noise_pct=0.3,
                                             white_noise_var=6,
                                             cuda=CUDA)
    cpn_noisy.setup(batch_size)
    
    _, _, _, en_pred_vec = forward(
        cpn_noisy, din, trial_end, dout, labels, en=en)
    
    en_preds = torch.cat(en_pred_vec, axis=1)
    actuals = utils.trunc_to_trial_end(en_preds, trial_end[:, :-1, :])
    loss = torch.nn.MSELoss()(en_preds, dout[:, 1:, :])
    loss.backward(inputs=list(en.parameters()))
    
    rtl = loss.item()

    print(eidx, rtl)

    for p in opt.param_groups:
        if rtl < 0.0007:
            p["lr"] = 1e-4
        elif rtl < 0.005:
            p["lr"] = 3e-3
        else:
            p["lr"] = 4e-3

    opt.step()
    
    eidx += 1
# In loop: train like usual.

torch.Size([402, 340, 50]) torch.Size([402, 340, 1]) <built-in function len>
0 0.02537582814693451
torch.Size([402, 340, 50]) torch.Size([402, 340, 1]) <built-in function len>
1 0.02138046734035015
torch.Size([402, 340, 50]) torch.Size([402, 340, 1]) <built-in function len>
2 0.02006768248975277
torch.Size([402, 340, 50]) torch.Size([402, 340, 1]) <built-in function len>
3 0.019693749025464058
torch.Size([402, 340, 50]) torch.Size([402, 340, 1]) <built-in function len>
4 0.01915150322020054
torch.Size([402, 340, 50]) torch.Size([402, 340, 1]) <built-in function len>
5 0.018800601363182068
torch.Size([402, 340, 50]) torch.Size([402, 340, 1]) <built-in function len>
6 0.0183712225407362
torch.Size([402, 340, 50]) torch.Size([402, 340, 1]) <built-in function len>
7 0.017801620066165924


KeyboardInterrupt: 

In [None]:
# TODO: experiment where we take the on-policy CPN outputs,
#       similar CPN outputs, and random CPN outputs, and compare grads
#       between the EN and backprop-through-the-brain.

In [30]:
opt = AdamW(cpn.parameters())

rtl = 1
eidx = 0

while rtl > 0.0003:
    _, stim_params, actuals_vec, _ = forward(cpn, din, trial_end, dout, labels,
                                         en=None)
    
    actuals = torch.cat(actuals_vec, axis=1)
    actuals = utils.trunc_to_trial_end(actuals, trial_end[:, :-1, :])
    loss = torch.nn.MSELoss()(actuals, dout[:, 1:, :])
    loss.backward(inputs=list(cpn.parameters()))
    
    rtl = loss.item()
    
    print(eidx, rtl)

    lr_sched(opt, rtl, eidx)
    opt.step()
    
    eidx += 1

0 0.023088084533810616
1 0.02138981781899929
2 0.019765663892030716


KeyboardInterrupt: 

In [None]:
# Now we have a partially trained CPN. Let's repeat the experiment here, starting with a new EN...

In [None]:
# Again, far later in training, we make another EN and compare...