### Installs, Imports, and Pyro Setup 

In [1]:
# 1# Run first time
# ! module load cuda/9.2.88-gcc/7.1.0 cudnn/7.6.5.32-9.2-linux-x64-gcc/7.1.0-cuda9_2 anaconda3/2019.10-gcc/8.3.1
! source activate pytorch_env
# !pip install --up
# !pip install pyro-ppl
# !pip install torchvision
# !pip install --upgrade git+https://github.com/dhudsmith/clean-the-text
# !pip install umap-learn

In [2]:
from dataset import PairwiseData, DocumentPairData
from model import Encoder, Decoder, ProdLDA, custom_elbo
import os
import pickle as pkl
from typing import List

import numpy as np
import torch
import graphviz
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import pandas as pd


import pyro
from pyro import poutine
import pyro.distributions as dist
# import pyro.contrib.examples.util  # patches torchvision
from pyro.infer import SVI, Trace_ELBO, TraceMeanField_ELBO
from pyro.optim import Adam, ClippedAdam
import torch.nn.functional as F
from scipy.io import loadmat

import scipy
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.datasets import fetch_20newsgroups
from sklearn.metrics import roc_auc_score
from sklearn.metrics import accuracy_score
import pandas as pd
import matplotlib.pyplot as plt
import wandb
import umap
import umap.plot
import sklearn.datasets

plt.rcParams["figure.figsize"] = (10,5)

In [3]:
pyro.distributions.enable_validation(True)

In [4]:
pyro.set_rng_seed(0)

# Setup

In [5]:
# Data
NUMBER_PAIRS = 100000
PAIR_PERCENTAGE = 1.0


NUM_WORKERS = 0

# Dropout rate
NUM_TOPICS = 50
NUM_PROTOTYPES = 7
EMBED_DIM  = 64
HIDDEN_DIM = 128
DROPOUT_RATE = 0.2

# Training
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
NUM_EPOCHS = 25
TEST_FREQUENCY = 1
BATCH_SIZE = 128
VAL_BATCH_SIZE = 2000
OBS_SAMPLES = 10

# gamma = 0.25
# num_steps = (NUMBER_PAIRS // BATCH_SIZE) * NUM_EPOCHS
# lrd = np.exp(np.log(gamma)/num_steps)
adam_args = {"lr": 0.003, 'clip_norm':10.0, 'betas': (0.99, 0.999)}

latest_plotted_p = torch.zeros(NUM_PROTOTYPES, NUM_TOPICS)

# Data Prep

In [6]:
def get_data(pair_percentage=PAIR_PERCENTAGE):
    # pairwise data
    pwd = PairwiseData()
    train_pairs, val_pairs, test_pairs = [pwd.get_pairs_table(d) for d in [pwd.train, pwd.val, pwd.test]]

    # datasets
    data_train, data_val, data_test = [
        DocumentPairData(bows=pwd.bows, index_table=ix_table, prob=PAIR_PERCENTAGE)
        for ix_table in [train_pairs, val_pairs, test_pairs]
    ]

    # dataloaders
    dl_train = DataLoader(data_train, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True)
    dl_val = DataLoader(data_val, batch_size=VAL_BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=False)
    dl_test = DataLoader(data_test, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=False)
    
    return pwd, dl_train, dl_val, dl_test

# Training and Evaluation Functions 

