In [1]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
from warnings import simplefilter
from sklearn.exceptions import ConvergenceWarning
simplefilter("ignore", category=ConvergenceWarning)
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.decomposition import LatentDirichletAllocation as LDA
from sklearn.decomposition import NMF, PCA
from sklearn.linear_model import Ridge, LogisticRegression, Lasso
from sklearn.metrics import mean_squared_error as mse, roc_auc_score as roc, accuracy_score as acc, log_loss
from sklearn.neural_network import MLPClassifier, MLPRegressor
from sklearn.datasets import fetch_20newsgroups
import numpy as np
import pandas as pd
from data.dataset import BaseDataset
from scipy.sparse import csr_matrix
from importlib import reload
import itertools as it
import torch
from model.models import SparseVAESpikeSlab, VAE
import scipy
from torch.nn import functional as F
import matplotlib.pyplot as plt
from matplotlib import colors
from matplotlib.ticker import PercentFormatter
import seaborn as sns
from scipy.special import expit
import os

np.random.seed(42)

In [2]:
def negative_log_likelihood(logits, count):
    prob = logits/logits.sum(axis=1)[:,np.newaxis]
    return -(np.log(prob + 1e-7) * count).sum(axis=1).mean()

In [23]:
def col_normalize(x):
    return x/x.sum(axis=1)[:,np.newaxis]

In [24]:
def run_correlation_study(theta_sim, beta, true_expected, true_data, doc_n):
    p_words = theta_sim.dot(beta)
    expected_counts = doc_n[:,np.newaxis] * p_words
    expected_counts = np.around(expected_counts)
    
    valid = np.where(expected_counts.sum(axis=1)!=0)[0]
    expected_counts = expected_counts[valid,:]

    K=beta.shape[0]
    m = NMF(n_components=K)
    theta_hat = m.fit_transform(expected_counts)
    beta_hat = m.components_
    pred_tr = theta_hat.dot(beta_hat)

    npmi = get_normalized_pmi(beta_hat, true_data)
    
    theta_new = m.transform(true_expected)
    pred_new = theta_new.dot(beta_hat)
    
    tr_mse = mse(pred_tr, expected_counts)
    te_mse = mse(pred_new, true_expected)
    
    return npmi, tr_mse, te_mse, expected_counts

In [25]:
def make_expected_counts(theta_sim, beta, doc_n):
    p_words = theta_sim.dot(beta)
    expected_counts = doc_n[:,np.newaxis] * p_words
    expected_counts = np.around(expected_counts)
    
    valid = np.where(expected_counts.sum(axis=1)!=0)[0]
    expected_counts = expected_counts[valid,:]

    return expected_counts

In [26]:
def get_normalized_pmi(topics, counts, num_words=10):
    num_topics = topics.shape[0]
    per_topic_npmi = np.zeros(num_topics)
    tf = csr_matrix(counts)
    cooccurence = tf.T.dot(tf)
    cooccurence = cooccurence.toarray()
    
    count = counts.sum(axis=0)
    prob = count/count.sum()
    cooccurence_prob = cooccurence/cooccurence.sum()

    for k in range(num_topics):
        npmi_total = 0
        beta = topics[k,:]
        top_words = (-beta).argsort()[:num_words]
        n = 0 
        for (w1, w2) in it.combinations(top_words, 2):
            log_joint_prob = np.log(cooccurence_prob[w1][w2]+1e-7)
            log_prob_w1 = np.log(prob[w1]+1e-7)
            log_prob_w2 = np.log(prob[w2]+1e-7)
            pmi = log_joint_prob - (log_prob_w1+log_prob_w2)
            normalizer = -log_joint_prob
            npmi_total += pmi/normalizer
            n+=1
        per_topic_npmi[k] = npmi_total / n
    return per_topic_npmi.mean()

In [27]:
def make_simulated_theta(sigma, theta):
    n_col = theta.shape[1]//2
    log_theta = np.log(theta)    ##logits
    log_theta_real = log_theta[:,:n_col]
    log_theta_simulated = np.zeros((log_theta.shape[0], n_col))
    for k in range(n_col):
        noise = (sigma*np.random.randn(log_theta.shape[0]))
        log_theta_simulated[:,k] = log_theta[:,k] + noise
    
    log_theta_corr = np.hstack([log_theta_real, log_theta_simulated])
    theta_corr = col_normalize(np.exp(log_theta_corr))
    
    violations = 0
    cov = np.cov(theta_corr.T)
    diag = np.diag(cov)
    for k in range(diag.shape[0]):
        max_cov = np.max(cov[k, :])
        if max_cov != diag[k]:
            violations+=1
    print("No. violations of cov condition:", violations)
    
    return theta_corr

