In [1]:
import torch
import torch.optim as optim

import numpy as np
import pandas as pd
import pickle
from utils.vmf_batch import vMF

from models import SeqEncoder, SeqDecoder, Seq2Seq_VAE, PoolingClassifier, init_weights
from utils.training_utils import train, evaluate

## plotting ###

import warnings
warnings.filterwarnings("ignore")

import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

%matplotlib inline

In [2]:
SEED = 17
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [37]:
# setting directories for the dataset's data iterators
iterator_path = Path('./data/Farrow_data/iterator/soma_centered')
train_iterator_path = iterator_path / 'train_iterator.pkl'
val_iterator_path = iterator_path / 'val_iterator.pkl'

with open(train_iterator_path, 'rb') as f:
    train_iterator = pickle.load(f)

with open(val_iterator_path, 'rb') as f:
    val_iterator = pickle.load(f)
   

In [38]:
model_path = Path('./models/Farrow/finetuned/soma_centered')
state_dict_path = Path('./models/parameter_search')

In [39]:
src_data, trg_data, seq_len, indices, labels = list(train_iterator)[0]
bs, n_walks, walk_length, output_dim = src_data.shape

N_train = len(train_iterator.sampler.indices)
N_val = len(val_iterator.sampler.indices)

 
MASKING_ELEMENT = train_iterator.dataset.masking_el

# get number of labels, ignore -100 index
l = list(np.unique(labels))
if -100 in l:
    l.remove(-100)
NUM_CLASSES = len(l)

### load model

In [40]:
emb_dim = 16
hid_dim = 16 # idk if i need this or not
latent_dim = 8
NUM_LAYERS = 2
dpout = .1
kap = 500
pool = 'avg'
lr = 0.01

enc = SeqEncoder(output_dim, emb_dim, emb_dim, NUM_LAYERS, dpout)
dec = SeqDecoder(output_dim, emb_dim, emb_dim, NUM_LAYERS, dpout)
dist = vMF(latent_dim, kappa=kap)
model = Seq2Seq_VAE(enc, dec, dist, device).to(device)
classifier = PoolingClassifier(latent_dim, NUM_CLASSES, n_walks,dpout,pooling=pool).to(device)

KLD: 18.465579986572266


In [41]:
def calculate_loss(x, reconstructed_x, ignore_el=MASKING_ELEMENT):
    # reconstruction loss
    # x = [trg len, batch size * n walks, output dim]

    seq_len , bs, output_dim = x.shape
    mask = x[:,:,0] != ignore_el
    RCL = 0
    for d in range(output_dim):
        RCL += mse_loss(reconstructed_x[:,:,d][mask], x[:,:,d][mask])
    RCL /= output_dim
    
    return RCL

In [42]:
torch.cuda.empty_cache()

In [43]:
#path = "./models/5_populations"
#import os 
#os.makedirs(path, exist_ok=True)

In [44]:
param_combo = 'emb%s_hid%s_lat%s_dp%s_k%s_%s'%(emb_dim, hid_dim, latent_dim,
                                                dpout, kap, pool)
state_dict_pt = '%s_run1_best.pt'%param_combo
#state_dict = torch.load(state_dict_path / state_dict_pt)

In [45]:
state_dict_path / state_dict_pt

PosixPath('models/parameter_search/emb16_hid16_lat8_dp0.1_k500_avg_run1_best.pt')

In [49]:

N_EPOCHS= 200
save_path_model= model_path / 'finetuned_vae_frac%.1f_best_run%i.pt'
save_path_losses = model_path / 'finetuned_losses_frac%.1f_run%i.npy'
save_path_elapsed_time = model_path / 'finetuned_elapsed_time_frac%.1f_run%i.npy'

param_combo = 'emb%s_hid%s_lat%s_dp%s_k%s_%s'%(emb_dim, hid_dim, latent_dim,
                                                dpout, kap, pool)
#state_dict_pt = '%s_frac1.0_run1_best.pt'%param_combo
#state_dict_pt = '%s_run1_best.pt'%param_combo
state_dict_pt = '%s_run1.pt'%param_combo
state_dict = torch.load(state_dict_path / state_dict_pt)
state_dict_npy = '%s_run1.npy'%param_combo

# start with the 5 pop toy data, "true" finetuning on the Farrow data
# loads in parameters from 5 pop toy data, updates from what it learns from Farrow
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

for frac in [1., .9, .5, .1, 0.]:
    
  
    runs = range(1,4)
        
    for run in runs:
        
        # load pre-trained model
        # the first run was the best
        model.load_state_dict(state_dict['model_state_dict'])
        classifier.apply(init_weights)
        cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='sum', ignore_index=-100)
        mse_loss = torch.nn.MSELoss(reduction='sum')


        #optimizer
        optimizer = optim.Adam(list(model.parameters()) + list(classifier.parameters()), lr=lr)
       
        best_test_loss = np.infty

        losses = np.load(state_dict_path / state_dict_npy)
        elapsed_time = np.zeros((N_EPOCHS))
        training = list(losses[:,:2])
        validation = list(losses[:,2:])
        
        for e in range(N_EPOCHS):
            start.record()
            train_loss, train_class_loss = train(model, classifier, train_iterator, optimizer, 
                                               calculate_loss,cross_entropy_loss, clip=1, norm_p=None,
                                                 class_fraction=frac)
            val_loss, val_class_loss = evaluate(model,classifier, val_iterator,
                                                 calculate_loss, cross_entropy_loss, norm_p=None)

            train_loss /= N_train
            train_class_loss /= N_train
            val_loss /= N_val
            val_class_loss /=N_val
            
            end.record()

            # Waits for everything to finish running
            torch.cuda.synchronize()
            elapsed_time[e] = start.elapsed_time(end) # milliseconds
            
            training += [[train_loss,train_class_loss]]
            validation += [[val_loss, val_class_loss]]
            print(f'Epoch {e}, Train Loss: {train_loss:.2f}, Val Loss: {val_loss:.2f}, Time elapsed [s]: {elapsed_time[e]/1000:.2f}')


            if e % 50 == 0 and e > 0:
                optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr']/2

            if best_test_loss > val_loss:
                best_test_loss = val_loss
                torch.save({'epoch': e,
                            'model_state_dict': model.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'classifier_state_dict': classifier.state_dict()
                               },save_path_model%(frac, run))

                validation_ = np.array(validation)
                training_ = np.array(training)
                # [:,0] = training loss, [:,1] = training classification loss 
                # [:,2] validation loss, [:,3] validation classification loss
                losses = np.hstack((training_, validation_))
                np.save(save_path_losses%(frac, run),losses)
                np.save(save_path_elapsed_time%(frac,run),elapsed_time)
        validation = np.array(validation)
        training = np.array(training)
        losses = np.hstack((training, validation))
        np.save(save_path_losses%(frac, run), losses)
        np.save(save_path_elapsed_time%(frac,run),elapsed_time)

RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx

In [None]:
plt.plot(np.array(validation)[:,0])

In [None]:
elapsed_time.mean()/1000
