## Imports

In [None]:
from __future__ import division

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.nn.utils.prune as prune

import numpy as np

import os
import random
import json

# from statistics import geometric_mean
from tqdm import tqdm
import pickle as pkl
from itertools import accumulate

from scipy.stats import gmean
from task2 import generate_trials
from ml_decorr import decorr_criterion
from model import Model

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using '{device}' device")

Using 'cuda' device
Using 'cuda' device


## misc helper funcs

In [2]:

# create population vector for performance calculation
def popvec(y):
    """Population vector read out."""
    pref = np.arange(0, 2 * np.pi, 2 * np.pi / y.shape[-1])  # preferences
    temp_sum = y.sum(axis=-1)
    temp_cos = np.sum(y * np.cos(pref), axis=-1) / temp_sum
    temp_sin = np.sum(y * np.sin(pref), axis=-1) / temp_sum
    loc = np.arctan2(temp_sin, temp_cos)
    return np.mod(loc, 2 * np.pi)


# get model performance
def get_perf(y_hat, y_loc):

    if len(y_hat.shape) != 3:
        raise ValueError('y_hat must have shape (Time, Batch, Unit)')
    # Only look at last time points
    y_loc = y_loc[-1]
    y_hat = y_hat[-1]

    # Fixation and location of y_hat
    y_hat_fix = y_hat[..., 0]
    y_hat_loc = popvec(y_hat[..., 1:])

    # Fixating? Correctly saccading?
    fixating = y_hat_fix > 0.5

    original_dist = y_loc - y_hat_loc
    corr_loc = ( 2*np.pi - abs(original_dist) ) / 2*np.pi
    # dist = np.minimum(abs(original_dist), 2*np.pi-abs(original_dist))
    # corr_loc = dist < 0.2*np.pi

    # Should fixate?
    should_fix = y_loc < 0

    # performance
    perf = should_fix * fixating + (1-should_fix) * corr_loc * (1-fixating)
    return perf


# generate orthogonal matrix for weight init
def gen_ortho_matrix(dim, rng=None):
    """Generate random orthogonal matrix
    Taken from scipy.stats.ortho_group
    Copied here from compatibilty with older versions of scipy
    """
    H = np.eye(dim)
    for n in range(1, dim):
        if rng is None:
            x = np.random.normal(size=(dim-n+1,))
        else:
            x = rng.normal(size=(dim-n+1,))
        # random sign, 50/50, but chosen carefully to avoid roundoff error
        D = np.sign(x[0])
        x[0] += D*np.sqrt((x*x).sum())
        # Householder transformation
        Hx = -D*(np.eye(dim-n+1) - 2.*np.outer(x, x)/(x*x).sum())
        mat = np.eye(dim)
        mat[n-1:, n-1:] = Hx
        H = np.dot(H, mat)
    return H


## model running funcs

#### pruning func

In [3]:

def apply_layer_pruning(model, region, amount=0.2):
    if region=='ctx': 
        prune.l1_unstructured(model.h2h_ctx, name='weight', amount=amount)
        prune.l1_unstructured(model.ctx2out, name='weight', amount=amount)
        
    elif region=='hpc':
        prune.l1_unstructured(model.h2h_hpc, name='weight', amount=amount)
        prune.l1_unstructured(model.hpc2out, name='weight', amount=amount)


#### training loop

