In [1]:
import torch
import torch.optim as optim

import numpy as np
import pandas as pd
import pickle
import os
from utils.vmf_batch import vMF

from models import SeqEncoder, SeqDecoder, Seq2Seq_VAE, PoolingClassifier, init_weights
from utils.training_utils import train, evaluate

## plotting ###

import warnings
warnings.filterwarnings("ignore")

import matplotlib.pyplot as plt
import seaborn as sns

%matplotlib inline

In [2]:
SEED = 17
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
with open('./data/Farrow_data/iterator/soma_centered/train_iterator.pkl', 'rb') as f:
    train_iterator = pickle.load(f)

with open('./data/Farrow_data/iterator/soma_centered/val_iterator.pkl', 'rb') as f:
    val_iterator = pickle.load(f)
   

In [4]:
src_data, trg_data, seq_len, indices, labels = list(train_iterator)[0]
bs, n_walks, walk_length, output_dim = src_data.shape

N_train = len(train_iterator.sampler.indices)
N_val = len(val_iterator.sampler.indices)

 
MASKING_ELEMENT = train_iterator.dataset.masking_el

# get number of labels, ignore -100 index
l = list(np.unique(labels))
if -100 in l:
    l.remove(-100)
NUM_CLASSES = len(l)

In [5]:
np.unique(train_iterator.dataset.labels[train_iterator.sampler.indices])

array([-100,    0,    1,    2,    3,    4,    5,    6,    7,    8,    9,
         10,   11,   12,   13])

### load model

In [6]:
emb_dim = 32
latent_dim = 32
NUM_LAYERS = 2
dpout = .1
kap = 500
pool = 'max'
lr = 0.01

enc = SeqEncoder(output_dim, emb_dim, emb_dim, NUM_LAYERS, dpout)
dec = SeqDecoder(output_dim, emb_dim, emb_dim, NUM_LAYERS, dpout)
dist = vMF(latent_dim, kappa=kap)
model = Seq2Seq_VAE(enc, dec, dist, device).to(device)
classifier = PoolingClassifier(latent_dim, NUM_CLASSES, n_walks,dpout,pooling=pool).to(device)

KLD: 45.709938049316406


In [7]:
def calculate_loss(x, reconstructed_x, ignore_el=MASKING_ELEMENT):
    # reconstruction loss
    # x = [trg len, batch size * n walks, output dim]

    seq_len , bs, output_dim = x.shape
    mask = x[:,:,0] != ignore_el
    RCL = 0
    for d in range(output_dim):
        RCL += mse_loss(reconstructed_x[:,:,d][mask], x[:,:,d][mask])
    RCL /= output_dim
    
    return RCL

In [8]:
torch.cuda.empty_cache()

In [9]:
path = "./models/Farrow/finetuned/soma_centered"
import os 
os.makedirs(path, exist_ok=True)

In [None]:

N_EPOCHS= 50
save_path_model= './models/Farrow/finetuned/soma_centered/finetuned_scaled_vae_frac%.1f_best_run%i.pt'
save_path_losses = './models/Farrow/finetuned/soma_centered/finetuned_scaled_losses_frac%.1f_run%i.npy'
save_path_elapsed_time = './models/Farrow/finetuned/soma_centered/finetuned_scaled_elapsed_time_frac%.1f_run%i.npy'
# state_dict = torch.load('./models/5_populations/emb32_hid32_lat32_dp0.1_k500_max_frac1.0_scaled_sum_run1_best.pt')

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