In [7]:
def train(svi, topic_model, train_loader, device, progress_interval = 100, print_debug = False, epoch = 0, latest_plotted_p = None, trans = None):
    topic_model.train()
    
    # initialize loss accumulator
    epoch_loss = 0.
    logits_epoch = []
    labels_epoch = []
    batches = 0
    
    
    # do a training epoch over each mini-batch x returned
    # by the data loader
    for i, x in enumerate(train_loader):
        batches += 1
        # if on GPU put mini-batch into CUDA memory
        x_a = x['a'].to(device).squeeze()
        x_b = x['b'].to(device).squeeze()
        x_label = x['label'].to(device).type(torch.float32)
        x_observed = x['observed'].to(device).type(torch.bool)
        
        if print_debug and i % progress_interval == 0:
            print(f"Step {i}; avg. loss {epoch_loss/(i+1)}", end='\r')
            # p = pyro.param('p').detach().cpu()
            # phi = torch.softmax(p, axis=-1)
            # p_2d = trans.transform(p)
            # latest_plotted_p[:] = p
            # umap.plot.points(trans, labels=pwd.val.category,theme='fire')
            # for ix, marker in enumerate(['$1$', '$2$', '$3$', '$4$', '$5$', '$6$', '$7$']):
            #     plt.scatter(x=p_2d[ix,0], y=p_2d[ix,1], s=7**3, c='white', marker=marker)
            # plt.title(f"Frozen Enc/Dec : Epoch {epoch}, Step {i}")
            # plt.savefig(f'../data/figures_pretrained_unfrozen/epoch{epoch}_step{i}')
            # plt.close()
        
        # do ELBO gradient and accumulate loss
        epoch_loss += svi.step(x_a, x_b, x_label, x_observed)
        
        with torch.no_grad():
            x = svi.guide(x_a, x_b, x_label, x_observed)

        x_label = x_label.cpu()
        logits = x.cpu()
        logits_epoch += logits
        labels_epoch += x_label
        
    logits_binned = np.digitize(logits_epoch, [0.45], right=False)
    train_acc = accuracy_score(labels_epoch, logits_binned)
    train_auroc = roc_auc_score(labels_epoch, logits_epoch)
    
    normalizer_train = len(train_loader.dataset)
    total_epoch_loss_train = epoch_loss / normalizer_train
    return total_epoch_loss_train, train_acc, train_auroc

In [8]:
def evaluate(svi, etm, test_loader, device): 
    etm.eval()
    
    # initialize loss accumulator
    test_loss = 0.0
    test_ce_loss = 0.0
    test_kl_loss = 0.0
    test_acc = 0.0
    test_auroc = 0.0
    batches = 0
    logits_epoch = []
    labels_epoch = []
    
    # compute the loss over the entire test set
    for x in test_loader:
        batches += 1
        # if on GPU put mini-batch into CUDA memory
        x_a = x['a'].to(device).squeeze()
        x_b = x['b'].to(device).squeeze()
        x_label = x['label'].to(device).type(torch.float32)
        x_observed = x['observed'].to(device).type(torch.bool)
            
        # compute ELBO estimate and accumulate loss
        test_loss += svi.evaluate_loss(x_a, x_b, x_label, x_observed)
        
        # generate reconstruction of batch documents and move to cuda if designated
        with torch.no_grad():
            recon_x_a = etm.reconstruct_doc(x_a)
            recon_x_b = etm.reconstruct_doc(x_b)
            logits = svi.guide(x_a, x_b, x_label, x_observed)

        x_label = x_label.cpu()
        logits = logits.cpu()
        logits_epoch += logits
        labels_epoch += x_label

        # calculate and sum cross entropy loss and kl divergence
        x_a=x_a.squeeze()
        x_b=x_b.squeeze()
        log_probs_a = torch.log(recon_x_a)
        log_probs_b = torch.log(recon_x_b)
        
        if log_probs_a.isnan().any() | log_probs_b.isnan().any() | (log_probs_a.abs() > 30).any() | (log_probs_b.abs() > 30).any():
            raise ValueError("nan or very large log probs")
            
        targets_a = x_a/(x_a.sum(axis=-1)[:,None])
        targets_b = x_b/(x_b.sum(axis=-1)[:,None])
        ce_loss_a = F.cross_entropy(log_probs_a, targets_a, reduction='sum')
        ce_loss_b = F.cross_entropy(log_probs_b, targets_b, reduction='sum')
        test_ce_loss += ce_loss_a + ce_loss_b
        test_kl_loss += etm.calc_kl_divergence(x_a.squeeze(1)) + etm.calc_kl_divergence(x_b.squeeze(1))

        
    logits_binned = np.digitize(logits_epoch, [0.45], right=False)
    test_acc = accuracy_score(labels_epoch, logits_binned)
    test_auroc = roc_auc_score(labels_epoch, logits_epoch)
    
    normalizer_test = len(test_loader.dataset)
    epoch_elbo = test_loss / normalizer_test
    epoch_ce_loss = test_ce_loss / normalizer_test
    epoch_kl_loss = test_kl_loss / normalizer_test
    
    return epoch_elbo, epoch_ce_loss.item(), epoch_kl_loss.item(), test_acc, test_auroc