In [4]:
def train_interleaved(model, hp, task_list_unique, task_list, 
                      epoch_list=None, task_probs=None, save_hids=False, save_weights=False):

    all_train_losses = []; all_train_perfs = []
    all_losses_hpc = []
    all_eval_losses = []; all_eval_perfs = []

    all_hids_ctx = []; all_hids_hpc = []

    all_ctx_h2h_weights = []; all_hpc_h2h_weights = []
    all_ctx2hpc_weights = []; all_hpc2ctx_weights = []
    all_ctx2out_weights = []; all_hpc2out_weights = []

    all_yhats = []; all_eval_yhats = []
    all_trials = []; all_eval_trials = []

    batch_size = hp['batch_size']

    mixed_batch_tasks = []

    # print(f"{task_list_unique} for {epoch_list} epochs")
    # print(f"{hp['eval_epochs']} eval epoch per {hp['eval_step']} training epochs")
    # print(f"{np.sum(epoch_list)} total epochs")

    ## --- sort out mixed batching --- ##
    if hp['mixed_batch']: 
        if task_probs==None: raise Exception("ERROR: NO TASK PROBABILITIES")
        epoch_intervals = task_probs.keys()
        epoch_accum = list(accumulate([int(a[:-1]) for a in epoch_intervals]))
        train_paradigm = 0
        task_list = ['hi']*epoch_accum[-1] # filler list for iterating over

        hp['batch_size']=1 # changed to 1 since each trial within the mixed batch is singular
        task_probs['task_list_unique'] = task_list_unique
        hp['task_probs'] = task_probs

    # print(len(task_list))
    i = 0
    for i, task in tqdm(enumerate(task_list)):

        ## --- task generation --- ##
        if hp['mixed_batch']: 
            if i==epoch_accum[train_paradigm]: train_paradigm+=1

            task = random.choices(task_list_unique, task_probs[list(epoch_intervals)[train_paradigm]], k=batch_size)
            hp['tdim'] = random.randint(75, 125) # min and max tdims here
            mixed_batch_tasks.append(task)

        trial = generate_trials(task, hp, 'random', batch_size=hp['batch_size'])
        all_trials.append(trial)
        inputs = torch.tensor(trial.x, dtype=torch.float32).to(device)
        mask = torch.tensor(trial.c_mask, dtype=torch.float32).to(device)
        y = torch.tensor(trial.y, dtype=torch.float32).to(device)

        ctx_hid = torch.zeros(batch_size, hp['hid_size_ctx']).to(device)
        hpc_hid = torch.zeros(batch_size, hp['hid_size_hpc']).to(device)

        yhats = []
        outputs = []
        ctx_hids = []
        hpc_hids = []

        loss = 0 # full loss
        loss_hpc = 0 # just hpc loss


        ## --- training step --- ##
        model.train()
        model.optimizer.zero_grad()

        n_timesteps = inputs.shape[0]
        for t in range(n_timesteps):

            (y_hat), ctx_hid, hpc_hid  = model(inputs[t], ctx_hid, hpc_hid) # .to(device)
            
            if hp['hpc_loss'] != None: 
                if hp['hpc_loss']=='recon': 
                    y_hat, reconstruction = y_hat
                    l_hpc = model.reconstruction_loss(inputs[t], reconstruction)
                elif hp['hpc_loss']=='decorr': 
                    l_hpc = decorr_criterion(hpc_hid)
                loss_hpc += l_hpc
                
            l = model.compute_loss(y[t], y_hat, mask[t*batch_size:(t+1)*batch_size])
            loss += l

            yhats.append(y_hat)
            outputs.append(y_hat.cpu().detach().numpy())

            if save_hids:
                ctx_hids.append(ctx_hid.detach().cpu())
                hpc_hids.append(hpc_hid.detach().cpu())

        # outputs = torch.stack(outputs)
        perf = np.mean(get_perf(np.asarray(outputs), trial.y_loc)) # [:,batch_idxs])) # .cpu().detach().numpy()

        if hp['force_sparsity']:
            l1_norm_h2h_ctx = model.h2h_ctx.weight.abs().sum()
            l1_norm_h2h_hpc = model.h2h_hpc.weight.abs().sum()
            l1_norm_ctx2out = model.ctx2out.weight.abs().sum()
            l1_norm_hpc2out = model.hpc2out.weight.abs().sum()

            loss += (hp['l1_lambda_ctx'] * l1_norm_h2h_ctx) + (hp['l1_lambda_hpc'] * l1_norm_h2h_hpc) + \
                    (hp['l1_lambda_ctx'] * l1_norm_ctx2out) + (hp['l1_lambda_hpc'] * l1_norm_hpc2out)


        if hp['hpc_loss'] != None: 
            loss = loss + (hp['hpc_loss_alpha']*loss_hpc)
            # loss = (hp['hpc_loss_alpha']*loss) + ((1-hp['hpc_loss_alpha'])*loss_hpc)
            
        loss.backward() # retain_graph=True
        model.optimizer.step()
        # scheduler.step()


        ## ---- prunning ---- ##
        if hp['force_sparsity'] and i % hp['prune_int']==0 and i!=0:
            apply_layer_pruning(model, 'ctx', amount=hp['prune_ctx'])
            apply_layer_pruning(model, 'hpc', amount=hp['prune_hpc'])

            prune.remove(model.h2h_ctx, 'weight')
            prune.remove(model.h2h_hpc, 'weight')
            prune.remove(model.ctx2out, 'weight')
            prune.remove(model.hpc2out, 'weight')

        ## ---- logging ---- ##
        all_train_perfs.append(np.mean(perf))
        all_train_losses.append(float(loss.detach().cpu()))
        all_yhats.append(yhats)

        if hp['hpc_loss'] != None: 
            all_losses_hpc.append(float(loss_hpc.detach().cpu()))


        if save_hids:
            all_hids_ctx.append(ctx_hids)
            all_hids_hpc.append(hpc_hids)

        if save_weights and i % hp['weight_save_int'] == 0: #  and i!=0:
            ctx_h2h_weights = model.state_dict()['h2h_ctx.weight'].cpu().detach().numpy()
            hpc_h2h_weights = model.state_dict()['h2h_hpc.weight'].cpu().detach().numpy()
            ctx2hpc_weights = model.state_dict()['ctx2hpc.weight'].cpu().detach().numpy()
            hpc2ctx_weights = model.state_dict()['hpc2ctx.weight'].cpu().detach().numpy()
            ctx2out_weights = model.state_dict()['ctx2out.weight'].cpu().detach().numpy()
            hpc2out_weights = model.state_dict()['hpc2out.weight'].cpu().detach().numpy()

            all_ctx_h2h_weights.append(ctx_h2h_weights)
            all_hpc_h2h_weights.append(hpc_h2h_weights)
            all_ctx2hpc_weights.append(ctx2hpc_weights)
            all_hpc2ctx_weights.append(hpc2ctx_weights)
            all_ctx2out_weights.append(ctx2out_weights)
            all_hpc2out_weights.append(hpc2out_weights)

        
        ## ---- eval step --- ## 
        # if epoch % eval_step == 0: # ylocs, ys, xs, 
        eval_perfs, eval_losses, yhats, eval_trial, eval_hids_ctx, eval_hids_hpc = run_eval(model, hp, task_list_unique)

        all_eval_perfs.append(eval_perfs)
        all_eval_losses.append(eval_losses)

        all_eval_trials.append(eval_trial)
        all_eval_yhats.append(yhats)
        # all_xs.append(xs)
        # all_ylocs.append(ylocs)
        # all_ys.append(ys)
        
        i += 1

    return (all_train_losses, all_losses_hpc, all_train_perfs, mixed_batch_tasks), \
            (all_eval_losses, all_eval_perfs), \
            (all_yhats, all_trials, all_eval_yhats, all_eval_trials), \
            (all_hids_ctx, all_hids_hpc), \
            (all_ctx_h2h_weights, all_hpc_h2h_weights), \
            (all_ctx2hpc_weights, all_hpc2ctx_weights), \
            (all_ctx2out_weights, all_hpc2out_weights)