for frac in [1., .9, .5, .1, 0.]:
    
  
    runs = range(1,4)
        
    for run in runs:
        
        #optimizer
        optimizer = optim.Adam(list(model.parameters()) + list(classifier.parameters()), lr=lr)
        
        if os.path.exists(save_path_model%(frac,run)):
            state_dict = torch.load(save_path_model%(frac,run))
            
            # load model
            model.load_state_dict(state_dict['model_state_dict'])
            
            # overwrite optimizer if the model had been trained already
            optimizer.load_state_dict(state_dict['optimizer_state_dict'])
            classifier.load_state_dict(state_dict['classifier_state_dict'])
            losses = np.load(save_path_losses%(frac, run))
            elapsed_time = np.load(save_path_elapsed_time%(frac, run))
            
            last_epoch = state_dict['epoch']
            training = list(losses[:last_epoch,:2])
            validation = list(losses[:last_epoch,2:])
            elapsed_time = elapsed_time[:last_epoch]
            elapsed_time = np.hstack((elapsed_time, np.zeros((N_EPOCHS))))
            best_test_loss = losses[:,2].min()
            
        else:
            # load pre-trained model
            state_dict = torch.load('./models/Farrow/scratch/soma_centered/vae_frac0.0_scaled_best_run%i.pt'%run)
            # the first run was the best
            model.load_state_dict(state_dict['model_state_dict'])
            classifier.apply(init_weights)
            best_test_loss = np.infty

            losses = np.load('./models/Farrow/scratch/soma_centered/losses_frac0.0_scaled_run%i.npy'%run)
            elapsed_time = np.load('./models/Farrow/scratch/soma_centered/elapsed_time_frac0.0_scaled_run%i.npy'%run)
            last_epoch = len(elapsed_time)
            elapsed_time = np.hstack((elapsed_time, np.zeros((N_EPOCHS))))
            training = list(losses[:,:2])
            validation = list(losses[:,2:])
        
        
        cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='sum', ignore_index=-100)
        mse_loss = torch.nn.MSELoss(reduction='sum')


        
        
        for e in range(N_EPOCHS):
            start.record()
            train_loss, train_class_loss = train(model, classifier, train_iterator, optimizer, 
                                               calculate_loss,cross_entropy_loss, clip=1, norm_p=None,
                                                 class_fraction=frac)
            val_loss, val_class_loss = evaluate(model,classifier, val_iterator,
                                                 calculate_loss, cross_entropy_loss, norm_p=None)

            train_loss /= N_train
            train_class_loss /= N_train
            val_loss /= N_val
            val_class_loss /=N_val
            
            end.record()

            # Waits for everything to finish running
            torch.cuda.synchronize()
            elapsed_time[e+last_epoch] = start.elapsed_time(end) # milliseconds
            
            training += [[train_loss,train_class_loss]]
            validation += [[val_loss, val_class_loss]]
            print(f'Epoch {e}, Train Loss: {train_loss:.2f}, Val Loss: {val_loss:.2f}, Time elapsed [s]: {elapsed_time[e]/1000:.2f}')


            if e % 50 == 0 and e > 0:
                optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr']/2

            if best_test_loss > val_loss:
                best_test_loss = val_loss
                torch.save({'epoch': e + last_epoch,
                            'model_state_dict': model.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'classifier_state_dict': classifier.state_dict()
                               },save_path_model%(frac, run))

                validation_ = np.array(validation)
                training_ = np.array(training)
                # [:,0] = training loss, [:,1] = training classification loss 
                # [:,2] validation loss, [:,3] validation classification loss
                losses = np.hstack((training_, validation_))
                np.save(save_path_losses%(frac, run),losses)
                np.save(save_path_elapsed_time%(frac,run),elapsed_time)
        validation = np.array(validation)
        training = np.array(training)
        losses = np.hstack((training, validation))
        np.save(save_path_losses%(frac, run), losses)
        np.save(save_path_elapsed_time%(frac,run),elapsed_time)

Epoch 0, Train Loss: 65.46, Val Loss: 393.71, Time elapsed [s]: 19.29
Epoch 1, Train Loss: 62.30, Val Loss: 412.80, Time elapsed [s]: 19.15
Epoch 2, Train Loss: 67.77, Val Loss: 391.99, Time elapsed [s]: 19.16
Epoch 3, Train Loss: 63.52, Val Loss: 380.41, Time elapsed [s]: 19.13
Epoch 4, Train Loss: 64.58, Val Loss: 376.85, Time elapsed [s]: 19.10
Epoch 5, Train Loss: 64.51, Val Loss: 386.56, Time elapsed [s]: 19.13
Epoch 6, Train Loss: 59.44, Val Loss: 382.96, Time elapsed [s]: 19.18
Epoch 7, Train Loss: 56.94, Val Loss: 401.25, Time elapsed [s]: 19.07
Epoch 8, Train Loss: 61.06, Val Loss: 393.43, Time elapsed [s]: 19.15
Epoch 9, Train Loss: 59.21, Val Loss: 399.39, Time elapsed [s]: 19.11
Epoch 10, Train Loss: 65.00, Val Loss: 422.77, Time elapsed [s]: 19.03
Epoch 11, Train Loss: 61.88, Val Loss: 424.27, Time elapsed [s]: 19.15
Epoch 12, Train Loss: 64.01, Val Loss: 433.07, Time elapsed [s]: 19.16
Epoch 13, Train Loss: 63.72, Val Loss: 451.85, Time elapsed [s]: 19.20
Epoch 14, Train 

In [None]:
plt.plot(np.array(validation)[:,0])

In [23]:
elapsed_time.mean()/1000


26.72294771484375