## Building Model

In [9]:
def init_model(pretrained, frozen, pwd):
    try:
        if frozen and not pretrained:
            raise Exception
    except:
        print('ERROR: attempting to freeze an non-pretrained model!')
        
    # clear param store
    pyro.clear_param_store()
    # setup the VAE
    topic_model = ProdLDA(pwd.vocab_size, NUM_TOPICS, NUM_PROTOTYPES, HIDDEN_DIM, DROPOUT_RATE, DEVICE, frozen=frozen).to(DEVICE)
    
    if pretrained:
        saved_model_dict = torch.load("../data/pretrained_lda_5.pt")
        topic_model.load_state_dict(saved_model_dict['model'])
        # svi_guide = saved_model_dict['guide']
        pyro.get_param_store().load("../data/pretrained_params_5.pt")
    
    if frozen:
        for param in topic_model.encoder.parameters():
            param.requires_grad=False
        for param in topic_model.decoder.parameters():
            param.requires_grad=False
        topic_model.frozen = True
        
    # setup the optimizer
    optimizer = ClippedAdam(adam_args)
    # setup the inference algorithm
    svi = SVI(topic_model.model, topic_model.guide, optimizer, loss=custom_elbo)
    

    
    return topic_model, svi

In [10]:
def trace_model(topic_model, dl_train):
    batch_1 = next(iter(dl_train))
    dat_a = batch_1['a'].to(DEVICE).squeeze()
    dat_b = batch_1['b'].to(DEVICE).squeeze()
    dat_label = batch_1['label'].to(DEVICE).squeeze().type(torch.float32)
    dat_observed = batch_1['observed'].to(DEVICE).squeeze().type(torch.bool)

    print(pyro.poutine.trace(topic_model.model).get_trace(dat_a, dat_b, dat_label, dat_observed).format_shapes())
    print(pyro.poutine.trace(topic_model.guide).get_trace(dat_a, dat_b, dat_label, dat_observed).format_shapes())

In [11]:
def get_map(topic_model):
    x = torch.tensor(pwd.bows_val.toarray()).to(DEVICE).type(torch.float)
    with torch.no_grad():
        z_loc, z_scale = topic_model.encoder(x)

    features = z_loc.cpu()
    features = torch.softmax(features, axis=-1)
    features = features.numpy()

    trans = umap.UMAP(n_neighbors=15, random_state=42, min_dist=0.1).fit(features)
    return trans

## Training

In [12]:
train_elbo = []
test_elbo = []
test_celoss = []
test_klloss = []

print("Beginning Training")

# training loop