#### eval func

In [5]:
### EVAL FUNC

def run_eval(model, hp, task_list):
    model.eval()
    
    batch_size = hp['eval_batch_size']
    device = hp['device']
    
    all_hid_ctx = []
    all_hid_hpc = []

    all_perfs = []
    all_losses = []

    all_yhats = []
    all_trials = []


    for task in task_list:
        task_hid_ctx = []
        task_hid_hpc = []

        task_perfs = []
        task_losses = []

        task_yhats = []
        task_trials = []

        epoch = 0 
        for epoch in range(hp['eval_epochs']):

            trial = generate_trials(task, hp, 'random', batch_size=batch_size)
            inputs = torch.tensor(trial.x, dtype=torch.float32).to(device)
            # y_loc = torch.tensor(trial.y_loc, dtype=torch.float32).to(device)
            mask = torch.tensor(trial.c_mask, dtype=torch.float32).to(device)
            y = torch.tensor(trial.y, dtype=torch.float32).to(device)

            ctx_hid = torch.zeros(batch_size, hp['hid_size_ctx']).to(device)
            hpc_hid = torch.zeros(batch_size, hp['hid_size_hpc']).to(device)
        
            yhats = []
            loss = 0
            n_timesteps = inputs.shape[0]
            for t in range(n_timesteps):
                (y_hat), ctx_hid, hpc_hid  = model(inputs[t], ctx_hid, hpc_hid) # .to(device)

                l = model.compute_loss(y[t], y_hat, mask[t*batch_size:(t+1)*batch_size])
                loss += l

                yhats.append(y_hat.cpu().detach().numpy())
                # outputs.append(y_hat.cpu().detach().numpy())


            perf = np.mean(get_perf(np.asarray(yhats), trial.y_loc)) # .cpu().detach().numpy()

            task_perfs.append(np.mean(perf))
            task_losses.append(float(loss.detach().cpu()))

            if hp['save_eval_hids']:
                task_hid_ctx.append(ctx_hid)
                task_hid_hpc.append(hpc_hid)

            task_yhats.append(yhats)
            task_trials.append(trial)

            epoch+= 1


        all_perfs.append(task_perfs)
        all_losses.append(task_losses)

        all_yhats.append(task_yhats)
        all_trials.append(task_trials)


        if hp['save_eval_hids']:
            all_hid_ctx.append(task_hid_ctx)
            all_hid_hpc.append(task_hid_hpc)

    return all_perfs, all_losses, all_yhats, all_trials, all_hid_ctx, all_hid_hpc # all_ylocs, all_ys, all_xs, 

