In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader,TensorDataset
from torch.utils.data.sampler import SubsetRandomSampler

import GPUtil
import numpy as np
from utils.vmf_batch import vMF

from models import SeqEncoder, SeqDecoder, Seq2SeqDataSet, Seq2Seq_VAE, PoolingClassifier, init_weights
from itertools import product
from utils.training_utils import train, evaluate
from datetime import datetime
## plotting ###

import warnings
warnings.filterwarnings("ignore")

import matplotlib.pyplot as plt
import seaborn as sns

%matplotlib inline

In [4]:
# from importlib import reload
# import training_utils
# reload(training_utils)
# from training_utils import train, evaluate


In [5]:
SEED = 17
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
import pickle
folder = '3_populations'
with open('./data/toy_data/%s/iterator/shuffled_val_iterator.pkl'%(folder), 'rb') as f:
    val_iterator = pickle.load(f)

with open('./data/toy_data/%s/iterator/shuffled_train_iterator.pkl'%(folder), 'rb') as f:
    train_iterator = pickle.load(f)

In [None]:
N_train = len(train_iterator.sampler.indices)
N_val = len(val_iterator.sampler.indices)
n_walks = train_iterator.dataset.n_walks
# parameter
INPUT_DIM = 3   
lr = 1e-2                           # learning rate
NUM_LAYERS = 2
NUM_CLASSES = 3
N_EPOCHS = 150
MASKING_ELEMENT = 0

In [None]:
def calculate_loss(x, reconstructed_x, ignore_el=MASKING_ELEMENT):
    # reconstruction loss
    # x = [trg len, batch size * n walks, output dim]

    seq_len , bs, output_dim = x.shape
    mask = x[:,:,0] != ignore_el
    RCL = 0
    for d in range(output_dim):
        RCL += mse_loss(reconstructed_x[:,:,d][mask], x[:,:,d][mask])
    RCL /= output_dim
    
    return RCL


In [None]:
torch.cuda.empty_cache()

In [None]:
np.random.seed(SEED)
torch.manual_seed(SEED)

emb_dim = 32
latent_dim = 32
dpout = .1
kappa = 500
pool = 'max'
    
### train the model(s)
for frac in [1.]:
    for k in range(1,4):
        start = datetime.now()
        # model
        enc = SeqEncoder(INPUT_DIM, emb_dim, emb_dim, NUM_LAYERS, dpout)
        dec = SeqDecoder(INPUT_DIM, emb_dim, emb_dim, NUM_LAYERS, dpout)
        dist = vMF(latent_dim, kappa=kappa)
        model = Seq2Seq_VAE(enc, dec, dist, device).to(device)
        classifier = PoolingClassifier(latent_dim, NUM_CLASSES, n_walks,dpout,pooling=pool).to(device)

        # initialize model 
        model.apply(init_weights)
        classifier.apply(init_weights)

        # losses
        cross_entropy_loss = nn.CrossEntropyLoss(reduction='sum')
        mse_loss = nn.MSELoss(reduction='sum')

        #optimizer
        optimizer = optim.Adam(list(model.parameters()) + list(classifier.parameters()), lr=lr)

        best_test_loss = np.infty
        N_EPOCHS= 150
        training = []
        validation = []
        for e in range(N_EPOCHS):

            train_loss, train_class_loss = train(model, classifier, train_iterator, optimizer, 
                                               calculate_loss,cross_entropy_loss, 
                                                 clip=1,norm_p=None, class_fraction=frac)
            val_loss, val_class_loss = evaluate(model,classifier, val_iterator,
                                                 calculate_loss, cross_entropy_loss, norm_p=None)


            train_loss /= N_train
            train_class_loss /= N_train
            val_loss /= N_val
            val_class_loss /=N_val

            training += [[train_loss, train_class_loss]]
            validation += [[val_loss, val_class_loss]]
            print(f'Epoch {e}, Train Loss: {train_loss:.9f}, Test Loss: {val_loss:.9f}')


            if e % 50 == 0 and e > 0 :
                optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr']/2

            if best_test_loss > val_loss:
                best_test_loss = val_loss
                suffix = 'shuffled_emb%i_hid%i_lat%i_dp%.1f_k%i_%s'%(emb_dim,emb_dim,latent_dim,dpout,kappa,pool)
                suffix += '_frac%.1f'%frac
                suffix += '_unscaled'
                suffix += '_sum'
                
                torch.save({'epoch': e,
                                'model_state_dict': model.state_dict(),
                                'optimizer_state_dict': optimizer.state_dict(),
                                'classifier_state_dict': classifier.state_dict()
                               }, './models/%s/%s_run%i_best.pt'%(folder,suffix,(k+1)))
                # save training and validation loss
                validation_ = np.array(validation)
                training_ = np.array(training)
                losses = np.concatenate((training_, validation_), axis=1)
                # losses [:,0] = training loss, [:,1] = training classification loss
                # [:,2] = validation loss, [:,3] = validation classification loss
                with open('./models/%s/shuffled_losses_%s_%i.npy'%(folder, suffix, (k+1)), 'wb') as f:
                    np.save(f,losses)

        validation_ = np.array(validation)
        training_ = np.array(training)
        losses = np.concatenate((training_, validation_), axis=1)
        # losses [:,0] = training loss, [:,1] = training classification loss
        # [:,2] = validation loss, [:,3] = validation classification loss
        with open('./models/%s/shuffled_losses_%s_%i.npy'%(folder,suffix, (k+1)), 'wb') as f:
                np.save(f,losses)
        end = datetime.now()
        print('Time to fit model %i : '%(k+1), end-start)
    torch.cuda.empty_cache()

KLD: 45.709938049316406
Epoch 0, Train Loss: 1360.344875000, Test Loss: 8337.006000000
Epoch 1, Train Loss: 1003.453680556, Test Loss: 6842.868500000
Epoch 2, Train Loss: 839.720486111, Test Loss: 7049.564500000
Epoch 3, Train Loss: 720.715430556, Test Loss: 6775.543000000
Epoch 4, Train Loss: 676.795076389, Test Loss: 7422.435000000
Epoch 5, Train Loss: 633.520416667, Test Loss: 7328.762500000
Epoch 6, Train Loss: 628.223395833, Test Loss: 8124.971000000
Epoch 7, Train Loss: 604.718465278, Test Loss: 6883.512000000
Epoch 8, Train Loss: 605.018402778, Test Loss: 7300.894500000
Epoch 9, Train Loss: 593.012541667, Test Loss: 7053.825000000
Epoch 10, Train Loss: 585.497750000, Test Loss: 7017.565000000