for p_percentage in [1.0]:
    print(f"Training for percentage: {p_percentage}")
    best_elbo = 10**10
    run = wandb.init(project="prodLDA_pairs", entity="witw", )
    wandb.run.name = f"pretrained_unfrozen_2000vocab_percentage{p_percentage}_2" 
    pwd, dl_train, dl_val, dl_test = get_data(p_percentage)
    
    topic_model, svi = init_model(pretrained=True, frozen=False, pwd=pwd)
    wandb.watch(topic_model, log=['gradients', 'parameters'], log_freq=100)
    trace_model(topic_model, dl_train=dl_train)
    for epoch in range(NUM_EPOCHS):
        trans = get_map(topic_model)
        total_epoch_loss_train, _, _ = train(svi, topic_model, dl_train, device=DEVICE, print_debug=True, epoch=epoch, latest_plotted_p=latest_plotted_p, trans = trans)
        train_elbo.append(-total_epoch_loss_train)

        wandb.log({'epoch': epoch,
                   'train_elbo': total_epoch_loss_train})

        print("[epoch %03d]  average training loss: %.4f" % (epoch, total_epoch_loss_train))

        if epoch % TEST_FREQUENCY == 0:
            # report test diagnostics
            total_epoch_loss_test, total_epoch_celoss_test, total_epoch_klloss_test, test_acc, test_auroc = evaluate(svi, topic_model, dl_val, device=DEVICE)
            test_elbo.append(-total_epoch_loss_test)
            test_celoss.append(total_epoch_celoss_test)
            test_klloss.append(total_epoch_klloss_test)
    #         print(x)
            wandb.log({'epoch': epoch,
                       'test_acc': test_acc,
                       'test_auroc': test_auroc,
                       'test_elbo': total_epoch_loss_test,
                       'test_entropy': total_epoch_celoss_test,
                       'test_kl': total_epoch_klloss_test})

            print("\nEvaluation: ")
            print("[epoch %03d]  average elbo loss: %.4f" % (epoch, total_epoch_loss_test))
            print("              average ce loss:   %.4f" % (total_epoch_celoss_test))
            print("              average kld loss:  %.4f" % (total_epoch_klloss_test))
            print("              average accuracy:     %.4f" % (test_acc))
            print("              average auroc:  %.4f\n" % (test_auroc))
            if total_epoch_loss_test < best_elbo:
                best_elbo = total_epoch_loss_test
                torch.save({"model" : topic_model.state_dict()}, f"../data/bestmodel2000P{p_percentage}_unfrozen_2.pt")
                pyro.get_param_store().save(f"../data/bestmodelparams2000P{p_percentage}_unfrozen_2.pt")


Beginning Training
Training for percentage: 1.0


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mwitw[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.13.4 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade

CondaEnvException: Unable to determine environment

Please re-run this command with one of the following options:

* Provide an environment name via --name or -n
* Re-run this command inside an activated conda environment.



        Trace Shapes:             
         Param Sites:             
decoder$$$beta.weight 2000 50     
        Sample Sites:             
     documents_a dist       |     
                value  128  |     
      logtheta_a dist  128  |   50
                value  128  |   50
           obs_a dist  128  | 2000
                value  128  | 2000
     documents_b dist       |     
                value  128  |     
      logtheta_b dist  128  |   50
                value  128  |   50
           obs_b dist  128  | 2000
                value  128  | 2000
      prototypes dist    7  |   50
                value    7  |   50
        Trace Shapes:             
         Param Sites:             
 encoder$$$fc1.weight 128 2000    
   encoder$$$fc1.bias      128    
 encoder$$$fc2.weight 128  128    
   encoder$$$fc2.bias      128    
encoder$$$fcmu.weight  50  128    
  encoder$$$fcmu.bias       50    
encoder$$$fclv.weight  50  128    
  encoder$$$fclv.bias       50    
                    

In [13]:
torch.save({"model" : topic_model.state_dict()}, f"../data/bestmodel2000P{p_percentage}_unfrozen_2.pt")

In [14]:
svi.guide

<bound method ProdLDA.guide of ProdLDA(
  (encoder): Encoder(
    (drop): Dropout(p=0.2, inplace=False)
    (fc1): Linear(in_features=2000, out_features=128, bias=True)
    (fc2): Linear(in_features=128, out_features=128, bias=True)
    (fcmu): Linear(in_features=128, out_features=50, bias=True)
    (fclv): Linear(in_features=128, out_features=50, bias=True)
    (bnmu): BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
    (bnlv): BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
  )
  (decoder): Decoder(
    (beta): Linear(in_features=50, out_features=2000, bias=False)
    (bn): BatchNorm1d(2000, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
    (drop): Dropout(p=0.2, inplace=False)
  )
)>