#### save data func

In [6]:
def save_data(model, hp, task_list_unique, all_eval_perfs, all_eval_losses, 
              epoch_list=None,
              all_trials=None, all_eval_trials=None,
              all_yhats=None, all_eval_yhats=None,
              all_train_perfs=None, all_train_losses=None, 
              all_losses_hpc=None, mixed_batch_tasks=None,
              all_hids_ctx=None, all_hids_hpc=None, 
              all_ctx_h2h_weights=None, all_hpc_h2h_weights=None,
              all_ctx2hpc_weights=None, all_hpc2ctx_weights=None, 
              all_ctx2out_weights=None, all_hpc2out_weights=None):

    from datetime import datetime
    import glob

    today = datetime.today().strftime('%Y-%m-%d')
    results_dir = '/home/jason/dev/schema1/results'
    # if hp['hpc_reconstruct']: results_dir += '/recon-loss'
    n_tasks = len(task_list_unique)

    ## --- create the name string based on tasks --- ##
    model_name_str = ''
    for i in range(n_tasks):
        if hp['mixed_batch']: model_name_str = model_name_str + str(task_list_unique[i])
        else: model_name_str = model_name_str + str(task_list_unique[i]) + str(epoch_list[i])
        if i != (n_tasks-1):
            model_name_str += '-'

    ## --- add network size label  --- ##
    model_name_str += '_ctx'+str(hp['hid_size_ctx']) + '-hpc'+str(hp['hid_size_hpc'])

    ## --- add mixed batch label  --- ##
    if 'mixed_batch' in hp and hp['mixed_batch']: 
        epoch_intervals = list(hp['task_probs'].keys())[:-1]
        epoch_accum = list(accumulate([int(a[:-1]) for a in epoch_intervals]))
        model_name_str += '_mixedbatch' + str(epoch_accum[-1]) + 'epochs'

    ## --- add hpc loss label  --- ##
    if hp['hpc_reconstruct']!=None: 
        if hp['hpc_loss']=='decorr': model_name_str += '_decorrloss' +'{:.0e}'.format(hp['hpc_loss_alpha'])
        elif hp['hpc_loss']=='recon': model_name_str += '_reconloss' + str(hp['hpc_loss_alpha'])
    if hp['force_sparsity']: model_name_str += '_sparse'
    results_dir = os.path.join(results_dir, model_name_str)

    try: os.mkdir(results_dir)
    except: pass

    ## --- get directory str  --- ##
    this_dir = today + '_lrctx'+'{:.0e}'.format(hp['lr_ctx']) + '_lrhpc'+'{:.0e}'.format(hp['lr_hpc']) \
              + '_c2h'+'{:.1e}'.format(hp['lr_c2h']) + '_h2c'+'{:.1e}'.format(hp['lr_h2c']) 

    
    if os.path.isdir(os.path.join(results_dir, this_dir)):
        n_files = len(glob.glob(os.path.join(results_dir, this_dir+'*')))
        this_dir = this_dir+'_v'+str(n_files+1)
    this_dir = os.path.join(results_dir, this_dir)
    os.mkdir(this_dir)
    print(this_dir)


    os.mkdir(os.path.join(this_dir, 'perfs'))
    os.mkdir(os.path.join(this_dir, 'losses'))
    os.mkdir(os.path.join(this_dir, 'hids'))
    os.mkdir(os.path.join(this_dir, 'weights'))
    os.mkdir(os.path.join(this_dir, 'trials'))
    os.mkdir(os.path.join(this_dir, 'yhats'))

    perfs_dir = os.path.join(this_dir, 'perfs')
    losses_dir = os.path.join(this_dir, 'losses')
    hids_dir = os.path.join(this_dir, 'hids')
    weights_dir = os.path.join(this_dir, 'weights')
    trials_dir = os.path.join(this_dir, 'trials')
    yhats_dir = os.path.join(this_dir, 'yhats')


    ## --- always save model  --- ##
    model_save_path = os.path.join(this_dir, model_name_str + '.pth')
    torch.save(model.state_dict(), model_save_path)
    print('model saved at:' + model_save_path)

    ## --- always hp dict  --- ##
    with open(os.path.join(this_dir, 'hp.pkl'), 'wb') as f:
                pkl.dump(hp, f)

    rng_save = hp.pop('rng')
    device_save = hp.pop('device')
    json_object = json.dumps(hp, indent=4)
    with open(os.path.join(this_dir, 'hp.json'), "w") as outfile:
        outfile.write(json_object)
    hp['rng'] = rng_save; hp['device'] = device_save
    print('hyperparameters saved')
    

    ## --- always perfs and losses  --- ##
    with open(os.path.join(perfs_dir, 'eval_perfs.pkl'), 'wb') as f:
                pkl.dump(all_eval_perfs, f) 
    with open(os.path.join(losses_dir, 'eval_losses.pkl'), 'wb') as f:
                pkl.dump(all_eval_losses, f) 
    print("eval perfs and losses saved")


    ## --- save raw trials  --- ##
    if all_trials!=None:
        with open(os.path.join(trials_dir, 'train_trials.pkl'), 'wb') as f:
                    pkl.dump(all_trials, f) 
        print("train trials saved")
    if all_eval_trials!=None:
        with open(os.path.join(trials_dir, 'eval_trials.pkl'), 'wb') as f:
                    pkl.dump(all_eval_trials, f) 
        print("eval trials saved")

    ## --- save yhats  --- ##
    if all_yhats!=None:
        with open(os.path.join(yhats_dir, 'train_yhats.pkl'), 'wb') as f:
                    pkl.dump(all_yhats, f) 
        print("train yhats saved")
    if all_eval_yhats!=None:
        with open(os.path.join(yhats_dir, 'eval_yhats.pkl'), 'wb') as f:
                    pkl.dump(all_eval_yhats, f) 
        print("eval yhats saved")

    ## --- save train perfs and losses  --- ##
    if all_train_perfs!=None: 
        with open(os.path.join(perfs_dir, 'train_perfs.pkl'), 'wb') as f:
                    pkl.dump(all_train_perfs, f) 
        print("train perfs saved")
    if all_train_losses!=None:
        with open(os.path.join(losses_dir, 'train_losses.pkl'), 'wb') as f:
                    pkl.dump(all_train_losses, f) 
        print("train losses saved")


    ## --- save hpc loss  --- ##
    if all_losses_hpc!=None:
        with open(os.path.join(losses_dir, 'losses_sep.pkl'), 'wb') as f:
                pkl.dump(all_losses_hpc, f)
    ## if mixed batch
    if hp['mixed_batch']: 
        with open(os.path.join(this_dir, 'mixed_batch_tasks.pkl'), 'wb') as f:
                pkl.dump(mixed_batch_tasks, f)


    ## --- save hids  --- ##
    if all_hids_ctx!=None:
        # with open(os.path.join(hids_dir, model_name_str + '_ctx-hids' + '.pkl'), 'wb') as f:
        #             pkl.dump(all_hids_ctx, f) 
        with open(os.path.join(hids_dir, 'ctx-hids.pkl'), 'wb') as f:
                    pkl.dump(all_hids_ctx, f) 
        with open(os.path.join(hids_dir, 'hpc-hids.pkl'), 'wb') as f:
                    pkl.dump(all_hids_hpc, f)       
        print("hids saved")
    

    ## --- save weights  --- ##
    if all_ctx_h2h_weights!= None:
        with open(os.path.join(weights_dir, 'ctx-h2h-weights' + '.pkl'), 'wb') as f:
                    pkl.dump(all_ctx_h2h_weights, f) 
        with open(os.path.join(weights_dir, 'hpc-h2h-weights' + '.pkl'), 'wb') as f:
                    pkl.dump(all_hpc_h2h_weights, f)      
        with open(os.path.join(weights_dir, 'ctx2hpc-weights' + '.pkl'), 'wb') as f:
                    pkl.dump(all_ctx2hpc_weights, f) 
        with open(os.path.join(weights_dir, 'hpc2ctx-weights' + '.pkl'), 'wb') as f: 
                    pkl.dump(all_hpc2ctx_weights, f) 
        with open(os.path.join(weights_dir, 'ctx2out-weights' + '.pkl'), 'wb') as f: 
                    pkl.dump(all_ctx2out_weights, f) 
        with open(os.path.join(weights_dir, 'hpc2out-weights' + '.pkl'), 'wb') as f: 
                    pkl.dump(all_hpc2out_weights, f) 
        print("weights saved")