In [28]:
proc_file = '../../dat/proc/peerread_small_proc.npz'
arr = np.load(proc_file)
data = arr['data']
vocab = arr['metadata']
terms_total = data.sum(axis=1)
data.shape, vocab.shape

((11777, 500), (500,))

In [30]:
pretrained_file = '../../dat/proc/peerread_small_pretraining.npz'
arr = np.load(pretrained_file)
theta = arr['theta']
beta = arr['beta']
theta.shape, beta.shape

((11777, 20), (20, 500))

In [31]:
for k in range(beta.shape[0]):
    top_words = (-beta[k]).argsort()[:20]
    topic_words = [vocab[t] for t in top_words]
    print('Topic {}: {}'.format(k, topic_words))

print("NPMI:", get_normalized_pmi(beta, data))

Topic 0: ['user', 'feature', 'task', 'online', 'item', 'group', 'method', 'based', 'approach', 'propose', 'detection', 'selection', 'signal', 'system', 'proposed', 'information', 'preference', 'problem', 'paper', 'model']
Topic 1: ['system', 'event', 'process', 'paper', 'information', 'ontology', 'application', 'tool', 'web', 'present', 'knowledge', 'approach', 'software', 'research', 'used', 'ha', 'development', 'semantic', 'based', 'domain']
Topic 2: ['graph', 'constraint', 'program', 'programming', 'set', 'problem', 'solver', 'node', 'map', 'variable', 'structure', 'algorithm', 'show', 'model', 'instance', 'technique', 'paper', 'approach', 'solving', 'present']
Topic 3: ['method', 'function', 'algorithm', 'problem', 'matrix', 'data', 'learning', 'kernel', 'optimization', 'linear', 'proposed', 'sparse', 'approach', 'propose', 'show', 'regression', 'paper', 'point', 'result', 'loss']
Topic 4: ['model', 'learning', 'data', 'machine', 'topic', 'approach', 'latent', 'framework', 'applica

In [32]:
theta = col_normalize(theta)
true_p_words = theta.dot(beta)

true_expected = terms_total[:,np.newaxis] * true_p_words
true_expected = np.around(true_expected)
valid = np.where(true_expected.sum(axis=1)!=0)[0]
true_expected = true_expected[valid, :]

In [33]:
for i in range(1):
    print("Working on experiment", i)
    sigmas = [0.5, 1.0, 2.0]
    for sigma in sigmas:
        print("Working on sigma", sigma, "...")
        theta_sim = make_simulated_theta(sigma, theta)

        npmi, tr_mse, te_mse, sim_data = run_correlation_study(theta_sim, beta, true_expected, data, terms_total)
        
        outdir = '../dat/intervened_n=5/' + str(i) + '/'
        os.makedirs(outdir, exist_ok=True)
        fname = outdir + '/sigma=' + str(round(sigma,2))
        np.savez_compressed(fname, theta_sim=theta_sim, sim_obs=sim_data, orig_obs=true_expected, features=vocab)

        print("NPMI, Training loss, intervened data loss:", npmi, tr_mse, te_mse)

Working on experiment 0
Working on sigma 0.5 ...
No. violations of cov condition: 0




NPMI, Training loss, intervened data loss: 0.015518007274422232 0.005455749828197064 0.011101625423968116
Working on sigma 1.0 ...
No. violations of cov condition: 0
NPMI, Training loss, intervened data loss: 0.022688744095815037 0.006191162399490585 0.0079658154641498
Working on sigma 2.0 ...
No. violations of cov condition: 0
NPMI, Training loss, intervened data loss: 0.02921255305518599 0.006422207465129802 0.007144053227221853


In [37]:
for i in range(1):
    print("Working on experiment", i)
    sigmas = [0.5, 1.0, 3.0]
    for sigma in sigmas:
        print("Working on sigma", sigma, "...")
        tr_theta_sim = make_simulated_theta(sigma, theta)
        te_theta_sim = make_simulated_theta(sigma, theta)
        
        tr_data = make_expected_counts(tr_theta_sim, beta, terms_total)
        te_data = make_expected_counts(te_theta_sim, beta, terms_total)
        
        outdir = '../../dat/intervened_holdout' 
        os.makedirs(outdir, exist_ok=True)
        fname = outdir + '/sigma=' + str(round(sigma,2))
        np.savez_compressed(fname, theta_sim=te_theta_sim, sim_obs=tr_data, orig_obs=te_data, features=vocab)

Working on experiment 0
Working on sigma 0.5 ...
No. violations of cov condition: 0
No. violations of cov condition: 0
Working on sigma 1.0 ...
No. violations of cov condition: 0
No. violations of cov condition: 0
Working on sigma 3.0 ...
No. violations of cov condition: 0
No. violations of cov condition: 0
