In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader,TensorDataset
from torch.utils.data.sampler import SubsetRandomSampler
from tqdm.auto import tqdm
import GPUtil
import numpy as np
from utils.vmf_batch import vMF

from models import SeqEncoder, SeqDecoder, Seq2SeqDataSet, Seq2Seq_VAE, PoolingClassifier, init_weights
from itertools import product
from utils.training_utils import create_Seq2SeqDataset, train, evaluate
from importlib import reload
import utils.training_utils

from datetime import datetime
## plotting ###

import warnings
warnings.filterwarnings("ignore")

import matplotlib.pyplot as plt
import seaborn as sns


In [2]:
#reload(training_utils)


In [3]:
SEED = 17
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
device

device(type='cuda')

In [5]:
# parameter
INPUT_DIM = 3   
lr = 1e-2                           # learning rate
NUM_LAYERS = 2
NUM_CLASSES = 3
N_EPOCHS = 150

In [6]:
import pickle
with open('./data/toy_data/3_populations/iterator/val_iterator.pkl', 'rb') as f:
    val_iterator = pickle.load(f)

with open('./data/toy_data/3_populations/iterator/train_iterator.pkl', 'rb') as f:
    train_iterator = pickle.load(f)

In [7]:
N_train = len(train_iterator.sampler.indices)
N_val = len(val_iterator.sampler.indices)
MASKING_ELEMENT = train_iterator.dataset.masking_el
n_walks = train_iterator.dataset.n_walks

In [8]:
def calculate_loss(x, reconstructed_x, ignore_el=MASKING_ELEMENT):
    # reconstruction loss
    # x = [trg len, batch size * n walks, output dim]

    seq_len , bs, output_dim = x.shape
    mask = x[:,:,0] != ignore_el
    RCL = 0
    for d in range(output_dim):
        RCL += mse_loss(reconstructed_x[:,:,d][mask], x[:,:,d][mask])
    RCL /= output_dim
    
    return RCL


In [9]:
torch.cuda.empty_cache()

In [10]:
# EMD_D = HID_D = 16, LAT_D = 8, Dropout= 0.1, kappa= 500, pooling=avg
# EMD_D = HID_D = 16, LAT_D = 8, Dropout= 0.3, kappa= 500, pooling=avg
# EMD_D = HID_D = 16, LAT_D = 8, Dropout= 0.5, kappa= 500, pooling=avg
embedding_dims = [16,32]
latent_dims = [8,16,32]
dropout = [.5]
kappa = [500]
pooling = ['avg', 'max']
parameter_grid = list(product(embedding_dims, latent_dims, dropout, kappa, pooling))

In [11]:
# create directory for ./models/parameter_search

In [12]:
parameter_grid.index

<function list.index(value, start=0, stop=9223372036854775807, /)>

In [13]:
int(parameter_grid[0][0])

16

In [None]:
# for each parameter combination, train 3 models
np.random.seed(SEED)
torch.manual_seed(SEED)
n_runs = 1
pbar = tqdm(total=n_runs*len(parameter_grid))

while len(parameter_grid) > 0:
    
    # get next parameter set
    emb_dim = int(parameter_grid[0][0])
    latent_dim = int(parameter_grid[0][1])
    dpout = float(parameter_grid[0][2])
    kappa = int(parameter_grid[0][3])
    pool = parameter_grid[0][4]
    
    print('Fitting model with parameters: \
    EMD_D = HID_D = %i, LAT_D = %i, Dropout= %.1f, kappa= %i, pooling=%s'%(emb_dim, 
                                                                           latent_dim, 
                                                                           dpout, 
                                                                           kappa, 
                                                                           pool))
    # save the file without this parameter set
    with open('./models/parameter_search/parameter_grid.txt', 'wb') as f: 
        np.save(f,parameter_grid[1:])
        

    for k in range(n_runs):
        pbar.update()
        
        start = datetime.now()
        # model
        enc = SeqEncoder(INPUT_DIM, emb_dim, emb_dim, NUM_LAYERS, dpout)
        dec = SeqDecoder(INPUT_DIM, emb_dim, emb_dim, NUM_LAYERS, dpout)
        dist = vMF(latent_dim, kappa=kappa)
        model = Seq2Seq_VAE(enc, dec, dist, device).to(device)
        classifier = PoolingClassifier(latent_dim, NUM_CLASSES, n_walks,dpout,pooling=pool).to(device)
        
        # initialize model 
        model.apply(init_weights)
        classifier.apply(init_weights)
        
        # losses
        cross_entropy_loss = nn.CrossEntropyLoss(reduction='sum')
        mse_loss = nn.MSELoss(reduction='sum')

        #optimizer
        optimizer = optim.Adam(list(model.parameters()) + list(classifier.parameters()), lr=lr)

        ### train the model(s)

        best_test_loss = np.infty
        training = []
        validation = []
        for e in range(N_EPOCHS):

            train_loss, train_class_loss = train(model, classifier, train_iterator, optimizer, 
                                               calculate_loss,cross_entropy_loss, 1,1.)
            val_loss, val_class_loss = evaluate(model,classifier, val_iterator,
                                                 calculate_loss, cross_entropy_loss)


            train_loss /= N_train
            train_class_loss /= N_train
            val_loss /= N_val
            val_class_loss /=N_val

            training += [[train_loss, train_class_loss]]
            validation += [[val_loss, val_class_loss]]
            #print(f'Epoch {e}, Train Loss: {train_loss:.2f}, Test Loss: {val_loss:.2f}')


            if e % 50 == 0 and e > 0:
                optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr']/2
            
            if best_test_loss > val_loss:
                best_test_loss = val_loss
                suffix = 'emb%i_hid%i_lat%i_dp%.1f_k%i_%s'%(emb_dim,emb_dim,latent_dim,dpout,kappa,pool)
                torch.save({'epoch': e,
                            'model_state_dict': model.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'classifier_state_dict': classifier.state_dict()
                           }, './models/parameter_search/%s_run%i_best.pt'%(suffix,(k+1)))
        # save training and validation loss
        validation_ = np.array(validation)
        training_ = np.array(training)
        losses = np.concatenate((training_, validation_), axis=1)
        # losses [:,0] = training loss, [:,1] = training classification loss
        # [:,2] = validation loss, [:,3] = validation classification loss
        with open('./models/parameter_search/losses_%s_%i.npy'%(suffix, (k+1)), 'wb') as f:
            np.save(f,losses)
            
        end = datetime.now()
        print('Time to fit model %i : '%(k+1), end-start)
    torch.cuda.empty_cache()
    with open('./models/parameter_search/parameter_grid.txt', 'rb') as f: 
        parameter_grid = np.load(f)

pbar.close()

  0%|          | 0/12 [00:00<?, ?it/s]

Fitting model with parameters:     EMD_D = HID_D = 16, LAT_D = 8, Dropout= 0.5, kappa= 500, pooling=avg
KLD: 18.465579986572266