#### load model func (unused)

In [7]:
load_model = False

if load_model:
    model_dir = '/home/jason/dev/schema1/models/'
    model_name = 'fdgo1000-reactgo1000-delayanti1000_a0.2_wd1e-05_ctx1e-04_hpc1e-03.pth'
    model_load_path = os.path.join(model_dir, model_name)
    model = torch.load(model_load_path, weights_only=False)

## Train model

### hyperparameters

In [None]:

rules_dict = \
    {'all' : ['fdgo', 'reactgo', 'delaygo', 'fdanti', 'reactanti', 'delayanti',
              'dm1', 'dm2', 'contextdm1', 'contextdm2', 'multidm',
              'delaydm1', 'delaydm2', 'contextdelaydm1', 'contextdelaydm2', 'multidelaydm',
              'dmsgo', 'dmsnogo', 'dmcgo', 'dmcnogo'],

    'mante' : ['contextdm1', 'contextdm2'],

    'oicdmc' : ['oic', 'dmc']
	}

ruleset = 'all'
num_ring = 2
n_eachring = 32
n_rule = len(rules_dict[ruleset])

in_size_model, out_size = 1+(num_ring*n_eachring)+n_rule, n_eachring+1
n_input, n_output = 1+(num_ring*n_eachring)+n_rule, n_eachring+1


