In [1]:
# base_model_params = "../data/pretrained_models/pretrained_lda_20ng_2000vocab.pt"
# base_param_store = "../data/pretrained_models/pretrained_params_20ng_2000vocab.pt"

base_model_params = "../data/pretrained_models/pretrained_lda_movies.pt"
# base_param_store = "../data/pretrained_models/pretrained_params_movies.pt"

# trained_model_params = "../data/20ng_checkpoints/lastmodel_kld2_20ng_a=1000.pt"
# trained_param_store = "../data/bestmodelparams_kld2_20ng_noparams.pt"

trained_model_params = "../data/movie_checkpoints/lastmodel_kld2_movies_a=1000.pt"
# trained_param_store = "../data/bestmodelparams_kld2_movies_a45lr0.01.pt"

In [2]:
from imdb_dataset import PairwiseData, DocumentPairData
from model_kld import Encoder, Decoder, ProdLDA, LossClass
import os
import pickle as pkl
from typing import List

import numpy as np
import torch
import graphviz
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import pandas as pd


import pyro
from pyro import poutine
import pyro.distributions as dist

import pandas as pd
import matplotlib.pyplot as plt
import umap
import umap.plot
from sklearn.metrics import roc_auc_score

plt.rcParams["figure.figsize"] = (10,5)

In [3]:
pyro.distributions.enable_validation(True)
pyro.set_rng_seed(0)
pd.set_option('display.max_rows', 100)

In [4]:
PAIR_PERCENTAGE = 1.0
BATCH_SIZE = 128
VAL_BATCH_SIZE = 2000
NUM_WORKERS = 0

NUM_TOPICS = 50
EMBED_DIM  = 64
HIDDEN_DIM = 128
DROPOUT_RATE = 0.2

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
# pairwise data
pwd = PairwiseData()
train_pairs, val_pairs, test_pairs = [pwd.get_pairs_table(d) for d in [pwd.train, pwd.val, pwd.test]]

# datasets
data_train, data_val, data_test = [
    DocumentPairData(bows=pwd.bows, index_table=ix_table, prob=PAIR_PERCENTAGE)
    for ix_table in [train_pairs, val_pairs, test_pairs]
]

# dataloaders
dl_train = DataLoader(data_train, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True)
dl_val = DataLoader(data_val, batch_size=VAL_BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=False)
dl_test = DataLoader(data_test, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=False)

In [6]:
NUM_MINI_TEST = 5
TOP_N = 20
tok2ix = pwd.vectorizer.vocabulary_
ix2tok = {v:k for k,v in tok2ix.items()}

NUM_MINI_TOPICS = 5
START_TOPIC_I = 5
TOP_N_TOPIC = 15

In [7]:
def evaluateAUROC(topic_model):
    lossClass = LossClass(num_topics=NUM_TOPICS, alpha=1.0, device=DEVICE)
    
    klds = []
    klds_similar = []
    klds_dissimilar = []
    labels = []
    
    for x in dl_test:
        # if on GPU put mini-batch into CUDA memory
        x_a = x['a'].to(DEVICE).squeeze()
        x_b = x['b'].to(DEVICE).squeeze()
        x_label = x['label'].to(DEVICE).type(torch.int)
        x_observed = x['observed'].to(DEVICE).type(torch.bool)

        model_trace = pyro.poutine.trace(topic_model.model).get_trace(x_a, x_b, x_label, x_observed)
        
        theta_a = model_trace.nodes['theta_a']['value']
        theta_b = model_trace.nodes['theta_b']['value']
        c = model_trace.nodes['c']['value'].type(torch.bool)
        
        kld = lossClass.get_kld(theta_a, theta_b).cpu()
        klds += kld
        labels += (-x_label.cpu())
        
        for ix, label in enumerate(x_label.cpu().numpy()): 
            print(f"LABEL: {label} : {kld[ix]}")
            if label:
                klds_similar.append(kld[ix])
            else:
                klds_dissimilar.append(kld[ix])
                
    print(f"Average KLD for similar pairs = {sum(klds_similar) / len(klds_similar)}")
    print(f"Average KLD for dissimilar pairs = {sum(klds_dissimilar) / len(klds_dissimilar)}")
    
    return roc_auc_score(labels, klds)