In [None]:
## HP DICT 

hp = {
        'in_size_model': in_size_model,
		'out_size': out_size,
        'n_eachring': n_eachring,

        'hid_size_ctx': 64,
        'hid_size_hpc': 64,       

		'learning_rate': 0.001, # optimizer lr
        'lr_ctx': 1e-4,
        'lr_hpc': 1e-3, 
		# 'lr_c2h': gmean([hp['lr_ctx'], hp['lr_hpc']]),
		# 'lr_h2c': gmean([hp['lr_ctx'], hp['lr_hpc']]),

        'activation': 'relu', # 'softplus', 'tanh', 'relu
        'alpha': 0.2, # for forward pass / rnn gating
		'weight_decay': 1e-4,

        'batch_size': 128,
		'eval_batch_size': 8,
		'eval_step': 1,
		'eval_epochs': 1,
        'weight_save_int': 100, 

        'dt': 20,
        'in_type': 'normal',
        'loss_type': 'lsq',
        'n_input': n_input,
        'n_output': n_output,
        'num_ring': num_ring,
        'optimizer': 'adam',
        'ruleset': ruleset,
        'rule_start': 1+num_ring*n_eachring,
        'save_name': 'test',
        'sigma_rec': 0.05, # for noise calc
        'sigma_x': 0.01, # for noise calc
        'target_perf': 1.,
        'tau': 100,
        'use_separate_input': False,
        'w_rec_init': 'randortho',
        }

hp['lr_c2h'] = gmean([hp['lr_ctx'], hp['lr_hpc']]) # hp['lr_ctx'] # gmean([hp['lr_ctx'], hp['lr_hpc']])
hp['lr_h2c'] = gmean([hp['lr_ctx'], hp['lr_hpc']]) # hp['lr_hpc'] # hp['lr_c2h']

# _w_in_start = 1.0
# _w_rec_start = 0.5

seed = 0
trainables = 'all'

hp['save_eval_hids'] = False
hp['seed'] = seed
hp['rng'] = np.random.RandomState(seed)
hp['device'] = device


In [9]:
### SPARSITY VARS

hp['force_sparsity'] = False

hp['prune_int'] = 100
hp['prune_ctx'] = 0.1
hp['prune_hpc'] = 0.2
hp['l1_lambda_ctx'] = 1e-5
hp['l1_lambda_hpc'] = 1e-5

In [None]:
### HPC LOSS VARS

hp['hpc_loss'] =  'decorr' # 'decorr' 'recon', None

hp['hpc_loss_alpha'] = 1
decorr_alpha = 1e-4 # 1e-3
recon_alpha = 0.5

if hp['hpc_loss']!=None : 
    if hp['hpc_loss']=='decorr': 
        hp['mixed_batch'] = True
        hp['hpc_loss_alpha'] = decorr_alpha
    elif hp['hpc_loss']=='recon': hp['hpc_loss_alpha'] = recon_alpha


### initiate training

In [None]:
device = torch.device("cuda") # if torch.cuda.is_available() else "cpu")
save_all = True
save_hids = False
save_weights = True

training_type = 'interleaved'
hp['mode'] = 'random'
hp['mixed_batch'] = True

hp['hid_size_ctx'] = 64
hp['hid_size_hpc'] = 64

# print('mixed batch: ', hp['mixed_batch'], ' | hpc_loss: ', hp['hpc_loss'])


# if training_type == 'sequential':
#     task_list = ['reactgo', 'fdgo'] # , 'fdanti', 'reactanti'] # , 'reactanti_0', 'fdgo_0'] # , 'reactanti', 'dmsgo'] # delaygo
#     task_list_unique = set(task_list)
#     epoch_list = [4000, 4000, 2000] # , 1000, 1000] # , 1000, 500]

if training_type == 'interleaved':

    if hp['mixed_batch']:
        task_list_unique = ['fdgo','reactgo', 'delaygo', 'reactoffgo', 'reactoffanti'] # 'reactoffgo', 
        task_probs = {
            '1000a':  [1/3, 1/3, 1/3, 0.0, 0.0], # "pre-training" | '# epochs': probs corresponding to task_list
            '2000a':  [1/7, 1/7, 1/7, 2/7, 2/7],
        }

        epoch_intervals = task_probs.keys()
        epoch_accum = list(accumulate([int(a[:-1]) for a in epoch_intervals])) # getting the accumulation of epochs per paradigm switch
        task_list = [0]*list(accumulate([int(a[:-1]) for a in epoch_intervals]))[-1]
        epoch_list = None

    else: ### --- tasks fully listed along with number of epochs, entire paradigm explicitly coded --- ### 

        # task_list_unique = ['fdgo','reactgo', 'delaygo', 'reactoffgo', 'reactoffanti'] 
        # epoch_list = [1000, 1000, 1000, 1000, 1000] # , 500,500] 
        # task_list_pre = [task_list_unique[0]]*500 + [task_list_unique[1]]*500 + [task_list_unique[2]]*500
        # task_list_post = [task_list_unique[0]]*500 + [task_list_unique[1]]*500 + [task_list_unique[2]]*500 \
        #                  + [task_list_unique[3]]*1000 + [task_list_unique[4]]*1000
        # random.shuffle(task_list_pre)
        # random.shuffle(task_list_post)
        # task_list = task_list_pre + task_list_post

        task_list_unique = ['reactoffanti']
        epoch_list = [500] # ,500,500]
        task_list_pre = [task_list_unique[0]]*500
        # task_list_post = ['fdgo']*500 + ['delaygo']*500
        # random.shuffle(task_list_pre)
        # random.shuffle(task_list_post)
        task_list = task_list_pre #  + task_list_post

        task_probs = None
    