In [8]:
def getPretrainedEvaluations(model_params):
    pyro.clear_param_store()
    # load in the pretrained model
    topic_model = ProdLDA(pwd.vocab_size, NUM_TOPICS, HIDDEN_DIM, DROPOUT_RATE, DEVICE).to(DEVICE)
    saved_model_dict = torch.load(model_params)
    topic_model.load_state_dict(saved_model_dict['model'])
    # pyro.get_param_store().load(param_store)
    
    # get topics
    betas = topic_model.beta()
    top_term_ixs = betas.argsort(axis=-1, descending=True)[:,:TOP_N_TOPIC]

    auroc = evaluateAUROC(topic_model)
    
    topic_terms = []
    for jx, topic in enumerate(top_term_ixs):
        terms = [(ix2tok[ix.item()]) for ix in topic]
        topic_terms.append(terms)
    
    x = torch.tensor(pwd.bows_val.toarray()).to(DEVICE).type(torch.float)
    with torch.no_grad():
        z_loc, z_scale = topic_model.encoder(x)

    features = z_loc.cpu()
    features = torch.softmax(features, axis=-1)
    features = features.numpy()

    trans = umap.UMAP(n_neighbors=15, random_state=42, min_dist=0.1).fit(features)
    
    return betas, topic_terms, trans, auroc

In [9]:
# betas_before, topic_terms_before, trans_before, auroc_before = getPretrainedEvaluations(base_model_params)
betas_after, topic_terms_after, trans_after, auroc_after = getPretrainedEvaluations(trained_model_params)

word_dists_before = torch.softmax(betas_before, axis=-1)
word_dists_after = torch.softmax(betas_after, axis=-1)

kld = -(word_dists_after*torch.log(word_dists_before)-word_dists_after*torch.log(word_dists_after)).sum(axis=-1)
kld = kld.tolist()
kld_sorted_indices = np.argsort(kld)[::-1]

df_topics_before = pd.DataFrame(topic_terms_before)
df_topics_after = pd.DataFrame(topic_terms_after)

df_compare = pd.DataFrame()
for ix in kld_sorted_indices:
    b_temp = df_topics_before.iloc[[ix]]
    b_temp.insert(0, "ORIGIN", 'BEFORE', True)
    a_temp = df_topics_after.iloc[[ix]]
    a_temp.insert(0, "ORIGIN", 'AFTER', True)

    df_compare = df_compare.append(b_temp)
    df_compare = df_compare.append(a_temp)

LABEL: 1 : 0.26854103803634644
LABEL: 1 : 0.24817663431167603
LABEL: 0 : 0.2604092061519623
LABEL: 0 : 0.22099290788173676
LABEL: 0 : 0.21837346255779266
LABEL: 0 : 0.28948384523391724
LABEL: 1 : 0.24880598485469818
LABEL: 0 : 0.24018672108650208
LABEL: 0 : 0.2273484170436859
LABEL: 1 : 0.16579052805900574
LABEL: 0 : 0.2792280614376068
LABEL: 0 : 0.246687114238739
LABEL: 0 : 0.24541828036308289
LABEL: 0 : 0.23688894510269165
LABEL: 1 : 0.16825422644615173
LABEL: 1 : 0.2979174852371216
LABEL: 0 : 0.22195672988891602
LABEL: 0 : 0.17204587161540985
LABEL: 0 : 0.3143978416919708
LABEL: 0 : 0.27041274309158325
LABEL: 0 : 0.15214745700359344
LABEL: 0 : 0.23942776024341583
LABEL: 0 : 0.28354811668395996
LABEL: 0 : 0.2760602831840515
LABEL: 0 : 0.21377235651016235
LABEL: 0 : 0.2652031183242798
LABEL: 0 : 0.2643260955810547
LABEL: 1 : 0.16743895411491394
LABEL: 1 : 0.273166298866272
LABEL: 0 : 0.2278013378381729
LABEL: 0 : 0.2902238965034485
LABEL: 1 : 0.23565812408924103
LABEL: 0 : 0.331967025

KeyboardInterrupt: 

In [None]:
auroc_before

In [None]:
auroc_after

In [None]:
df_compare

In [None]:
df_topics_after

In [None]:
# df_compare.to_csv(trained_model_params.split("/")[3].split(".")[0] + '.csv')

In [None]:
umap.plot.points(trans_before, labels=pwd.val.stars_category,theme='fire')

In [None]:
umap.plot.points(trans_after, labels=pwd.val.stars_category,theme='fire')