else:
    raise Exception("invalid training type")

# task_list_unique = set(task_list)
# n_tasks_unique = len(task_list_unique)
# total_n_epochs = np.sum(epoch_list)


# scheduler = lr_scheduler.ExponentialLR(opt, gamma=0.99)
n_tasks = len(task_list) # len(task_list) # 4
n_runs = 1 # number of runs of this paradigm
print("TOTAL RUNS: ", n_runs)

for i in range(n_runs):
    
    model = Model(hp).to(device)
    opt = optim.Adam([ # setting the learning rates of the specific layers
            {'params': model.params.cortical.parameters(), 'lr': hp['lr_ctx'],},
            {'params': model.params.hippocampal.parameters(), 'lr': hp['lr_hpc'],},
            {'params': model.params.ctx2hpc.parameters(), 'lr': hp['lr_c2h'],},
            {'params': model.params.hpc2ctx.parameters(), 'lr': hp['lr_h2c'],},
        ], weight_decay=hp['weight_decay']) # , betas=(0.9, 0.999)) # , momentum=0.9,)
    model.set_optimizer(opt)
    model.train()

    print('\nTraining run #'+str(i))

    # if training_type == 'sequential':
    #     (all_train_losses, all_train_perfs), (all_eval_losses, all_eval_perfs), \
    #     (all_xs, all_ys, all_yhats, all_ylocs), \
    #     (all_hids_ctx, all_hids_hpc), \
    #     (all_ctx_h2h_weights, all_hpc_h2h_weights), \
    #     (all_ctx2hpc_weights, all_hpc2ctx_weights) \
    #         = train_sequential(model, hp, task_list, epoch_list, save_hids=save_hids, save_weights=save_weights)

    if training_type == 'interleaved':
        (all_train_losses, all_losses_hpc, all_train_perfs, mixed_batch_tasks), \
        (all_eval_losses, all_eval_perfs), \
        (all_yhats, all_trials, all_eval_yhats, all_eval_trials), \
        (all_hids_ctx, all_hids_hpc), \
        (all_ctx_h2h_weights, all_hpc_h2h_weights), \
        (all_ctx2hpc_weights, all_hpc2ctx_weights), \
        (all_ctx2out_weights, all_hpc2out_weights) \
            = train_interleaved(model, hp, task_list_unique, task_list, 
                                epoch_list=epoch_list, task_probs=task_probs, save_hids=save_hids, save_weights=save_weights)
        
    else:
        raise Exception("invalid training type")

    print("Training finished!")

    if save_all:
        save_data(model, hp, task_list_unique, all_eval_perfs, all_eval_losses, 
                    epoch_list=epoch_list, 
                    all_trials=all_trials, all_eval_trials=all_eval_trials,
                    all_yhats=all_yhats, all_eval_yhats=all_eval_yhats,
                    all_train_perfs=all_train_perfs, all_train_losses=all_train_losses,
                    all_losses_hpc=all_losses_hpc, mixed_batch_tasks=mixed_batch_tasks,
                    all_hids_ctx=all_hids_ctx, all_hids_hpc=all_hids_hpc, 
                    all_ctx_h2h_weights=all_ctx_h2h_weights, all_hpc_h2h_weights=all_hpc_h2h_weights,
                    all_ctx2hpc_weights=all_ctx2hpc_weights, all_hpc2ctx_weights=all_hpc2ctx_weights,
                    all_ctx2out_weights=all_ctx2out_weights, all_hpc2out_weights=all_hpc2out_weights)

    print("ALL DONE")

TOTAL RUNS:  1

Training run #0


5it [00:03,  1.47it/s]


KeyboardInterrupt: 