From 2966493f86a6f808f0dfa71d590e3403a840befc Mon Sep 17 00:00:00 2001 From: Dhanya Sridhar Date: Thu, 5 Mar 2020 14:41:33 -0500 Subject: [PATCH] adding causal amortized topic model --- .../add_split_to_simulations.ipynb | 61 ++ src/supervised_lda/compute_estimates.py | 51 + src/supervised_lda/helpers.py | 87 ++ src/supervised_lda/peerread_output_att.py | 133 +++ src/supervised_lda/reddit_output_att.py | 153 +++ src/supervised_lda/run_supervised_tm.py | 95 ++ .../peerread-exps/run_peerread_simulation.sh | 18 + .../peerread-exps/submit_no_sup.sh | 23 + .../peerread-exps/submit_no_unsup.sh | 23 + .../peerread-exps/submit_nonlinear.sh | 23 + .../submit_peerread_simulation.sh | 23 + .../reddit-exps/run_reddit_simulation.sh | 20 + .../reddit-exps/submit_no_sup.sh | 29 + .../reddit-exps/submit_no_unsup.sh | 30 + .../reddit-exps/submit_nonlinear.sh | 29 + .../reddit-exps/submit_reddit_simulation.sh | 29 + .../reddit-exps/submit_reddit_test.sh | 29 + src/supervised_lda/supervised_topic_model.py | 193 ++++ src/supervised_lda/test_slda.ipynb | 870 ++++++++++++++++++ 19 files changed, 1919 insertions(+) create mode 100644 src/supervised_lda/add_split_to_simulations.ipynb create mode 100644 src/supervised_lda/compute_estimates.py create mode 100644 src/supervised_lda/helpers.py create mode 100644 src/supervised_lda/peerread_output_att.py create mode 100644 src/supervised_lda/reddit_output_att.py create mode 100644 src/supervised_lda/run_supervised_tm.py create mode 100755 src/supervised_lda/submit_scripts/peerread-exps/run_peerread_simulation.sh create mode 100755 src/supervised_lda/submit_scripts/peerread-exps/submit_no_sup.sh create mode 100755 src/supervised_lda/submit_scripts/peerread-exps/submit_no_unsup.sh create mode 100755 src/supervised_lda/submit_scripts/peerread-exps/submit_nonlinear.sh create mode 100755 src/supervised_lda/submit_scripts/peerread-exps/submit_peerread_simulation.sh create mode 100755 src/supervised_lda/submit_scripts/reddit-exps/run_reddit_simulation.sh create mode 100755 src/supervised_lda/submit_scripts/reddit-exps/submit_no_sup.sh create mode 100755 src/supervised_lda/submit_scripts/reddit-exps/submit_no_unsup.sh create mode 100755 src/supervised_lda/submit_scripts/reddit-exps/submit_nonlinear.sh create mode 100755 src/supervised_lda/submit_scripts/reddit-exps/submit_reddit_simulation.sh create mode 100755 src/supervised_lda/submit_scripts/reddit-exps/submit_reddit_test.sh create mode 100644 src/supervised_lda/supervised_topic_model.py create mode 100644 src/supervised_lda/test_slda.ipynb diff --git a/src/supervised_lda/add_split_to_simulations.ipynb b/src/supervised_lda/add_split_to_simulations.ipynb new file mode 100644 index 0000000..a059c80 --- /dev/null +++ b/src/supervised_lda/add_split_to_simulations.ipynb @@ -0,0 +1,61 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import numpy as np\n", + "import pandas as pd" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "base_sim_dir = '../../dat/sim/'\n", + "datasets = ['reddit_subreddit_based/subreddits[13, 6, 8]', 'peerread_buzzytitle_based']\n", + "mode = 'modesimple'\n", + "\n", + "for dataset in datasets:\n", + " simdir = os.path.join(base_sim_dir, dataset, mode)\n", + " for simfile in os.listdir(simdir):\n", + " df = pd.read_csv(os.path.join(simdir, simfile), sep='\\t')\n", + " df['split'] = np.random.randint(0, 10, size=df.shape[0])\n", + " df.to_csv(os.path.join(simdir, simfile),sep='\\t')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/supervised_lda/compute_estimates.py b/src/supervised_lda/compute_estimates.py new file mode 100644 index 0000000..1d21277 --- /dev/null +++ b/src/supervised_lda/compute_estimates.py @@ -0,0 +1,51 @@ +from semi_parametric_estimation.att import att_estimates +import numpy as np +import os +import argparse +import pandas as pd + +def main(): + outdir = os.path.join('..', 'out', args.data, args.experiment) + for sim in os.listdir(outdir): + mean_estimates = {'very_naive': [], 'q_only': [], 'plugin': [], 'one_step_tmle': [], 'aiptw': []} + for split in os.listdir(os.path.join(outdir, sim)): + if args.num_splits is not None: + # print("ignoring split", split) + if int(split) >= int(args.num_splits): + continue + array = np.load(os.path.join(outdir, sim, split, 'predictions.npz')) + g = array['g'] + q0 = array['q0'] + q1 = array['q1'] + y = array['y'] + t = array['t'] + estimates = att_estimates(q0, q1, g, t, y, t.mean(), truncate_level=0.03) + for est, att in estimates.items(): + mean_estimates[est].append(att) + + if args.data == 'reddit': + sim = sim.replace('beta01.0.', '') + options = sim.split('.0.') + p2 = options[0].replace('beta1', '') + p3 = options[1].replace('gamma', '') + + print("------ Simulation setting: Confounding strength =", p2, "; Variance:", p3, "------") + print("True effect = 1.0") + else: + ground_truth_map = {'1.0':0.06, '5.0':0.06, '25.0':0.03} + print("------ Simulation setting: Confounding strength =", sim) + print("True effect = ", ground_truth_map[sim]) + + + for est, atts in mean_estimates.items(): + print('\t', est, np.round(np.mean(atts), 3), "+/-", np.round(np.std(atts),3)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--data", action="store", default="reddit") + parser.add_argument("--experiment", action="store", default="base_model") + parser.add_argument("--num-splits", action="store", default=None) + args = parser.parse_args() + + main() \ No newline at end of file diff --git a/src/supervised_lda/helpers.py b/src/supervised_lda/helpers.py new file mode 100644 index 0000000..f10653f --- /dev/null +++ b/src/supervised_lda/helpers.py @@ -0,0 +1,87 @@ +from nltk.tokenize import word_tokenize +from nltk.stem import WordNetLemmatizer +from nltk.corpus import stopwords +from sklearn.feature_extraction.text import CountVectorizer +import numpy as np +import pandas as pd +from sklearn.decomposition import LatentDirichletAllocation + +class LemmaTokenizer(object): + def __init__(self): + self.wnl = WordNetLemmatizer() + def __call__(self, articles): + stop = stopwords.words('english') + return [self.wnl.lemmatize(t) for t in word_tokenize(articles) if t.isalpha() and t not in stop] + +def filter_by_subreddit(reddit, subs=None): + if not subs: + return reddit.index.values + else: + return reddit[reddit.subreddit.isin(subs)].index.values + +def tokenize_documents(documents,max_df0=0.9, min_df0=0.0005): + from nltk.corpus import stopwords + ''' + From a list of documents raw text build a matrix DxV + D: number of docs + V: size of the vocabulary, i.e. number of unique terms found in the whole set of docs + ''' + count_vect = CountVectorizer(tokenizer=LemmaTokenizer(), max_df=max_df0, min_df=min_df0) + corpus = count_vect.fit_transform(documents) + vocabulary = count_vect.get_feature_names() + + return corpus,vocabulary,count_vect + +def assign_dev_split(num_docs, percentage=0.05): + indices = np.arange(num_docs) + np.random.shuffle(indices) + size = int(indices.shape[0]*percentage) + dev = indices[:size] + return dev + +def learn_topics(X, X_dev, K=50): + lda = LatentDirichletAllocation(n_components=K, learning_method='online', verbose=1) + print("Fitting", K, "topics...") + lda.fit(X) + score = lda.perplexity(X_dev) + print("Log likelihood:", score) + topics = lda.components_ + return score, lda, topics + +def show_topics(vocab, topics, n_words=20): + topic_keywords = [] + for topic_weights in topics: + top_keyword_locs = (-topic_weights).argsort()[:n_words] + topic_keywords.append(vocab.take(top_keyword_locs)) + + df_topic_keywords = pd.DataFrame(topic_keywords) + df_topic_keywords.columns = ['Word '+str(i) for i in range(df_topic_keywords.shape[1])] + df_topic_keywords.index = ['Topic '+str(i) for i in range(df_topic_keywords.shape[0])] + return df_topic_keywords + +def filter_document_embeddings(filtered_df, doc_embeddings, index_mapping, on='post_index'): + filtered_indices = filtered_df[on].values + doc_idx = [index_mapping[idx] for idx in filtered_indices] + embeddings = doc_embeddings[doc_idx, :] + return embeddings + +def filter_document_terms(filtered_df, counts, index_mapping, on='post_index'): + filtered_indices = filtered_df[on].values + doc_idx = [index_mapping[idx] for idx in filtered_indices] + filtered_counts = counts[doc_idx, :] + return filtered_counts + +def make_index_mapping(df, on='post_index', convert_to_int=True): + if on=='index': + indices = df.index.values + else: + indices = df[on].values + + if convert_to_int: + return {int(ind):i for (i,ind) in enumerate(indices)} + + return {ind:i for (i,ind) in enumerate(indices)} + +def assign_split(df, num_splits=10, col_to_add='split'): + df[col_to_add] = np.random.randint(0, num_splits, size=df.shape[0]) + return df diff --git a/src/supervised_lda/peerread_output_att.py b/src/supervised_lda/peerread_output_att.py new file mode 100644 index 0000000..85bde84 --- /dev/null +++ b/src/supervised_lda/peerread_output_att.py @@ -0,0 +1,133 @@ +from semi_parametric_estimation.att import att_estimates +from supervised_lda.helpers import filter_document_terms, make_index_mapping, assign_split, tokenize_documents +import numpy as np +import pandas as pd +import os +from sklearn.metrics import mean_squared_error as mse +import argparse +import sys +from supervised_lda.supervised_topic_model import SupervisedTopicModel +from supervised_lda import run_supervised_tm +from scipy import sparse +from sklearn.linear_model import LogisticRegression, Ridge +from scipy.special import logit + +def load_peerread(path='../dat/PeerRead/'): + return pd.read_csv(path + 'proc_abstracts.csv') + +def load_term_counts(df, path='../dat/PeerRead/', force_redo=False, text_col='abstract_text'): + count_filename = path + 'term_counts' + vocab_filename = path + 'vocab' + + if os.path.exists(count_filename + '.npz') and not force_redo: + return sparse.load_npz(count_filename + '.npz').toarray(), np.load(vocab_filename + '.npy') + + post_docs = df[text_col].values + counts, vocab, _ = tokenize_documents(post_docs) + sparse.save_npz(count_filename, counts) + np.save(vocab_filename, vocab) + return counts.toarray(), np.array(vocab) + +def compute_ground_truth_treatment_effect(df): + y1 = df['y1'] + y0 = df['y0'] + return y1.mean() - y0.mean() + +def load_simulated_data(): + sim_df = pd.read_csv(simulation_file, delimiter='\t') + return sim_df + +def fit_model(doc_embeddings, labels, is_binary=False): + if is_binary: + model = LogisticRegression(solver='liblinear') + else: + model = Ridge() + model.fit(doc_embeddings, labels) + return model + +def main(): + if dat_dir: + peerread = load_peerread(path=dat_dir) + counts,vocab = load_term_counts(peerread,path=dat_dir) + else: + peerread = load_peerread() + counts,vocab = load_term_counts(peerread) + + indices = peerread['paper_id'].values + index_mapping = make_index_mapping(peerread, on='index') + + sim_df = load_simulated_data() + + train_df = sim_df[sim_df.split != split] + predict_df = sim_df[sim_df.split == split] + tr_treatment_labels = train_df.treatment.values + tr_outcomes = train_df.outcome.values + predict_treatment = predict_df.treatment.values + predict_outcomes = predict_df.outcome.values + + tr_counts = filter_document_terms(train_df, counts, index_mapping, on='id') + predict_counts = filter_document_terms(predict_df, counts, index_mapping, on='id') + + num_documents = tr_counts.shape[0] + vocab_size = tr_counts.shape[1] + model = SupervisedTopicModel(num_topics, vocab_size, num_documents, outcome_linear_map=linear_outcome_model) + + run_supervised_tm.train(model, tr_counts, tr_treatment_labels, tr_outcomes, dtype='binary', + num_epochs=num_iters, use_recon_loss=use_recon_loss, use_sup_loss=use_supervised_loss) + + if use_supervised_loss: + propensity_score, expected_outcome_treat, expected_outcome_no_treat = run_supervised_tm.predict(model, predict_counts, dtype='binary') + else: + tr_doc_embeddings = run_supervised_tm.get_representation(model, tr_counts) + treated = tr_treatment_labels == 1 + out_treat = tr_outcomes[treated] + out_no_treat = tr_outcomes[~treated] + q0_embeddings = tr_doc_embeddings[~treated,:] + q1_embeddings = tr_doc_embeddings[treated,:] + q0_model = fit_model(q0_embeddings, out_no_treat, is_binary=True) + q1_model = fit_model(q1_embeddings, out_treat, is_binary=True) + g_model = fit_model(tr_doc_embeddings, tr_treatment_labels, is_binary=True) + + pred_doc_embeddings = run_supervised_tm.get_representation(model, predict_counts) + propensity_score = g_model.predict_proba(pred_doc_embeddings)[:,1] + expected_outcome_no_treat = q0_model.predict_proba(pred_doc_embeddings)[:,1] + expected_outcome_treat = q1_model.predict_proba(pred_doc_embeddings)[:,1] + + out = os.path.join(outdir, str(split)) + os.makedirs(out, exist_ok=True) + outfile = os.path.join(out, 'predictions') + np.savez_compressed(outfile, g=propensity_score, q0=expected_outcome_no_treat, q1=expected_outcome_treat, t=predict_treatment, y=predict_outcomes) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--dat-dir", action="store", default=None) + parser.add_argument("--outdir", action="store", default='../out/') + parser.add_argument("--sim-dir", action="store", default='../dat/sim/peerread_buzzytitle_based/') + parser.add_argument("--mode", action="store", default="simple") + parser.add_argument("--params", action="store", default="1.0") + parser.add_argument("--verbose", action='store_true') + parser.add_argument("--split", action='store', default=0) + parser.add_argument("--num-iters", action="store", default=3000) + parser.add_argument("--num-topics", action='store', default=100) + parser.add_argument("--linear-outcome-model", action='store', default="t") + parser.add_argument("--use-recon-loss", action='store', default="t") + parser.add_argument("--use-supervised-loss", action='store', default="t") + args = parser.parse_args() + + sim_dir = args.sim_dir + outdir = args.outdir + dat_dir = args.dat_dir + verbose = args.verbose + params = args.params + sim_setting = 'beta00.25' + '.beta1' + params + '.gamma0.0' + mode = args.mode + simulation_file = sim_dir + '/mode' + mode + '/' + sim_setting + ".tsv" + num_topics = args.num_topics + split = int(args.split) + linear_outcome_model = True if args.linear_outcome_model == "t" else False + use_supervised_loss = True if args.use_supervised_loss == "t" else False + use_recon_loss = True if args.use_recon_loss == "t" else False + num_iters = int(args.num_iters) + print(use_supervised_loss, use_recon_loss, linear_outcome_model) + + main() \ No newline at end of file diff --git a/src/supervised_lda/reddit_output_att.py b/src/supervised_lda/reddit_output_att.py new file mode 100644 index 0000000..18119d7 --- /dev/null +++ b/src/supervised_lda/reddit_output_att.py @@ -0,0 +1,153 @@ +from semi_parametric_estimation.att import att_estimates +from reddit.data_cleaning.reddit_posts import load_reddit_processed +from supervised_lda.helpers import filter_document_terms, make_index_mapping, assign_split, tokenize_documents +import numpy as np +import pandas as pd +import os +from supervised_lda.supervised_topic_model import SupervisedTopicModel +from sklearn.linear_model import LogisticRegression, Ridge +from supervised_lda import run_supervised_tm +from sklearn.metrics import mean_squared_error as mse +import argparse +import sys +from scipy.special import logit +from scipy import sparse + +def load_term_counts(reddit, path='../dat/reddit/', force_redo=False): + count_filename = path + 'term_counts' + vocab_filename = path + 'vocab' + + if os.path.exists(count_filename + '.npz') and not force_redo: + return sparse.load_npz(count_filename + '.npz').toarray(), np.load(vocab_filename + '.npy') + + post_docs = reddit['post_text'].values + counts, vocab, _ = tokenize_documents(post_docs) + sparse.save_npz(count_filename, counts) + np.save(vocab_filename, vocab) + return counts.toarray(), np.array(vocab) + +def load_simulated_data(): + sim_df = pd.read_csv(simulation_file, delimiter='\t') + sim_df = sim_df.rename(columns={'index':'post_index'}) + return sim_df + +def drop_empty_posts(counts): + doc_terms = counts.sum(axis=1) + return doc_terms >= 5 + +def fit_model(doc_embeddings, labels, is_binary=False): + if is_binary: + model = LogisticRegression(solver='liblinear') + else: + model = Ridge() + model.fit(doc_embeddings, labels) + return model + +def main(): + if dat_dir: + reddit = load_reddit_processed(path=dat_dir) + else: + reddit = load_reddit_processed() + + if subs: + reddit = reddit[reddit.subreddit.isin(subs)] + reddit = reddit.dropna(subset=['post_text']) + + + index_mapping = make_index_mapping(reddit, on='orig_index') + if not dat_dir: + counts, vocab = load_term_counts(reddit) + else: + counts, vocab = load_term_counts(reddit, path=dat_dir) + + sim_df = load_simulated_data() + + train_df = sim_df[sim_df.split != split] + predict_df = sim_df[sim_df.split == split] + + tr_treatment_labels = train_df.treatment.values + tr_outcomes = train_df.outcome.values + predict_treatment = predict_df.treatment.values + predict_outcomes = predict_df.outcome.values + + tr_counts = filter_document_terms(train_df, counts, index_mapping) + predict_counts = filter_document_terms(predict_df, counts, index_mapping) + tr_valid = drop_empty_posts(tr_counts) + pred_valid = drop_empty_posts(predict_counts) + tr_counts = tr_counts[tr_valid, :] + predict_counts = predict_counts[pred_valid, :] + + tr_treatment_labels = tr_treatment_labels[tr_valid] + tr_outcomes = tr_outcomes[tr_valid] + predict_treatment = predict_treatment[pred_valid] + predict_outcomes = predict_outcomes[pred_valid] + + num_documents = tr_counts.shape[0] + vocab_size = tr_counts.shape[1] + model = SupervisedTopicModel(num_topics, vocab_size, num_documents, outcome_linear_map=linear_outcome_model) + + run_supervised_tm.train(model, tr_counts, tr_treatment_labels, tr_outcomes, num_epochs=num_iters, use_recon_loss=use_recon_loss, use_sup_loss=use_supervised_loss) + + if use_supervised_loss: + propensity_score, expected_outcome_treat, expected_outcome_no_treat = run_supervised_tm.predict(model, predict_counts) + else: + tr_doc_embeddings = run_supervised_tm.get_representation(model, tr_counts) + treated = tr_treatment_labels == 1 + out_treat = tr_outcomes[treated] + out_no_treat = tr_outcomes[~treated] + q0_embeddings = tr_doc_embeddings[~treated,:] + q1_embeddings = tr_doc_embeddings[treated,:] + q0_model = fit_model(q0_embeddings, out_no_treat) + q1_model = fit_model(q1_embeddings, out_treat) + g_model = fit_model(tr_doc_embeddings, tr_treatment_labels, is_binary=True) + + pred_doc_embeddings = run_supervised_tm.get_representation(model, predict_counts) + propensity_score = g_model.predict_proba(pred_doc_embeddings)[:,1] + expected_outcome_no_treat = q0_model.predict(pred_doc_embeddings) + expected_outcome_treat = q1_model.predict(pred_doc_embeddings) + + out = os.path.join(outdir, str(split)) + os.makedirs(out, exist_ok=True) + outfile = os.path.join(out, 'predictions') + np.savez_compressed(outfile, g=propensity_score, q0=expected_outcome_no_treat, q1=expected_outcome_treat, t=predict_treatment, y=predict_outcomes) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--dat-dir", action="store", default=None) + parser.add_argument("--outdir", action="store", default='../out/') + parser.add_argument("--sim-dir", action="store", default='../dat/sim/reddit_subreddit_based/') + parser.add_argument("--subs", action="store", default='13,6,8') + parser.add_argument("--mode", action="store", default="simple") + parser.add_argument("--params", action="store", default="1.0,1.0,1.0") + parser.add_argument("--verbose", action='store_true') + parser.add_argument("--num-topics", action='store', default=100) + parser.add_argument("--split", action='store', default=0) + parser.add_argument("--num-iters", action="store", default=4000) + # parser.add_argument("--num_splits", action='store', default=10) + parser.add_argument("--linear-outcome-model", action='store', default="t") + parser.add_argument("--use-recon-loss", action='store', default="t") + parser.add_argument("--use-supervised-loss", action='store', default="t") + args = parser.parse_args() + + sim_dir = args.sim_dir + dat_dir = args.dat_dir + outdir = args.outdir + subs = None + if args.subs != '': + subs = [int(s) for s in args.subs.split(',')] + verbose = args.verbose + params = args.params.split(',') + sim_setting = 'beta0' + params[0] + '.beta1' + params[1] + '.gamma' + params[2] + subs_string = ', '.join(args.subs.split(',')) + mode = args.mode + simulation_file = sim_dir + 'subreddits['+ subs_string + ']/mode' + mode + '/' + sim_setting + ".tsv" + num_iters = int(args.num_iters) + num_topics = int(args.num_topics) + split = int(args.split) + # num_splits = args.num_splits + linear_outcome_model = True if args.linear_outcome_model == "t" else False + use_supervised_loss = True if args.use_supervised_loss == "t" else False + use_recon_loss = True if args.use_recon_loss == "t" else False + + main() \ No newline at end of file diff --git a/src/supervised_lda/run_supervised_tm.py b/src/supervised_lda/run_supervised_tm.py new file mode 100644 index 0000000..c9388b1 --- /dev/null +++ b/src/supervised_lda/run_supervised_tm.py @@ -0,0 +1,95 @@ +from torch import nn, optim +from torch.nn import functional as F +import torch +# from torch.utils.tensorboard import SummaryWriter +import numpy as np +import argparse +from scipy.special import expit + +def visualize_topics(model, vocab, num_topics, num_words=10): + model.eval() + with torch.no_grad(): + print('#'*100) + print('Visualize topics...') + betas = model.alphas.t() #model.get_beta() + for k in range(num_topics): + beta = betas[k].detach().numpy() + top_words = beta.argsort()[-num_words:] + topic_words = vocab[top_words] + print('Topic {}: {}'.format(k, topic_words)) + +def get_representation(model, docs): + normalized = docs/docs.sum(axis=-1)[:,np.newaxis] + normalized_bow = torch.tensor(normalized, dtype=torch.float) + num_documents = docs.shape[0] + model.eval() + with torch.no_grad(): + doc_representation,_ = model.get_theta(normalized_bow) + embeddings = doc_representation.detach().numpy() + return embeddings + + +def predict(model, docs, dtype='real'): + normalized = docs/docs.sum(axis=-1)[:,np.newaxis] + normalized_bow = torch.tensor(normalized, dtype=torch.float) + num_documents = docs.shape[0] + + treatment_ones = torch.ones(num_documents) + treatment_zeros = torch.zeros(num_documents) + + model.eval() + with torch.no_grad(): + doc_representation,_ = model.get_theta(normalized_bow) + propensity_score = model.predict_treatment(doc_representation).squeeze().detach().numpy() + propensity_score = expit(propensity_score) + expected_outcome_treat = model.predict_outcome_st_treat(doc_representation, treatment_ones).squeeze().detach().numpy() + expected_outcome_no_treat = model.predict_outcome_st_no_treat(doc_representation, treatment_zeros).squeeze().detach().numpy() + + if dtype == 'binary': + expected_outcome_treat = expit(expected_outcome_treat) + expected_outcome_no_treat = expit(expected_outcome_no_treat) + + return propensity_score, expected_outcome_treat, expected_outcome_no_treat + +def train(model, docs, treatment_labels, outcomes, dtype='real', num_epochs=20000, lr=0.005, wdecay=1.2e-5,batch_size=1000, use_recon_loss=True, use_sup_loss=True): + optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wdecay) + num_documents = docs.shape[0] + indices = np.arange(num_documents) + np.random.shuffle(indices) + + for e_idx in range(num_epochs): + model.train() + k = e_idx%(num_documents//batch_size) + start_index = k*batch_size + end_index = (k+1)*batch_size + batch = indices[start_index:end_index] + docs_batch = docs[batch,:] + treatment_labels_batch = treatment_labels[batch] + outcomes_batch = outcomes[batch] + normalized_batch = docs_batch/docs_batch.sum(axis=1)[:,np.newaxis] + + outcome_labels = torch.tensor(outcomes_batch, dtype=torch.float) + treat_labels = torch.tensor(treatment_labels_batch, dtype=torch.float) + bow = torch.tensor(docs_batch, dtype=torch.float) + normalized_bow = torch.tensor(normalized_batch, dtype=torch.float) + + optimizer.zero_grad() + model.zero_grad() + + recon_loss, supervised_loss, kld_theta = model(bow, normalized_bow, treat_labels, outcome_labels,dtype=dtype, use_supervised_loss=use_sup_loss) + acc_kl_theta_loss = torch.sum(kld_theta).item() + acc_sup_loss = 0. + acc_loss = 0. + + total_loss = kld_theta #+ recon_loss + supervised_loss + if use_recon_loss: + acc_loss = torch.sum(recon_loss).item() + total_loss += 0.1*recon_loss + if use_sup_loss: + acc_sup_loss = torch.sum(supervised_loss).item() + total_loss += supervised_loss + + total_loss.backward() + optimizer.step() + + print("Acc. loss:", acc_loss, "KL loss.:", acc_kl_theta_loss, "Supervised loss:", acc_sup_loss) \ No newline at end of file diff --git a/src/supervised_lda/submit_scripts/peerread-exps/run_peerread_simulation.sh b/src/supervised_lda/submit_scripts/peerread-exps/run_peerread_simulation.sh new file mode 100755 index 0000000..07b2485 --- /dev/null +++ b/src/supervised_lda/submit_scripts/peerread-exps/run_peerread_simulation.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash +#SBATCH -A sml +#SBATCH -c 8 +#SBATCH --mail-user=dhanya.sridhar@columbia.edu +#SBATCH --mail-type=ALL + +source activate py3.6 + +python -m supervised_lda.peerread_output_att \ +--dat-dir=${DIR} \ +--mode=${MODE} \ +--params=${BETA1} \ +--sim-dir=${SIMDIR} \ +--outdir=${OUT}/${BETA1} \ +--split=${SPLIT} \ +--linear-outcome-model=${LINOUTCOME} \ +--use-recon-loss=${RECONLOSS} \ +--use-supervised-loss=${SUPLOSS} \ \ No newline at end of file diff --git a/src/supervised_lda/submit_scripts/peerread-exps/submit_no_sup.sh b/src/supervised_lda/submit_scripts/peerread-exps/submit_no_sup.sh new file mode 100755 index 0000000..54e3cbe --- /dev/null +++ b/src/supervised_lda/submit_scripts/peerread-exps/submit_no_sup.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash +BASE_OUT=/proj/sml_netapp/projects/causal-text/PeerRead/supervised_lda_baseline/out/ + +export DIR=/proj/sml_netapp/projects/causal-text/PeerRead/supervised_lda_baseline/proc/ +export SIMDIR=/proj/sml_netapp/projects/causal-text/sim/peerread_buzzytitle_based/ + +export MODE=simple +export LINOUTCOME=t +export RECONLOSS=t +export SUPLOSS=f + +declare -a BETA1S=(5.0) + +for BETA1j in "${BETA1S[@]}"; do + for SPLITi in $(seq 0 9); do + export BETA1=${BETA1j} + export SPLIT=${SPLITi} + export OUT=${BASE_OUT}/no_sup/ + sbatch --job-name=peerread_supervised_lda_sim_${BETA1j}_${SPLITi} \ + --output=peerread_supervised_lda_sim_${BETA1j}_${SPLITi}.out \ + supervised_lda/submit_scripts/peerread-exps/run_peerread_simulation.sh + done +done diff --git a/src/supervised_lda/submit_scripts/peerread-exps/submit_no_unsup.sh b/src/supervised_lda/submit_scripts/peerread-exps/submit_no_unsup.sh new file mode 100755 index 0000000..4dfa39f --- /dev/null +++ b/src/supervised_lda/submit_scripts/peerread-exps/submit_no_unsup.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash +BASE_OUT=/proj/sml_netapp/projects/causal-text/PeerRead/supervised_lda_baseline/out/ + +export DIR=/proj/sml_netapp/projects/causal-text/PeerRead/supervised_lda_baseline/proc/ +export SIMDIR=/proj/sml_netapp/projects/causal-text/sim/peerread_buzzytitle_based/ + +export MODE=simple +export LINOUTCOME=t +export RECONLOSS=f +export SUPLOSS=t + +declare -a BETA1S=(1.0 5.0 25.0) + +for BETA1j in "${BETA1S[@]}"; do + for SPLITi in $(seq 0 9); do + export BETA1=${BETA1j} + export SPLIT=${SPLITi} + export OUT=${BASE_OUT}/no_unsup/ + sbatch --job-name=peerread_supervised_lda_sim_${BETA1j}_${SPLITi} \ + --output=peerread_supervised_lda_sim_${BETA1j}_${SPLITi}.out \ + supervised_lda/submit_scripts/peerread-exps/run_peerread_simulation.sh + done +done diff --git a/src/supervised_lda/submit_scripts/peerread-exps/submit_nonlinear.sh b/src/supervised_lda/submit_scripts/peerread-exps/submit_nonlinear.sh new file mode 100755 index 0000000..4c1f9bd --- /dev/null +++ b/src/supervised_lda/submit_scripts/peerread-exps/submit_nonlinear.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash +BASE_OUT=/proj/sml_netapp/projects/causal-text/PeerRead/supervised_lda_baseline/out/ + +export DIR=/proj/sml_netapp/projects/causal-text/PeerRead/supervised_lda_baseline/proc/ +export SIMDIR=/proj/sml_netapp/projects/causal-text/sim/peerread_buzzytitle_based/ + +export MODE=simple +export LINOUTCOME=f +export RECONLOSS=t +export SUPLOSS=t + +declare -a BETA1S=(1.0 5.0 25.0) + +for BETA1j in "${BETA1S[@]}"; do + for SPLITi in $(seq 0 9); do + export BETA1=${BETA1j} + export SPLIT=${SPLITi} + export OUT=${BASE_OUT}/non_linear/ + sbatch --job-name=peerread_supervised_lda_sim_${BETA1j}_${SPLITi} \ + --output=peerread_supervised_lda_sim_${BETA1j}_${SPLITi}.out \ + supervised_lda/submit_scripts/peerread-exps/run_peerread_simulation.sh + done +done diff --git a/src/supervised_lda/submit_scripts/peerread-exps/submit_peerread_simulation.sh b/src/supervised_lda/submit_scripts/peerread-exps/submit_peerread_simulation.sh new file mode 100755 index 0000000..5a95019 --- /dev/null +++ b/src/supervised_lda/submit_scripts/peerread-exps/submit_peerread_simulation.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash +BASE_OUT=/proj/sml_netapp/projects/causal-text/PeerRead/supervised_lda_baseline/out/ + +export DIR=/proj/sml_netapp/projects/causal-text/PeerRead/supervised_lda_baseline/proc/ +export SIMDIR=/proj/sml_netapp/projects/causal-text/sim/peerread_buzzytitle_based/ + +export MODE=simple +export LINOUTCOME=t +export RECONLOSS=t +export SUPLOSS=t + +declare -a BETA1S=(1.0 5.0 25.0) + +for BETA1j in "${BETA1S[@]}"; do + for SPLITi in $(seq 0 9); do + export BETA1=${BETA1j} + export SPLIT=${SPLITi} + export OUT=${BASE_OUT}/base_model/ + sbatch --job-name=peerread_supervised_lda_sim_${BETA1j}_${SPLITi} \ + --output=peerread_supervised_lda_sim_${BETA1j}_${SPLITi}.out \ + supervised_lda/submit_scripts/peerread-exps/run_peerread_simulation.sh + done +done diff --git a/src/supervised_lda/submit_scripts/reddit-exps/run_reddit_simulation.sh b/src/supervised_lda/submit_scripts/reddit-exps/run_reddit_simulation.sh new file mode 100755 index 0000000..f20de61 --- /dev/null +++ b/src/supervised_lda/submit_scripts/reddit-exps/run_reddit_simulation.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash +#SBATCH -A sml +#SBATCH -c 8 +#SBATCH --mail-user=dhanya.sridhar@columbia.edu +#SBATCH --mail-type=ALL + +source activate py3.6 + +python -m supervised_lda.reddit_output_att \ +--dat-dir=${DIR} \ +--mode=${MODE} \ +--subs=${SUBS} \ +--params=${BETA0},${BETA1},${GAMMA} \ +--sim-dir=${SIMDIR} \ +--outdir=${OUT}/beta0${BETA0}.beta1${BETA1}.gamma${GAMMA} \ +--split=${SPLIT} \ +--linear-outcome-model=${LINOUTCOME} \ +--use-recon-loss=${RECONLOSS} \ +--use-supervised-loss=${SUPLOSS} \ + diff --git a/src/supervised_lda/submit_scripts/reddit-exps/submit_no_sup.sh b/src/supervised_lda/submit_scripts/reddit-exps/submit_no_sup.sh new file mode 100755 index 0000000..ce39c67 --- /dev/null +++ b/src/supervised_lda/submit_scripts/reddit-exps/submit_no_sup.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash +BASE_OUT=/proj/sml_netapp/projects/causal-text/reddit/supervised_lda_baseline/out/ + +export DIR=/proj/sml_netapp/projects/causal-text/reddit/supervised_lda_baseline/proc/ +export SIMDIR=/proj/sml_netapp/projects/causal-text/sim/reddit_subreddit_based/ + +export MODE=simple +export SUBS=13,6,8 +export LINOUTCOME=t +export RECONLOSS=t +export SUPLOSS=f + +export BETA0=1.0 +declare -a BETA1S=(10.0) +declare -a GAMMAS=(1.0 4.0) + +for BETA1j in "${BETA1S[@]}"; do + export BETA1=${BETA1j} + for GAMMAj in "${GAMMAS[@]}"; do + for SPLITi in $(seq 0 4); do + export SPLIT=${SPLITi} + export GAMMA=${GAMMAj} + export OUT=${BASE_OUT}/no_sup/ + sbatch --job-name=reddit_supervised_lda_sim_${BETA1j}_${GAMMAj}_${SPLITi} \ + --output=reddit_supervised_lda_sim_${BETA1j}_${GAMMAj}_${SPLITi}.out \ + supervised_lda/submit_scripts/reddit-exps/run_reddit_simulation.sh + done + done +done diff --git a/src/supervised_lda/submit_scripts/reddit-exps/submit_no_unsup.sh b/src/supervised_lda/submit_scripts/reddit-exps/submit_no_unsup.sh new file mode 100755 index 0000000..506858c --- /dev/null +++ b/src/supervised_lda/submit_scripts/reddit-exps/submit_no_unsup.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash +BASE_OUT=/proj/sml_netapp/projects/causal-text/reddit/supervised_lda_baseline/out/ + +export DIR=/proj/sml_netapp/projects/causal-text/reddit/supervised_lda_baseline/proc/ +export SIMDIR=/proj/sml_netapp/projects/causal-text/sim/reddit_subreddit_based/ + +export MODE=simple +export SUBS=13,6,8 +export LINOUTCOME=t +export RECONLOSS=f +export SUPLOSS=t + +export BETA0=1.0 +declare -a BETA1S=(1.0 10.0 100.0) +declare -a GAMMAS=(1.0 4.0) + +for BETA1j in "${BETA1S[@]}"; do + export BETA1=${BETA1j} + for GAMMAj in "${GAMMAS[@]}"; do + for SPLITi in $(seq 0 4); do + export SPLIT=${SPLITi} + export GAMMA=${GAMMAj} + export OUT=${BASE_OUT}/no_unsup/ + sbatch --job-name=reddit_supervised_lda_sim_${BETA1j}_${GAMMAj}_${SPLITi} \ + --output=reddit_supervised_lda_sim_${BETA1j}_${GAMMAj}_${SPLITi}.out \ + supervised_lda/submit_scripts/reddit-exps/run_reddit_simulation.sh + + done + done +done diff --git a/src/supervised_lda/submit_scripts/reddit-exps/submit_nonlinear.sh b/src/supervised_lda/submit_scripts/reddit-exps/submit_nonlinear.sh new file mode 100755 index 0000000..c7105f3 --- /dev/null +++ b/src/supervised_lda/submit_scripts/reddit-exps/submit_nonlinear.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash +BASE_OUT=/proj/sml_netapp/projects/causal-text/reddit/supervised_lda_baseline/out/ + +export DIR=/proj/sml_netapp/projects/causal-text/reddit/supervised_lda_baseline/proc/ +export SIMDIR=/proj/sml_netapp/projects/causal-text/sim/reddit_subreddit_based/ + +export MODE=simple +export SUBS=13,6,8 +export LINOUTCOME=f +export RECONLOSS=t +export SUPLOSS=t + +export BETA0=1.0 +declare -a BETA1S=(1.0 10.0 100.0) +declare -a GAMMAS=(1.0 4.0) + +for BETA1j in "${BETA1S[@]}"; do + export BETA1=${BETA1j} + for GAMMAj in "${GAMMAS[@]}"; do + for SPLITi in $(seq 0 4); do + export SPLIT=${SPLITi} + export GAMMA=${GAMMAj} + export OUT=${BASE_OUT}/non_linear/ + sbatch --job-name=reddit_supervised_lda_sim_${BETA1j}_${GAMMAj}_${SPLITi} \ + --output=reddit_supervised_lda_sim_${BETA1j}_${GAMMAj}_${SPLITi}.out \ + supervised_lda/submit_scripts/reddit-exps/run_reddit_simulation.sh + done + done +done diff --git a/src/supervised_lda/submit_scripts/reddit-exps/submit_reddit_simulation.sh b/src/supervised_lda/submit_scripts/reddit-exps/submit_reddit_simulation.sh new file mode 100755 index 0000000..08a5292 --- /dev/null +++ b/src/supervised_lda/submit_scripts/reddit-exps/submit_reddit_simulation.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash +BASE_OUT=/proj/sml_netapp/projects/causal-text/reddit/supervised_lda_baseline/out/ + +export DIR=/proj/sml_netapp/projects/causal-text/reddit/supervised_lda_baseline/proc/ +export SIMDIR=/proj/sml_netapp/projects/causal-text/sim/reddit_subreddit_based/ + +export MODE=simple +export SUBS=13,6,8 +export LINOUTCOME=t +export RECONLOSS=t +export SUPLOSS=t + +export BETA0=1.0 +declare -a BETA1S=(1.0 10.0 100.0) +declare -a GAMMAS=(1.0 4.0) + +for BETA1j in "${BETA1S[@]}"; do + export BETA1=${BETA1j} + for GAMMAj in "${GAMMAS[@]}"; do + for SPLITi in $(seq 0 4); do + export SPLIT=${SPLITi} + export GAMMA=${GAMMAj} + export OUT=${BASE_OUT}/base_model/ + sbatch --job-name=reddit_supervised_lda_sim_${BETA1j}_${GAMMAj}_${SPLITi} \ + --output=reddit_supervised_lda_sim_${BETA1j}_${GAMMAj}_${SPLITi}.out \ + supervised_lda/submit_scripts/reddit-exps/run_reddit_simulation.sh + done + done +done diff --git a/src/supervised_lda/submit_scripts/reddit-exps/submit_reddit_test.sh b/src/supervised_lda/submit_scripts/reddit-exps/submit_reddit_test.sh new file mode 100755 index 0000000..acb0e92 --- /dev/null +++ b/src/supervised_lda/submit_scripts/reddit-exps/submit_reddit_test.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash +BASE_OUT=/proj/sml_netapp/projects/causal-text/reddit/supervised_lda_baseline/out/ + +export DIR=/proj/sml_netapp/projects/causal-text/reddit/supervised_lda_baseline/proc/ +export SIMDIR=/proj/sml_netapp/projects/causal-text/sim/reddit_subreddit_based/ + +export MODE=simple +export SUBS=13,6,8 +export LINOUTCOME=True +export RECONLOSS=True +export SUPLOSS=True + +export BETA0=1.0 +declare -a BETA1S=(1.0) +declare -a GAMMAS=(1.0) + +for BETA1j in "${BETA1S[@]}"; do + export BETA1=${BETA1j} + for GAMMAj in "${GAMMAS[@]}"; do + for SPLITi in $(seq 0 1); do + export SPLIT=${SPLITi} + export GAMMA=${GAMMAj} + export OUT=${BASE_OUT}/base_model/ + sbatch --job-name=reddit_supervised_lda_sim_${BETA1j}_${GAMMAj}_${SPLITi} \ + --output=reddit_supervised_lda_sim_${BETA1j}_${GAMMAj}_${SPLITi}.out \ + supervised_lda/submit_scripts/reddit-exps/run_reddit_simulation.sh + done + done +done diff --git a/src/supervised_lda/supervised_topic_model.py b/src/supervised_lda/supervised_topic_model.py new file mode 100644 index 0000000..9944967 --- /dev/null +++ b/src/supervised_lda/supervised_topic_model.py @@ -0,0 +1,193 @@ +import torch +import torch.nn.functional as F +import numpy as np +import math + +from torch import nn + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +class SupervisedTopicModel(nn.Module): + def __init__(self, num_topics, vocab_size, num_documents, t_hidden_size=800, theta_act='relu', enc_drop=0., outcome_linear_map=True): + super(SupervisedTopicModel, self).__init__() + + ## define hyperparameters + self.num_topics = num_topics + self.vocab_size = vocab_size + self.num_documents = num_documents + self.t_hidden_size = t_hidden_size + self.enc_drop = enc_drop + self.t_drop = nn.Dropout(enc_drop) + self.theta_act = self.get_activation(theta_act) + self.outcome_linear_map = outcome_linear_map + + ## define the matrix containing the topic embeddings + self.alphas = nn.Parameter(torch.randn(vocab_size, num_topics)) + + if self.outcome_linear_map: + ## define linear regression weights for predicting expected outcomes for treated + self.w_expected_outcome_treated = nn.Linear(num_topics, 1) + + ## define linear regression weights for predicting expected outcomes for untreated + self.w_expected_outcome_untreated = nn.Linear(num_topics, 1) + else: + self.f_outcome_treated = nn.Sequential( + nn.Linear(num_topics, t_hidden_size), + self.theta_act, + # nn.BatchNorm1d(t_hidden_size), + nn.Linear(t_hidden_size, t_hidden_size), + self.theta_act, + # nn.BatchNorm1d(t_hidden_size), + nn.Linear(t_hidden_size,1) + ) + self.f_outcome_untreated = nn.Sequential( + nn.Linear(num_topics, t_hidden_size), + self.theta_act, + # nn.BatchNorm1d(t_hidden_size), + nn.Linear(t_hidden_size, t_hidden_size), + self.theta_act, + # nn.BatchNorm1d(t_hidden_size), + nn.Linear(t_hidden_size,1) + ) + ## define linear regression weights for predicting binary treatment label + self.w_treatment = nn.Linear(num_topics,1) + + self.q_theta = nn.Sequential( + nn.Linear(vocab_size, t_hidden_size), + self.theta_act, + nn.BatchNorm1d(t_hidden_size), + nn.Linear(t_hidden_size, t_hidden_size), + self.theta_act, + nn.BatchNorm1d(t_hidden_size) + ) + self.mu_q_theta = nn.Linear(t_hidden_size, num_topics) + self.logsigma_q_theta = nn.Linear(t_hidden_size, num_topics) + + def get_activation(self, act): + if act == 'tanh': + act = nn.Tanh() + elif act == 'relu': + act = nn.ReLU() + elif act == 'softplus': + act = nn.Softplus() + elif act == 'rrelu': + act = nn.RReLU() + elif act == 'leakyrelu': + act = nn.LeakyReLU() + elif act == 'elu': + act = nn.ELU() + elif act == 'selu': + act = nn.SELU() + elif act == 'glu': + act = nn.GLU() + else: + print('Defaulting to tanh activations...') + act = nn.Tanh() + return act + + def reparameterize(self, mu, logvar): + """Returns a sample from a Gaussian distribution via reparameterization. + """ + if self.training: + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return eps.mul_(std).add_(mu) + else: + return mu + + def encode(self, bows): + """Returns paramters of the variational distribution for \theta. + + input: bows + batch of bag-of-words...tensor of shape bsz x V + output: mu_theta, log_sigma_theta + """ + q_theta = self.q_theta(bows) + if self.enc_drop > 0: + q_theta = self.t_drop(q_theta) + mu_theta = self.mu_q_theta(q_theta) + logsigma_theta = self.logsigma_q_theta(q_theta) + kl_theta = -0.5 * torch.sum(1 + logsigma_theta - mu_theta.pow(2) - logsigma_theta.exp(), dim=-1).mean() + return mu_theta, logsigma_theta, kl_theta + + def get_beta(self): + beta = F.softmax(self.alphas, dim=0).transpose(1, 0) ## softmax over vocab dimension + return beta + + def get_theta(self, normalized_bows): + mu_theta, logsigma_theta, kld_theta = self.encode(normalized_bows) + z = self.reparameterize(mu_theta, logsigma_theta) + theta = F.softmax(z, dim=-1) + return theta, kld_theta + + def decode(self, theta, beta): + res = torch.mm(theta, beta) + preds = torch.log(res+1e-6) + return preds + + def predict_treatment(self, theta): + logits = self.w_treatment(theta) + return logits + + def predict_outcome_st_treat(self, theta, treatment_labels): + treated_indices = [treatment_labels == 1] + theta_treated = theta[treated_indices] + + if not self.outcome_linear_map: + expected_outcome_treated = self.f_outcome_treated(theta_treated) + else: + expected_outcome_treated = self.w_expected_outcome_treated(theta_treated) + + return expected_outcome_treated + + def predict_outcome_st_no_treat(self, theta, treatment_labels): + untreated_indices = [treatment_labels == 0] + theta_untreated = theta[untreated_indices] + + if not self.outcome_linear_map: + expected_outcome_untreated = self.f_outcome_untreated(theta_untreated) + else: + expected_outcome_untreated = self.w_expected_outcome_untreated(theta_untreated) + + return expected_outcome_untreated + + + def forward(self, bows, normalized_bows, treatment_labels, outcomes, dtype='real', use_supervised_loss=True): + ## get \theta + theta, kld_theta = self.get_theta(normalized_bows) + beta = self.get_beta() + + bce_loss = nn.BCEWithLogitsLoss() + mse_loss = nn.MSELoss() + + ## get reconstruction loss + preds = self.decode(theta, beta) + recon_loss = -(preds * bows).sum(1) + recon_loss = recon_loss.mean() + + supervised_loss=None + if use_supervised_loss: + + #get treatment loss + treatment_logits = self.predict_treatment(theta).squeeze() + treatment_loss = bce_loss(treatment_logits, treatment_labels) + + #get expected outcome loss + treated = [treatment_labels == 1] + untreated = [treatment_labels == 0] + outcomes_treated = outcomes[treated] + outcomes_untreated = outcomes[untreated] + expected_treated = self.predict_outcome_st_treat(theta, treatment_labels).squeeze() + expected_untreated = self.predict_outcome_st_no_treat(theta, treatment_labels).squeeze() + + if dtype == 'real': + outcome_loss_treated = mse_loss(expected_treated,outcomes_treated) + outcome_loss_untreated = mse_loss(expected_treated,outcomes_treated) + else: + outcome_loss_treated = bce_loss(expected_treated,outcomes_treated) + outcome_loss_untreated = bce_loss(expected_treated,outcomes_treated) + + supervised_loss = treatment_loss + outcome_loss_treated + outcome_loss_untreated + + return recon_loss, supervised_loss, kld_theta + diff --git a/src/supervised_lda/test_slda.ipynb b/src/supervised_lda/test_slda.ipynb new file mode 100644 index 0000000..439daa7 --- /dev/null +++ b/src/supervised_lda/test_slda.ipynb @@ -0,0 +1,870 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from helpers import filter_document_terms, make_index_mapping, assign_split\n", + "import numpy as np\n", + "import pandas as pd\n", + "import os\n", + "import sys\n", + "import supervised_topic_model \n", + "import run_supervised_tm\n", + "from scipy import sparse\n", + "from importlib import reload\n", + "\n", + "params=\"1.0\"\n", + "sim_dir = '../../dat/sim/peerread_buzzytitle_based/'\n", + "mode = 'simple'\n", + "sim_setting ='beta00.25' + '.beta1' + params + '.gamma0.0'\n", + "simulation_file = sim_dir + '/mode' + mode + '/' + sim_setting + \".tsv\"" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "def load_peerread(path='../../dat/PeerRead/'):\n", + "\treturn pd.read_csv(path + 'proc_abstracts.csv')\n", + "\n", + "def load_term_counts(df, path='../../dat/PeerRead/', force_redo=False, text_col='abstract_text'):\n", + "\tcount_filename = path + 'term_counts'\n", + "\tvocab_filename = path + 'vocab'\n", + "\n", + "\tif os.path.exists(count_filename + '.npz') and not force_redo:\n", + "\t\treturn sparse.load_npz(count_filename + '.npz'), np.load(vocab_filename + '.npy')\n", + "\n", + "\tpost_docs = df[text_col].values\n", + "\tcounts, vocab, _ = tokenize_documents(post_docs) \n", + "\tsparse.save_npz(count_filename, counts)\n", + "\tnp.save(vocab_filename, vocab)\n", + "\treturn counts, np.array(vocab)\n", + "\n", + "def load_simulated_data():\n", + "\tsim_df = pd.read_csv(simulation_file, delimiter='\\t')\n", + "\treturn sim_df" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "peerread = load_peerread()\n", + "counts,vocab = load_term_counts(peerread)\n", + "counts= counts.toarray()\n", + "\n", + "indices = peerread['paper_id'].values\n", + "index_mapping = make_index_mapping(peerread, on='index')\n", + "\n", + "sim_df = load_simulated_data()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Acc. loss: 634.7957763671875 KL loss.: 0.04140656813979149 Supervised loss: 2.088380813598633\n", + "Acc. loss: 652.6654052734375 KL loss.: 0.1379348635673523 Supervised loss: 2.124856948852539\n", + "Acc. loss: 649.2124633789062 KL loss.: 0.13782715797424316 Supervised loss: 2.0561790466308594\n", + "Acc. loss: 628.5281372070312 KL loss.: 0.09775403887033463 Supervised loss: 2.0797927379608154\n", + "Acc. loss: 626.9488525390625 KL loss.: 0.10366755723953247 Supervised loss: 2.0787906646728516\n", + "Acc. loss: 632.6178588867188 KL loss.: 0.11328624188899994 Supervised loss: 2.074093818664551\n", + "Acc. loss: 650.6323852539062 KL loss.: 0.10288330912590027 Supervised loss: 2.066779613494873\n", + "Acc. loss: 647.1697998046875 KL loss.: 0.12123933434486389 Supervised loss: 2.0511393547058105\n", + "Acc. loss: 626.5287475585938 KL loss.: 0.1381053626537323 Supervised loss: 2.0768837928771973\n", + "Acc. loss: 625.0350952148438 KL loss.: 0.08887579292058945 Supervised loss: 2.0441465377807617\n", + "Acc. loss: 630.6015014648438 KL loss.: 0.14249682426452637 Supervised loss: 2.0655603408813477\n", + "Acc. loss: 648.536376953125 KL loss.: 0.1408989280462265 Supervised loss: 2.0474042892456055\n", + "Acc. loss: 645.245849609375 KL loss.: 0.11101733893156052 Supervised loss: 2.030522584915161\n", + "Acc. loss: 624.6386108398438 KL loss.: 0.13377906382083893 Supervised loss: 2.065744638442993\n", + "Acc. loss: 623.0485229492188 KL loss.: 0.1554994434118271 Supervised loss: 2.0334792137145996\n", + "Acc. loss: 628.7306518554688 KL loss.: 0.0850178599357605 Supervised loss: 2.0504908561706543\n", + "Acc. loss: 646.6339111328125 KL loss.: 0.1283387988805771 Supervised loss: 2.036609649658203\n", + "Acc. loss: 643.2108764648438 KL loss.: 0.1850668489933014 Supervised loss: 2.022549629211426\n", + "Acc. loss: 622.767333984375 KL loss.: 0.1347304880619049 Supervised loss: 2.052277088165283\n", + "Acc. loss: 621.0941772460938 KL loss.: 0.11989694088697433 Supervised loss: 2.0218236446380615\n", + "Acc. loss: 626.9766845703125 KL loss.: 0.1292686015367508 Supervised loss: 2.0399556159973145\n", + "Acc. loss: 644.6349487304688 KL loss.: 0.15476112067699432 Supervised loss: 2.0239899158477783\n", + "Acc. loss: 641.2636108398438 KL loss.: 0.16688869893550873 Supervised loss: 2.0074567794799805\n", + "Acc. loss: 620.8372802734375 KL loss.: 0.1618511974811554 Supervised loss: 2.045886516571045\n", + "Acc. loss: 619.4049072265625 KL loss.: 0.15865185856819153 Supervised loss: 2.0094056129455566\n", + "Acc. loss: 624.9852294921875 KL loss.: 0.11999165266752243 Supervised loss: 2.032895565032959\n", + "Acc. loss: 642.7385864257812 KL loss.: 0.14976756274700165 Supervised loss: 2.0141701698303223\n", + "Acc. loss: 639.3965454101562 KL loss.: 0.2086641788482666 Supervised loss: 1.9967474937438965\n", + "Acc. loss: 618.978271484375 KL loss.: 0.18051962554454803 Supervised loss: 2.036914348602295\n", + "Acc. loss: 617.625 KL loss.: 0.13743819296360016 Supervised loss: 2.000469923019409\n", + "Acc. loss: 623.190673828125 KL loss.: 0.13912740349769592 Supervised loss: 2.0222818851470947\n", + "Acc. loss: 640.8568115234375 KL loss.: 0.19409315288066864 Supervised loss: 2.0050911903381348\n", + "Acc. loss: 637.52783203125 KL loss.: 0.20685194432735443 Supervised loss: 1.9873762130737305\n", + "Acc. loss: 617.27490234375 KL loss.: 0.19531652331352234 Supervised loss: 2.026624917984009\n", + "Acc. loss: 615.7991333007812 KL loss.: 0.15136834979057312 Supervised loss: 1.9915733337402344\n", + "Acc. loss: 621.3892211914062 KL loss.: 0.16623076796531677 Supervised loss: 2.0139570236206055\n", + "Acc. loss: 638.9893188476562 KL loss.: 0.2075347751379013 Supervised loss: 1.99641752243042\n", + "Acc. loss: 635.6463623046875 KL loss.: 0.2234317660331726 Supervised loss: 1.9783962965011597\n", + "Acc. loss: 615.4528198242188 KL loss.: 0.21398229897022247 Supervised loss: 2.0211989879608154\n", + "Acc. loss: 614.0679321289062 KL loss.: 0.1836088001728058 Supervised loss: 1.9824151992797852\n", + "Acc. loss: 619.6111450195312 KL loss.: 0.18826153874397278 Supervised loss: 2.0062127113342285\n", + "Acc. loss: 637.2311401367188 KL loss.: 0.22428998351097107 Supervised loss: 1.9885168075561523\n", + "Acc. loss: 633.886962890625 KL loss.: 0.22329291701316833 Supervised loss: 1.9695569276809692\n", + "Acc. loss: 613.7763061523438 KL loss.: 0.2379865199327469 Supervised loss: 2.01389217376709\n", + "Acc. loss: 612.3335571289062 KL loss.: 0.22031542658805847 Supervised loss: 1.9752979278564453\n", + "Acc. loss: 617.9511108398438 KL loss.: 0.18600764870643616 Supervised loss: 2.000001907348633\n", + "Acc. loss: 635.3742065429688 KL loss.: 0.19648195803165436 Supervised loss: 1.9806807041168213\n", + "Acc. loss: 631.9808959960938 KL loss.: 0.3065944015979767 Supervised loss: 1.9617515802383423\n", + "Acc. loss: 611.978759765625 KL loss.: 0.2528259754180908 Supervised loss: 2.0059690475463867\n", + "Acc. loss: 610.5684204101562 KL loss.: 0.21446384489536285 Supervised loss: 1.9677181243896484\n", + "Acc. loss: 616.2669677734375 KL loss.: 0.20624415576457977 Supervised loss: 1.9931845664978027\n", + "Acc. loss: 633.7069091796875 KL loss.: 0.21131473779678345 Supervised loss: 1.974545955657959\n", + "Acc. loss: 630.298828125 KL loss.: 0.28980910778045654 Supervised loss: 1.955168604850769\n", + "Acc. loss: 610.197509765625 KL loss.: 0.36060264706611633 Supervised loss: 2.0021581649780273\n", + "Acc. loss: 608.9459228515625 KL loss.: 0.2834569811820984 Supervised loss: 1.9608267545700073\n", + "Acc. loss: 614.638671875 KL loss.: 0.201319620013237 Supervised loss: 1.990837812423706\n", + "Acc. loss: 631.9982299804688 KL loss.: 0.18684402108192444 Supervised loss: 1.968562126159668\n", + "Acc. loss: 628.7250366210938 KL loss.: 0.2703857719898224 Supervised loss: 1.9504163265228271\n", + "Acc. loss: 608.5803833007812 KL loss.: 0.43809598684310913 Supervised loss: 1.9968599081039429\n", + "Acc. loss: 607.1255493164062 KL loss.: 0.3968292474746704 Supervised loss: 1.957007884979248\n", + "Acc. loss: 612.9908447265625 KL loss.: 0.2461068034172058 Supervised loss: 1.9847311973571777\n", + "Acc. loss: 630.322509765625 KL loss.: 0.17726734280586243 Supervised loss: 1.963417410850525\n", + "Acc. loss: 627.0383911132812 KL loss.: 0.19822846353054047 Supervised loss: 1.9452910423278809\n", + "Acc. loss: 607.103271484375 KL loss.: 0.2928427755832672 Supervised loss: 1.991325855255127\n", + "Acc. loss: 605.5984497070312 KL loss.: 0.4701569378376007 Supervised loss: 1.9508798122406006\n", + "Acc. loss: 611.1455688476562 KL loss.: 0.4681859016418457 Supervised loss: 1.978605031967163\n", + "Acc. loss: 628.6463012695312 KL loss.: 0.3057301342487335 Supervised loss: 1.9586715698242188\n", + "Acc. loss: 625.5117797851562 KL loss.: 0.22863595187664032 Supervised loss: 1.9377126693725586\n", + "Acc. loss: 605.8080444335938 KL loss.: 0.22547172009944916 Supervised loss: 1.9894318580627441\n", + "Acc. loss: 604.2130737304688 KL loss.: 0.2960455119609833 Supervised loss: 1.9456671476364136\n", + "Acc. loss: 609.8333129882812 KL loss.: 0.3360653817653656 Supervised loss: 1.977726936340332\n", + "Acc. loss: 627.0001831054688 KL loss.: 0.35447433590888977 Supervised loss: 1.9535932540893555\n", + "Acc. loss: 623.8096313476562 KL loss.: 0.3533801734447479 Supervised loss: 1.9308574199676514\n", + "Acc. loss: 604.0684204101562 KL loss.: 0.3483069837093353 Supervised loss: 1.986729621887207\n", + "Acc. loss: 602.7664184570312 KL loss.: 0.37504979968070984 Supervised loss: 1.9402377605438232\n", + "Acc. loss: 608.3152465820312 KL loss.: 0.3884356915950775 Supervised loss: 1.9718488454818726\n", + "Acc. loss: 625.4102783203125 KL loss.: 0.36304759979248047 Supervised loss: 1.9495823383331299\n", + "Acc. loss: 622.2970581054688 KL loss.: 0.37069812417030334 Supervised loss: 1.93050217628479\n", + "Acc. loss: 602.5340576171875 KL loss.: 0.3210248351097107 Supervised loss: 1.983092188835144\n", + "Acc. loss: 601.2593994140625 KL loss.: 0.3611657917499542 Supervised loss: 1.9365015029907227\n", + "Acc. loss: 606.622802734375 KL loss.: 0.44638165831565857 Supervised loss: 1.9698593616485596\n", + "Acc. loss: 623.8392333984375 KL loss.: 0.4735555648803711 Supervised loss: 1.9469918012619019\n", + "Acc. loss: 620.5628662109375 KL loss.: 0.39191290736198425 Supervised loss: 1.9235000610351562\n", + "Acc. loss: 601.1473388671875 KL loss.: 0.3333520293235779 Supervised loss: 1.9798810482025146\n", + "Acc. loss: 599.7913818359375 KL loss.: 0.3734901249408722 Supervised loss: 1.9323662519454956\n", + "Acc. loss: 605.2146606445312 KL loss.: 0.4880806505680084 Supervised loss: 1.9660332202911377\n", + "Acc. loss: 622.2717895507812 KL loss.: 0.4943352937698364 Supervised loss: 1.9442998170852661\n", + "Acc. loss: 619.1456909179688 KL loss.: 0.4411986470222473 Supervised loss: 1.9209096431732178\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Acc. loss: 599.6876220703125 KL loss.: 0.38490843772888184 Supervised loss: 1.9805517196655273\n", + "Acc. loss: 598.4629516601562 KL loss.: 0.41774091124534607 Supervised loss: 1.9279911518096924\n", + "Acc. loss: 603.804443359375 KL loss.: 0.48628202080726624 Supervised loss: 1.9626843929290771\n", + "Acc. loss: 620.818115234375 KL loss.: 0.5278497934341431 Supervised loss: 1.941227674484253\n", + "Acc. loss: 617.5929565429688 KL loss.: 0.48286446928977966 Supervised loss: 1.9193298816680908\n", + "Acc. loss: 598.300048828125 KL loss.: 0.382050484418869 Supervised loss: 1.9763628244400024\n", + "Acc. loss: 597.0465087890625 KL loss.: 0.39224833250045776 Supervised loss: 1.9261529445648193\n", + "Acc. loss: 602.426513671875 KL loss.: 0.4795830547809601 Supervised loss: 1.961333155632019\n", + "Acc. loss: 619.4287719726562 KL loss.: 0.5574470162391663 Supervised loss: 1.9386060237884521\n", + "Acc. loss: 616.1077880859375 KL loss.: 0.5500497817993164 Supervised loss: 1.9145116806030273\n", + "Acc. loss: 596.890869140625 KL loss.: 0.48524120450019836 Supervised loss: 1.9737809896469116\n", + "Acc. loss: 595.6376342773438 KL loss.: 0.44047123193740845 Supervised loss: 1.9253599643707275\n", + "Acc. loss: 601.1607666015625 KL loss.: 0.44654321670532227 Supervised loss: 1.9585061073303223\n", + "Acc. loss: 617.960693359375 KL loss.: 0.5141069293022156 Supervised loss: 1.938472032546997\n", + "Acc. loss: 614.753662109375 KL loss.: 0.581633448600769 Supervised loss: 1.9108023643493652\n", + "Acc. loss: 595.36474609375 KL loss.: 0.599476158618927 Supervised loss: 1.9738755226135254\n", + "Acc. loss: 594.1175537109375 KL loss.: 0.5437114238739014 Supervised loss: 1.9219081401824951\n", + "Acc. loss: 599.7225341796875 KL loss.: 0.49401357769966125 Supervised loss: 1.9597573280334473\n", + "Acc. loss: 616.6268920898438 KL loss.: 0.5054826140403748 Supervised loss: 1.937340497970581\n", + "Acc. loss: 613.4783325195312 KL loss.: 0.4871211647987366 Supervised loss: 1.9095733165740967\n", + "Acc. loss: 594.3839721679688 KL loss.: 0.47038534283638 Supervised loss: 1.9707393646240234\n", + "Acc. loss: 592.9819946289062 KL loss.: 0.5186231136322021 Supervised loss: 1.9213314056396484\n", + "Acc. loss: 598.5391235351562 KL loss.: 0.5697377920150757 Supervised loss: 1.9570823907852173\n", + "Acc. loss: 615.4370727539062 KL loss.: 0.5512686967849731 Supervised loss: 1.9339805841445923\n", + "Acc. loss: 612.2005004882812 KL loss.: 0.48257148265838623 Supervised loss: 1.9089281558990479\n", + "Acc. loss: 593.1550903320312 KL loss.: 0.4570930600166321 Supervised loss: 1.9695611000061035\n", + "Acc. loss: 591.8512573242188 KL loss.: 0.4616970717906952 Supervised loss: 1.918290376663208\n", + "Acc. loss: 597.26123046875 KL loss.: 0.49206170439720154 Supervised loss: 1.9553351402282715\n", + "Acc. loss: 613.9918823242188 KL loss.: 0.5571587681770325 Supervised loss: 1.9323127269744873\n", + "Acc. loss: 610.7870483398438 KL loss.: 0.6024622321128845 Supervised loss: 1.9063526391983032\n", + "Acc. loss: 591.5699462890625 KL loss.: 0.5824069976806641 Supervised loss: 1.9692609310150146\n", + "Acc. loss: 590.4949340820312 KL loss.: 0.5427226424217224 Supervised loss: 1.9189532995224\n", + "Acc. loss: 596.0482177734375 KL loss.: 0.5367574691772461 Supervised loss: 1.9518933296203613\n", + "Acc. loss: 612.7228393554688 KL loss.: 0.5546607375144958 Supervised loss: 1.9294453859329224\n", + "Acc. loss: 609.6057739257812 KL loss.: 0.6377637386322021 Supervised loss: 1.9035661220550537\n", + "Acc. loss: 590.571044921875 KL loss.: 0.5956011414527893 Supervised loss: 1.9683594703674316\n", + "Acc. loss: 589.2631225585938 KL loss.: 0.5508600473403931 Supervised loss: 1.9140489101409912\n", + "Acc. loss: 594.9088745117188 KL loss.: 0.5596458911895752 Supervised loss: 1.9517335891723633\n", + "Acc. loss: 611.5405883789062 KL loss.: 0.5900426506996155 Supervised loss: 1.9295268058776855\n", + "Acc. loss: 608.3031005859375 KL loss.: 0.648329496383667 Supervised loss: 1.9033153057098389\n", + "Acc. loss: 589.2378540039062 KL loss.: 0.63958340883255 Supervised loss: 1.9675097465515137\n", + "Acc. loss: 588.0961303710938 KL loss.: 0.5590806603431702 Supervised loss: 1.914626121520996\n", + "Acc. loss: 593.4763793945312 KL loss.: 0.5715662837028503 Supervised loss: 1.9546208381652832\n", + "Acc. loss: 610.2390747070312 KL loss.: 0.6334273219108582 Supervised loss: 1.9271912574768066\n", + "Acc. loss: 607.0982666015625 KL loss.: 0.7078620195388794 Supervised loss: 1.9038770198822021\n", + "Acc. loss: 588.055419921875 KL loss.: 0.6576854586601257 Supervised loss: 1.9671618938446045\n", + "Acc. loss: 586.8609008789062 KL loss.: 0.6340914368629456 Supervised loss: 1.914759635925293\n", + "Acc. loss: 592.4799194335938 KL loss.: 0.5899326205253601 Supervised loss: 1.9512724876403809\n", + "Acc. loss: 609.1034545898438 KL loss.: 0.5784440636634827 Supervised loss: 1.9282276630401611\n", + "Acc. loss: 605.803466796875 KL loss.: 0.7055132389068604 Supervised loss: 1.9023444652557373\n", + "Acc. loss: 586.9415283203125 KL loss.: 0.6946328282356262 Supervised loss: 1.9655742645263672\n", + "Acc. loss: 585.8007202148438 KL loss.: 0.7270799279212952 Supervised loss: 1.9127206802368164\n", + "Acc. loss: 591.2859497070312 KL loss.: 0.6685547828674316 Supervised loss: 1.950807809829712\n", + "Acc. loss: 607.8809204101562 KL loss.: 0.5855787396430969 Supervised loss: 1.9269685745239258\n", + "Acc. loss: 604.8389892578125 KL loss.: 0.6339946389198303 Supervised loss: 1.9005935192108154\n", + "Acc. loss: 585.9232788085938 KL loss.: 0.6531076431274414 Supervised loss: 1.9650890827178955\n", + "Acc. loss: 584.6952514648438 KL loss.: 0.6928679347038269 Supervised loss: 1.911440134048462\n", + "Acc. loss: 590.3418579101562 KL loss.: 0.6446716785430908 Supervised loss: 1.951514482498169\n", + "Acc. loss: 606.8297729492188 KL loss.: 0.6265989542007446 Supervised loss: 1.9243438243865967\n", + "Acc. loss: 603.6721801757812 KL loss.: 0.6540220379829407 Supervised loss: 1.8983299732208252\n", + "Acc. loss: 585.0025024414062 KL loss.: 0.6409878730773926 Supervised loss: 1.9637936353683472\n", + "Acc. loss: 583.7362670898438 KL loss.: 0.6943967342376709 Supervised loss: 1.9123172760009766\n", + "Acc. loss: 589.0593872070312 KL loss.: 0.7175873517990112 Supervised loss: 1.9523029327392578\n", + "Acc. loss: 605.6854858398438 KL loss.: 0.6813796758651733 Supervised loss: 1.9271676540374756\n", + "Acc. loss: 602.7294311523438 KL loss.: 0.638911247253418 Supervised loss: 1.897719383239746\n", + "Acc. loss: 583.82373046875 KL loss.: 0.6100723743438721 Supervised loss: 1.9628639221191406\n", + "Acc. loss: 582.6400756835938 KL loss.: 0.643905758857727 Supervised loss: 1.9118653535842896\n", + "Acc. loss: 587.9468383789062 KL loss.: 0.7586087584495544 Supervised loss: 1.95174241065979\n", + "Acc. loss: 604.4320068359375 KL loss.: 0.7815669178962708 Supervised loss: 1.9244964122772217\n", + "Acc. loss: 601.5834350585938 KL loss.: 0.7500147819519043 Supervised loss: 1.8977433443069458\n", + "Acc. loss: 582.7657470703125 KL loss.: 0.5909407734870911 Supervised loss: 1.9625589847564697\n", + "Acc. loss: 581.783447265625 KL loss.: 0.6552307605743408 Supervised loss: 1.9103069305419922\n", + "Acc. loss: 587.0087280273438 KL loss.: 0.7766273617744446 Supervised loss: 1.9489648342132568\n", + "Acc. loss: 603.2978515625 KL loss.: 0.902069091796875 Supervised loss: 1.9232780933380127\n", + "Acc. loss: 600.4027099609375 KL loss.: 0.8082619905471802 Supervised loss: 1.8960590362548828\n", + "Acc. loss: 581.7525024414062 KL loss.: 0.6174604296684265 Supervised loss: 1.9678735733032227\n", + "Acc. loss: 580.7496948242188 KL loss.: 0.5887066721916199 Supervised loss: 1.9073309898376465\n", + "Acc. loss: 586.1242065429688 KL loss.: 0.6943305134773254 Supervised loss: 1.9533847570419312\n", + "Acc. loss: 602.3807983398438 KL loss.: 0.8296127319335938 Supervised loss: 1.924755334854126\n", + "Acc. loss: 599.3068237304688 KL loss.: 0.888909101486206 Supervised loss: 1.8948352336883545\n", + "Acc. loss: 580.6961669921875 KL loss.: 0.7293862700462341 Supervised loss: 1.9634697437286377\n", + "Acc. loss: 579.80517578125 KL loss.: 0.6581110954284668 Supervised loss: 1.909712791442871\n", + "Acc. loss: 585.2340698242188 KL loss.: 0.6615696549415588 Supervised loss: 1.951656460762024\n", + "Acc. loss: 601.5764770507812 KL loss.: 0.7147952318191528 Supervised loss: 1.9227977991104126\n", + "Acc. loss: 598.3082275390625 KL loss.: 0.8561665415763855 Supervised loss: 1.8965492248535156\n", + "Acc. loss: 579.661865234375 KL loss.: 0.8707464933395386 Supervised loss: 1.9646329879760742\n", + "Acc. loss: 578.5656127929688 KL loss.: 0.7967656254768372 Supervised loss: 1.9094716310501099\n", + "Acc. loss: 584.1214599609375 KL loss.: 0.7857041358947754 Supervised loss: 1.9498610496520996\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Acc. loss: 600.3922729492188 KL loss.: 0.7509818077087402 Supervised loss: 1.922802209854126\n", + "Acc. loss: 597.365234375 KL loss.: 0.8456413149833679 Supervised loss: 1.8953893184661865\n", + "Acc. loss: 578.6476440429688 KL loss.: 0.9063360691070557 Supervised loss: 1.9655643701553345\n", + "Acc. loss: 577.5272216796875 KL loss.: 0.9557400345802307 Supervised loss: 1.9087704420089722\n", + "Acc. loss: 583.2100830078125 KL loss.: 0.758919358253479 Supervised loss: 1.9502391815185547\n", + "Acc. loss: 599.614013671875 KL loss.: 0.6773056983947754 Supervised loss: 1.920617938041687\n", + "Acc. loss: 596.4628295898438 KL loss.: 0.7548312544822693 Supervised loss: 1.8950583934783936\n", + "Acc. loss: 577.8541870117188 KL loss.: 0.8534181714057922 Supervised loss: 1.964996099472046\n", + "Acc. loss: 576.6302490234375 KL loss.: 0.9545819759368896 Supervised loss: 1.9079914093017578\n", + "Acc. loss: 582.0859985351562 KL loss.: 0.9591970443725586 Supervised loss: 1.9496138095855713\n", + "Acc. loss: 598.5098876953125 KL loss.: 0.8687512278556824 Supervised loss: 1.9203529357910156\n", + "Acc. loss: 595.6066284179688 KL loss.: 0.8119627237319946 Supervised loss: 1.8940017223358154\n", + "Acc. loss: 577.1763916015625 KL loss.: 0.732309103012085 Supervised loss: 1.963850498199463\n", + "Acc. loss: 576.012451171875 KL loss.: 0.7605426907539368 Supervised loss: 1.9087677001953125\n", + "Acc. loss: 581.3328857421875 KL loss.: 0.8057571053504944 Supervised loss: 1.9509682655334473\n", + "Acc. loss: 597.7195434570312 KL loss.: 0.9203774929046631 Supervised loss: 1.9210426807403564\n", + "Acc. loss: 594.416259765625 KL loss.: 1.0169001817703247 Supervised loss: 1.8930925130844116\n", + "Acc. loss: 576.15576171875 KL loss.: 0.8556713461875916 Supervised loss: 1.9642741680145264\n", + "Acc. loss: 575.1513671875 KL loss.: 0.7590100765228271 Supervised loss: 1.9067673683166504\n", + "Acc. loss: 580.6233520507812 KL loss.: 0.7614714503288269 Supervised loss: 1.9509303569793701\n", + "Acc. loss: 596.6732177734375 KL loss.: 0.9142471551895142 Supervised loss: 1.920497179031372\n", + "Acc. loss: 593.59130859375 KL loss.: 1.0441635847091675 Supervised loss: 1.8958165645599365\n", + "Acc. loss: 575.1490478515625 KL loss.: 0.931157112121582 Supervised loss: 1.9656482934951782\n", + "Acc. loss: 574.169921875 KL loss.: 0.8856077194213867 Supervised loss: 1.9071497917175293\n", + "Acc. loss: 579.6165161132812 KL loss.: 0.822207510471344 Supervised loss: 1.9472527503967285\n", + "Acc. loss: 595.7883911132812 KL loss.: 0.8688073754310608 Supervised loss: 1.9206924438476562\n", + "Acc. loss: 592.8045654296875 KL loss.: 1.0192867517471313 Supervised loss: 1.8955371379852295\n", + "Acc. loss: 574.3530883789062 KL loss.: 0.9767225980758667 Supervised loss: 1.9624618291854858\n", + "Acc. loss: 573.3468017578125 KL loss.: 0.9486246705055237 Supervised loss: 1.9066505432128906\n", + "Acc. loss: 578.7392578125 KL loss.: 0.8650252819061279 Supervised loss: 1.9502310752868652\n", + "Acc. loss: 594.836669921875 KL loss.: 0.8434177041053772 Supervised loss: 1.92042875289917\n", + "Acc. loss: 591.8924560546875 KL loss.: 0.9893072247505188 Supervised loss: 1.8945015668869019\n", + "Acc. loss: 573.4886474609375 KL loss.: 0.9357943534851074 Supervised loss: 1.9631731510162354\n", + "Acc. loss: 572.5101318359375 KL loss.: 1.029445767402649 Supervised loss: 1.9038658142089844\n", + "Acc. loss: 578.0886840820312 KL loss.: 0.9821639657020569 Supervised loss: 1.9480891227722168\n", + "Acc. loss: 594.0009155273438 KL loss.: 0.9296137094497681 Supervised loss: 1.9210476875305176\n", + "Acc. loss: 591.1126708984375 KL loss.: 0.9822692275047302 Supervised loss: 1.894586443901062\n", + "Acc. loss: 572.7469482421875 KL loss.: 0.9375599026679993 Supervised loss: 1.9674644470214844\n", + "Acc. loss: 571.7532958984375 KL loss.: 1.0046072006225586 Supervised loss: 1.9046990871429443\n", + "Acc. loss: 577.114013671875 KL loss.: 0.9754020571708679 Supervised loss: 1.952930212020874\n", + "Acc. loss: 593.19580078125 KL loss.: 0.9729540944099426 Supervised loss: 1.9212489128112793\n", + "Acc. loss: 590.2562255859375 KL loss.: 1.0231703519821167 Supervised loss: 1.8928651809692383\n", + "Acc. loss: 571.9209594726562 KL loss.: 0.9588805437088013 Supervised loss: 1.961313009262085\n", + "Acc. loss: 570.8932495117188 KL loss.: 1.0559180974960327 Supervised loss: 1.9093527793884277\n", + "Acc. loss: 576.3781127929688 KL loss.: 1.0427532196044922 Supervised loss: 1.946725845336914\n", + "Acc. loss: 592.3807983398438 KL loss.: 1.0153725147247314 Supervised loss: 1.9215850830078125\n", + "Acc. loss: 589.5279541015625 KL loss.: 1.0459179878234863 Supervised loss: 1.8924920558929443\n", + "Acc. loss: 571.0513305664062 KL loss.: 1.0074565410614014 Supervised loss: 1.962456226348877\n", + "Acc. loss: 570.1802368164062 KL loss.: 1.0864084959030151 Supervised loss: 1.90628182888031\n", + "Acc. loss: 575.6062622070312 KL loss.: 1.0506802797317505 Supervised loss: 1.949343204498291\n", + "Acc. loss: 591.642333984375 KL loss.: 1.0851917266845703 Supervised loss: 1.9217478036880493\n", + "Acc. loss: 588.6165161132812 KL loss.: 1.1396090984344482 Supervised loss: 1.8877249956130981\n", + "Acc. loss: 570.5659790039062 KL loss.: 0.9689657688140869 Supervised loss: 1.9663496017456055\n", + "Acc. loss: 569.64208984375 KL loss.: 0.9728026390075684 Supervised loss: 1.9039344787597656\n", + "Acc. loss: 574.8248901367188 KL loss.: 1.0950126647949219 Supervised loss: 1.9499857425689697\n", + "Acc. loss: 590.7850341796875 KL loss.: 1.1685123443603516 Supervised loss: 1.9214177131652832\n", + "Acc. loss: 587.7109985351562 KL loss.: 1.201756477355957 Supervised loss: 1.8921332359313965\n", + "Acc. loss: 569.7089233398438 KL loss.: 1.0333195924758911 Supervised loss: 1.963461995124817\n", + "Acc. loss: 568.7250366210938 KL loss.: 1.088441252708435 Supervised loss: 1.90415358543396\n", + "Acc. loss: 574.25634765625 KL loss.: 1.1016507148742676 Supervised loss: 1.9463155269622803\n", + "Acc. loss: 589.9261474609375 KL loss.: 1.0685217380523682 Supervised loss: 1.9189798831939697\n", + "Acc. loss: 587.0904541015625 KL loss.: 1.167457103729248 Supervised loss: 1.8938937187194824\n", + "Acc. loss: 569.0497436523438 KL loss.: 1.0786960124969482 Supervised loss: 1.9632554054260254\n", + "Acc. loss: 568.0313720703125 KL loss.: 1.089814305305481 Supervised loss: 1.903356909751892\n", + "Acc. loss: 573.2232666015625 KL loss.: 1.1443216800689697 Supervised loss: 1.9511959552764893\n", + "Acc. loss: 589.1211547851562 KL loss.: 1.2551906108856201 Supervised loss: 1.916337013244629\n", + "Acc. loss: 586.2742309570312 KL loss.: 1.2889806032180786 Supervised loss: 1.886247158050537\n", + "Acc. loss: 568.4130859375 KL loss.: 1.027073621749878 Supervised loss: 1.9663026332855225\n", + "Acc. loss: 567.4453735351562 KL loss.: 1.0290007591247559 Supervised loss: 1.9017117023468018\n", + "Acc. loss: 572.6447143554688 KL loss.: 1.1767923831939697 Supervised loss: 1.9468029737472534\n", + "Acc. loss: 588.4607543945312 KL loss.: 1.369188666343689 Supervised loss: 1.9208760261535645\n", + "Acc. loss: 585.6040649414062 KL loss.: 1.3018999099731445 Supervised loss: 1.8867595195770264\n", + "Acc. loss: 567.5736694335938 KL loss.: 1.1973426342010498 Supervised loss: 1.9703209400177002\n", + "Acc. loss: 566.6812744140625 KL loss.: 1.0848957300186157 Supervised loss: 1.9025384187698364\n", + "Acc. loss: 572.1137084960938 KL loss.: 1.1495221853256226 Supervised loss: 1.9451578855514526\n", + "Acc. loss: 587.6824340820312 KL loss.: 1.2652661800384521 Supervised loss: 1.9200594425201416\n", + "Acc. loss: 584.7115478515625 KL loss.: 1.4674760103225708 Supervised loss: 1.887754201889038\n", + "Acc. loss: 566.888671875 KL loss.: 1.188211441040039 Supervised loss: 1.9687211513519287\n", + "Acc. loss: 565.9701538085938 KL loss.: 1.1335196495056152 Supervised loss: 1.904417634010315\n", + "Acc. loss: 571.581787109375 KL loss.: 1.065040946006775 Supervised loss: 1.9477155208587646\n", + "Acc. loss: 587.2252197265625 KL loss.: 1.1860339641571045 Supervised loss: 1.9186421632766724\n", + "Acc. loss: 583.9530639648438 KL loss.: 1.414807915687561 Supervised loss: 1.881492257118225\n", + "Acc. loss: 566.025146484375 KL loss.: 1.4257429838180542 Supervised loss: 1.9635826349258423\n", + "Acc. loss: 565.33349609375 KL loss.: 1.2637348175048828 Supervised loss: 1.9034743309020996\n", + "Acc. loss: 570.647705078125 KL loss.: 1.1112487316131592 Supervised loss: 1.9479444026947021\n", + "Acc. loss: 586.5643920898438 KL loss.: 1.2012691497802734 Supervised loss: 1.9098659753799438\n", + "Acc. loss: 583.2797241210938 KL loss.: 1.4511895179748535 Supervised loss: 1.89836847782135\n", + "Acc. loss: 565.4613647460938 KL loss.: 1.4225196838378906 Supervised loss: 1.9656457901000977\n", + "Acc. loss: 564.5144653320312 KL loss.: 1.4115504026412964 Supervised loss: 1.906143069267273\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Acc. loss: 570.0428466796875 KL loss.: 1.2529869079589844 Supervised loss: 1.9424664974212646\n", + "Acc. loss: 585.8638305664062 KL loss.: 1.1537368297576904 Supervised loss: 1.917574405670166\n", + "Acc. loss: 582.84228515625 KL loss.: 1.3180943727493286 Supervised loss: 1.881895899772644\n", + "Acc. loss: 564.8743286132812 KL loss.: 1.3374701738357544 Supervised loss: 1.9691131114959717\n", + "Acc. loss: 563.9944458007812 KL loss.: 1.4079235792160034 Supervised loss: 1.8945671319961548\n", + "Acc. loss: 569.2296752929688 KL loss.: 1.4860103130340576 Supervised loss: 1.9335544109344482\n", + "Acc. loss: 585.0882568359375 KL loss.: 1.3481898307800293 Supervised loss: 1.9167640209197998\n", + "Acc. loss: 582.326904296875 KL loss.: 1.302163004875183 Supervised loss: 1.8882637023925781\n", + "Acc. loss: 564.4805908203125 KL loss.: 1.1653491258621216 Supervised loss: 1.9626132249832153\n", + "Acc. loss: 563.4149169921875 KL loss.: 1.4037744998931885 Supervised loss: 1.899533748626709\n", + "Acc. loss: 568.4398803710938 KL loss.: 1.5649011135101318 Supervised loss: 1.9464385509490967\n", + "Acc. loss: 584.202880859375 KL loss.: 1.5281683206558228 Supervised loss: 1.9251255989074707\n", + "Acc. loss: 581.3310546875 KL loss.: 1.5533851385116577 Supervised loss: 1.8812135457992554\n", + "Acc. loss: 563.5504150390625 KL loss.: 1.3680479526519775 Supervised loss: 1.9661922454833984\n", + "Acc. loss: 562.8048706054688 KL loss.: 1.3140019178390503 Supervised loss: 1.9013421535491943\n", + "Acc. loss: 567.986572265625 KL loss.: 1.4156007766723633 Supervised loss: 1.952866554260254\n", + "Acc. loss: 583.6353149414062 KL loss.: 1.5742627382278442 Supervised loss: 1.9193828105926514\n", + "Acc. loss: 580.5467529296875 KL loss.: 1.733505129814148 Supervised loss: 1.880828857421875\n", + "Acc. loss: 562.8668212890625 KL loss.: 1.6808375120162964 Supervised loss: 1.9670491218566895\n", + "Acc. loss: 562.0729370117188 KL loss.: 1.5595041513442993 Supervised loss: 1.8932360410690308\n", + "Acc. loss: 567.6163940429688 KL loss.: 1.2764384746551514 Supervised loss: 1.946180820465088\n", + "Acc. loss: 583.1798706054688 KL loss.: 1.3198256492614746 Supervised loss: 1.9101784229278564\n", + "Acc. loss: 580.0220336914062 KL loss.: 1.6243479251861572 Supervised loss: 1.8735712766647339\n", + "Acc. loss: 562.3896484375 KL loss.: 1.6117701530456543 Supervised loss: 1.9534060955047607\n", + "Acc. loss: 561.314453125 KL loss.: 1.781530737876892 Supervised loss: 1.9074735641479492\n", + "Acc. loss: 566.876953125 KL loss.: 1.5236732959747314 Supervised loss: 1.9485435485839844\n", + "Acc. loss: 582.6624755859375 KL loss.: 1.4727513790130615 Supervised loss: 1.9154832363128662\n", + "Acc. loss: 579.3927001953125 KL loss.: 1.681654453277588 Supervised loss: 1.8843389749526978\n", + "Acc. loss: 561.8161010742188 KL loss.: 1.5608888864517212 Supervised loss: 1.9630773067474365\n", + "Acc. loss: 560.8428955078125 KL loss.: 1.7561500072479248 Supervised loss: 1.8921267986297607\n", + "Acc. loss: 566.278076171875 KL loss.: 1.634752631187439 Supervised loss: 1.9345552921295166\n", + "Acc. loss: 581.6873779296875 KL loss.: 1.6753586530685425 Supervised loss: 1.8986831903457642\n", + "Acc. loss: 578.913818359375 KL loss.: 1.695082426071167 Supervised loss: 1.8706419467926025\n", + "Acc. loss: 561.5447998046875 KL loss.: 1.446695327758789 Supervised loss: 1.9714767932891846\n", + "Acc. loss: 560.3269653320312 KL loss.: 1.6138434410095215 Supervised loss: 1.8874045610427856\n", + "Acc. loss: 565.6227416992188 KL loss.: 1.7834895849227905 Supervised loss: 1.9290094375610352\n", + "Acc. loss: 581.0274047851562 KL loss.: 1.8923317193984985 Supervised loss: 1.922577142715454\n", + "Acc. loss: 578.2293701171875 KL loss.: 1.7908681631088257 Supervised loss: 1.8934736251831055\n", + "Acc. loss: 560.6727905273438 KL loss.: 1.603097677230835 Supervised loss: 1.9450111389160156\n", + "Acc. loss: 560.0771484375 KL loss.: 1.470889925956726 Supervised loss: 1.8915791511535645\n", + "Acc. loss: 565.0996704101562 KL loss.: 1.6429342031478882 Supervised loss: 1.9374959468841553\n", + "Acc. loss: 580.351806640625 KL loss.: 1.8323121070861816 Supervised loss: 1.8960058689117432\n", + "Acc. loss: 577.40771484375 KL loss.: 1.9743045568466187 Supervised loss: 1.8556313514709473\n", + "Acc. loss: 560.3222045898438 KL loss.: 1.579010248184204 Supervised loss: 1.9831340312957764\n", + "Acc. loss: 559.3883666992188 KL loss.: 1.6438060998916626 Supervised loss: 1.8814544677734375\n", + "Acc. loss: 564.4851684570312 KL loss.: 1.7123193740844727 Supervised loss: 1.932579755783081\n", + "Acc. loss: 579.9849243164062 KL loss.: 1.8546631336212158 Supervised loss: 1.896489143371582\n", + "Acc. loss: 576.8028564453125 KL loss.: 2.10239315032959 Supervised loss: 1.8693045377731323\n", + "Acc. loss: 559.4261474609375 KL loss.: 1.8848878145217896 Supervised loss: 1.963029384613037\n", + "Acc. loss: 558.7362060546875 KL loss.: 1.846110224723816 Supervised loss: 1.9016344547271729\n", + "Acc. loss: 564.0462036132812 KL loss.: 1.6996270418167114 Supervised loss: 1.9176771640777588\n", + "Acc. loss: 579.5333251953125 KL loss.: 1.8076509237289429 Supervised loss: 1.8686883449554443\n", + "Acc. loss: 576.3933715820312 KL loss.: 1.9800454378128052 Supervised loss: 1.8433094024658203\n", + "Acc. loss: 558.8661499023438 KL loss.: 2.0599794387817383 Supervised loss: 1.953434705734253\n", + "Acc. loss: 557.9244995117188 KL loss.: 2.120126962661743 Supervised loss: 1.9095783233642578\n", + "Acc. loss: 563.1441650390625 KL loss.: 2.0559446811676025 Supervised loss: 1.9231040477752686\n", + "Acc. loss: 578.677001953125 KL loss.: 1.9454342126846313 Supervised loss: 1.8668439388275146\n", + "Acc. loss: 575.7185668945312 KL loss.: 2.0908541679382324 Supervised loss: 1.8623499870300293\n", + "Acc. loss: 558.3235473632812 KL loss.: 1.943078637123108 Supervised loss: 1.9443577527999878\n", + "Acc. loss: 557.5866088867188 KL loss.: 1.9217761754989624 Supervised loss: 1.895201325416565\n", + "Acc. loss: 562.6309204101562 KL loss.: 1.9284868240356445 Supervised loss: 1.8935331106185913\n", + "Acc. loss: 578.0524291992188 KL loss.: 2.0874316692352295 Supervised loss: 1.869370460510254\n", + "Acc. loss: 575.2984008789062 KL loss.: 2.1968181133270264 Supervised loss: 1.8306560516357422\n", + "Acc. loss: 558.0562133789062 KL loss.: 1.7949397563934326 Supervised loss: 1.9529144763946533\n", + "Acc. loss: 556.869384765625 KL loss.: 2.0627574920654297 Supervised loss: 1.8516747951507568\n", + "Acc. loss: 561.9600219726562 KL loss.: 2.0995934009552 Supervised loss: 1.8924217224121094\n", + "Acc. loss: 577.57275390625 KL loss.: 2.155583381652832 Supervised loss: 1.861236572265625\n", + "Acc. loss: 574.46142578125 KL loss.: 2.3844902515411377 Supervised loss: 1.810544729232788\n", + "Acc. loss: 557.3073120117188 KL loss.: 1.9975894689559937 Supervised loss: 1.9677320718765259\n", + "Acc. loss: 556.4171752929688 KL loss.: 1.9837427139282227 Supervised loss: 1.8719208240509033\n", + "Acc. loss: 561.6437377929688 KL loss.: 2.131303071975708 Supervised loss: 1.9146819114685059\n", + "Acc. loss: 576.7172241210938 KL loss.: 2.4541409015655518 Supervised loss: 1.840390920639038\n", + "Acc. loss: 573.9681396484375 KL loss.: 2.3478920459747314 Supervised loss: 1.8231446743011475\n", + "Acc. loss: 556.7078247070312 KL loss.: 2.239356756210327 Supervised loss: 1.888296365737915\n", + "Acc. loss: 555.8238525390625 KL loss.: 2.38700008392334 Supervised loss: 1.8364381790161133\n", + "Acc. loss: 561.0744018554688 KL loss.: 2.1677701473236084 Supervised loss: 1.8795933723449707\n", + "Acc. loss: 576.4382934570312 KL loss.: 2.2777228355407715 Supervised loss: 1.815284013748169\n", + "Acc. loss: 573.4437866210938 KL loss.: 2.5250754356384277 Supervised loss: 1.8003288507461548\n", + "Acc. loss: 555.9330444335938 KL loss.: 2.4100682735443115 Supervised loss: 1.8928945064544678\n", + "Acc. loss: 555.29541015625 KL loss.: 2.3777408599853516 Supervised loss: 1.783129096031189\n", + "Acc. loss: 560.2907104492188 KL loss.: 2.596811056137085 Supervised loss: 1.8348662853240967\n", + "Acc. loss: 575.5371704101562 KL loss.: 2.746939182281494 Supervised loss: 1.829972743988037\n", + "Acc. loss: 572.392333984375 KL loss.: 2.8560402393341064 Supervised loss: 1.7449336051940918\n", + "Acc. loss: 555.6661376953125 KL loss.: 2.294834613800049 Supervised loss: 1.834438681602478\n", + "Acc. loss: 555.2384643554688 KL loss.: 2.1053340435028076 Supervised loss: 1.8147554397583008\n", + "Acc. loss: 560.1636962890625 KL loss.: 2.3167884349823 Supervised loss: 1.765084981918335\n", + "Acc. loss: 574.8413696289062 KL loss.: 2.7704274654388428 Supervised loss: 1.7469749450683594\n", + "Acc. loss: 571.8780517578125 KL loss.: 3.0380141735076904 Supervised loss: 1.6983928680419922\n", + "Acc. loss: 555.0167236328125 KL loss.: 2.542062282562256 Supervised loss: 1.828558087348938\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Acc. loss: 554.577392578125 KL loss.: 2.248950481414795 Supervised loss: 1.7718029022216797\n", + "Acc. loss: 559.764892578125 KL loss.: 2.209470272064209 Supervised loss: 1.7533116340637207\n", + "Acc. loss: 574.5811767578125 KL loss.: 2.7022790908813477 Supervised loss: 1.639876127243042\n", + "Acc. loss: 571.3305053710938 KL loss.: 3.155371904373169 Supervised loss: 1.607314944267273\n", + "Acc. loss: 554.3609008789062 KL loss.: 2.7596218585968018 Supervised loss: 1.7376245260238647\n", + "Acc. loss: 553.9786987304688 KL loss.: 2.4950077533721924 Supervised loss: 1.637010097503662\n", + "Acc. loss: 559.1825561523438 KL loss.: 2.4872887134552 Supervised loss: 1.719037652015686\n", + "Acc. loss: 574.06689453125 KL loss.: 2.772099018096924 Supervised loss: 1.4953625202178955\n", + "Acc. loss: 570.8015747070312 KL loss.: 3.13488507270813 Supervised loss: 1.5272738933563232\n", + "Acc. loss: 553.8721313476562 KL loss.: 2.79630708694458 Supervised loss: 1.6142207384109497\n", + "Acc. loss: 553.0410766601562 KL loss.: 2.8367996215820312 Supervised loss: 1.5048125982284546\n", + "Acc. loss: 558.6292114257812 KL loss.: 2.539038896560669 Supervised loss: 1.5112793445587158\n", + "Acc. loss: 573.9107055664062 KL loss.: 2.7608115673065186 Supervised loss: 1.355721116065979\n", + "Acc. loss: 570.52685546875 KL loss.: 3.2968380451202393 Supervised loss: 1.4281235933303833\n", + "Acc. loss: 553.066162109375 KL loss.: 3.541501998901367 Supervised loss: 1.342860460281372\n", + "Acc. loss: 552.5811157226562 KL loss.: 3.11316180229187 Supervised loss: 1.382831335067749\n", + "Acc. loss: 558.38720703125 KL loss.: 2.642934560775757 Supervised loss: 1.474321961402893\n", + "Acc. loss: 573.5700073242188 KL loss.: 2.69863224029541 Supervised loss: 1.2664884328842163\n", + "Acc. loss: 570.3533935546875 KL loss.: 3.1198909282684326 Supervised loss: 1.2228034734725952\n", + "Acc. loss: 552.574951171875 KL loss.: 3.549316644668579 Supervised loss: 1.2236199378967285\n", + "Acc. loss: 552.0785522460938 KL loss.: 3.5360400676727295 Supervised loss: 1.3782474994659424\n", + "Acc. loss: 557.6590576171875 KL loss.: 2.8420560359954834 Supervised loss: 1.2370854616165161\n", + "Acc. loss: 573.4944458007812 KL loss.: 2.499807357788086 Supervised loss: 1.2722680568695068\n", + "Acc. loss: 570.7427368164062 KL loss.: 2.457063674926758 Supervised loss: 1.0969992876052856\n", + "Acc. loss: 552.9144897460938 KL loss.: 2.8306381702423096 Supervised loss: 1.1824209690093994\n", + "Acc. loss: 551.3397827148438 KL loss.: 3.703986644744873 Supervised loss: 1.103725790977478\n", + "Acc. loss: 556.0292358398438 KL loss.: 4.317842960357666 Supervised loss: 1.081669569015503\n", + "Acc. loss: 571.7138061523438 KL loss.: 3.5610690116882324 Supervised loss: 1.119594693183899\n", + "Acc. loss: 569.61962890625 KL loss.: 2.8704564571380615 Supervised loss: 1.3107800483703613\n", + "Acc. loss: 553.0933227539062 KL loss.: 2.377285957336426 Supervised loss: 1.26150381565094\n", + "Acc. loss: 551.9124755859375 KL loss.: 2.9168508052825928 Supervised loss: 1.1364418268203735\n", + "Acc. loss: 556.3005981445312 KL loss.: 3.5164976119995117 Supervised loss: 1.0406793355941772\n", + "Acc. loss: 570.872314453125 KL loss.: 4.0905232429504395 Supervised loss: 0.9445632100105286\n", + "Acc. loss: 568.1494140625 KL loss.: 3.9616987705230713 Supervised loss: 0.8995190858840942\n", + "Acc. loss: 551.7135620117188 KL loss.: 3.1436359882354736 Supervised loss: 1.2235773801803589\n", + "Acc. loss: 551.4240112304688 KL loss.: 2.843596935272217 Supervised loss: 1.1736174821853638\n", + "Acc. loss: 556.5046997070312 KL loss.: 2.8595447540283203 Supervised loss: 1.3238774538040161\n", + "Acc. loss: 571.2353515625 KL loss.: 3.2739481925964355 Supervised loss: 1.0312907695770264\n", + "Acc. loss: 568.1434936523438 KL loss.: 3.583770513534546 Supervised loss: 0.9812957048416138\n", + "Acc. loss: 550.7335815429688 KL loss.: 3.8770651817321777 Supervised loss: 1.0849758386611938\n", + "Acc. loss: 550.1336059570312 KL loss.: 3.8139264583587646 Supervised loss: 0.9169766902923584\n", + "Acc. loss: 555.7266845703125 KL loss.: 3.6126813888549805 Supervised loss: 1.1507575511932373\n", + "Acc. loss: 570.2406005859375 KL loss.: 3.8094663619995117 Supervised loss: 1.0025825500488281\n", + "Acc. loss: 567.7470703125 KL loss.: 3.6681201457977295 Supervised loss: 1.0139069557189941\n", + "Acc. loss: 550.95068359375 KL loss.: 3.207805633544922 Supervised loss: 1.2260971069335938\n", + "Acc. loss: 550.1900634765625 KL loss.: 3.3008389472961426 Supervised loss: 1.2545084953308105\n", + "Acc. loss: 555.2265625 KL loss.: 3.4959182739257812 Supervised loss: 1.435852289199829\n", + "Acc. loss: 569.907470703125 KL loss.: 3.788970947265625 Supervised loss: 1.173014760017395\n", + "Acc. loss: 566.8496704101562 KL loss.: 4.293398857116699 Supervised loss: 0.7809465527534485\n", + "Acc. loss: 550.0243530273438 KL loss.: 4.026847839355469 Supervised loss: 1.4758626222610474\n", + "Acc. loss: 549.2232666015625 KL loss.: 3.92946457862854 Supervised loss: 1.4938669204711914\n", + "Acc. loss: 554.6117553710938 KL loss.: 3.924480438232422 Supervised loss: 0.9637159109115601\n", + "Acc. loss: 570.0833740234375 KL loss.: 3.8309147357940674 Supervised loss: 1.34224534034729\n", + "Acc. loss: 567.068115234375 KL loss.: 3.601590394973755 Supervised loss: 1.0908288955688477\n", + "Acc. loss: 549.9638061523438 KL loss.: 3.5121817588806152 Supervised loss: 1.2912678718566895\n", + "Acc. loss: 548.9660034179688 KL loss.: 3.9225540161132812 Supervised loss: 1.5727856159210205\n", + "Acc. loss: 554.668701171875 KL loss.: 3.5694758892059326 Supervised loss: 1.7048225402832031\n", + "Acc. loss: 569.96533203125 KL loss.: 3.103224754333496 Supervised loss: 1.5875608921051025\n", + "Acc. loss: 567.0971069335938 KL loss.: 3.3121800422668457 Supervised loss: 1.7741986513137817\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0mreload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msupervised_topic_model\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msupervised_topic_model\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSupervisedTopicModel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnum_topics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvocab_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_documents\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutcome_linear_map\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m \u001b[0mrun_supervised_tm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msubset_counts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtreatment_labels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutcomes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'binary'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_epochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1000\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Documents/causal-effects-text/src/supervised_lda/run_supervised_tm.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(model, docs, treatment_labels, outcomes, dtype, num_epochs, lr, wdecay, batch_size, use_recon_loss, use_sup_loss)\u001b[0m\n\u001b[1;32m 64\u001b[0m \u001b[0mend_index\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstart_index\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mend_index\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 66\u001b[0;31m \u001b[0mdocs_batch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdocs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 67\u001b[0m \u001b[0mtreatment_labels_batch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtreatment_labels\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[0moutcomes_batch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moutcomes\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "bootstrap_sim_df = assign_split(sim_df, num_splits=2)\n", + "bootstrap_sim_df = bootstrap_sim_df[bootstrap_sim_df.split==0]\n", + "treatment_labels = bootstrap_sim_df.treatment.values\n", + "outcomes = bootstrap_sim_df.outcome.values\n", + "\n", + "subset_counts = filter_document_terms(bootstrap_sim_df, counts, index_mapping, on='id')\n", + "num_documents = subset_counts.shape[0]\n", + "vocab_size = subset_counts.shape[1]\n", + "num_topics=100\n", + "\n", + "reload(run_supervised_tm)\n", + "reload(supervised_topic_model)\n", + "model = supervised_topic_model.SupervisedTopicModel(num_topics, vocab_size, num_documents, outcome_linear_map=False)\n", + "run_supervised_tm.train(model, subset_counts, treatment_labels, outcomes, dtype='binary', num_epochs=1000)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "####################################################################################################\n", + "Visualize topics...\n", + "Topic 0: ['generative' 'taking' 'inject' 'x' 'computer' 'subtle' 'concatenation'\n", + " 'proof' 'include' 'demonstrate']\n", + "Topic 1: ['subset' 'dimensional' 'performance' 'combine' 'yahoo' 'action'\n", + " 'rationale' 'entity' 'label' 'reduction']\n", + "Topic 2: ['third' 'semantics' 'adaptive' 'region' 'improve' 'computer' 'preferred'\n", + " 'thoroughly' 'popular' 'extensive']\n", + "Topic 3: ['probabilistic' 'decision' 'form' 'covered' 'efficiency' 'thereof'\n", + " 'successfully' 'flow' 'completely' 'datasets']\n", + "Topic 4: ['sound' 'used' 'inexpensive' 'become' 'one' 'boosting' 'kind' 'different'\n", + " 'enough' 'death']\n", + "Topic 5: ['previous' 'passage' 'tuned' 'various' 'effect' 'interpretation' 'prone'\n", + " 'required' 'node' 'upper']\n", + "Topic 6: ['digit' 'transportation' 'weak' 'filtering' 'neural' 'flow' 'attachment'\n", + " 'principle' 'entity' 'plausibility']\n", + "Topic 7: ['fixed' 'caused' 'without' 'accessing' 'supervision' 'learn' 'descent'\n", + " 'scenario' 'smart' 'contain']\n", + "Topic 8: ['support' 'applying' 'accomplished' 'find' 'tabular' 'case' 'meaning'\n", + " 'subset' 'production' 'solution']\n", + "Topic 9: ['class' 'via' 'submitted' 'approach' 'achieved' 'vulnerable' 'attribute'\n", + " 'analytical' 'human' 'ineffective']\n", + "Topic 10: ['loss' 'report' 'maintain' 'arrangement' 'perceptrons' 'either'\n", + " 'continuous' 'notion' 'enables' 'well']\n", + "Topic 11: ['friend' 'duration' 'shifting' 'quality' 'identify' 'subsampling'\n", + " 'uncertainty' 'article' 'fraction' 'propose']\n", + "Topic 12: ['social' 'signal' 'know' 'detection' 'logarithmic' 'adversarial'\n", + " 'abstract' 'functionality' 'correcting' 'expected']\n", + "Topic 13: ['kernel' 'equality' 'perceptual' 'yahoo' 'partial' 'known' 'complete'\n", + " 'version' 'language' 'automatic']\n", + "Topic 14: ['spatial' 'probabilistic' 'added' 'embeddings' 'vector' 'negatively'\n", + " 'year' 'solution' 'regularizer' 'role']\n", + "Topic 15: ['independence' 'given' 'comparison' 'ehr' 'like' 'likelihood' 'currently'\n", + " 'apply' 'lexical' 'interdependent']\n", + "Topic 16: ['auction' 'weight' 'measure' 'found' 'previous' 'approach' 'work'\n", + " 'theoretical' 'struggle' 'provides']\n", + "Topic 17: ['compatible' 'way' 'respective' 'transforming' 'fact' 'structure'\n", + " 'reason' 'relevance' 'application' 'animal']\n", + "Topic 18: ['hybrid' 'without' 'often' 'paper' 'supervision' 'tedious' 'motion'\n", + " 'therefore' 'topic' 'optimal']\n", + "Topic 19: ['mind' 'particularly' 'social' 'identify' 'benefit' 'logic' 'spent'\n", + " 'individual' 'token' 'corpus']\n", + "Topic 20: ['extract' 'tackling' 'specific' 'statistical' 'noise' 'online' 'larger'\n", + " 'highway' 'sampling' 'related']\n", + "Topic 21: ['state' 'animal' 'optimal' 'learning' 'youtube' 'testing' 'portfolio'\n", + " 'output' 'problem' 'et']\n", + "Topic 22: ['evolves' 'expected' 'special' 'sector' 'implicitly' 'planning'\n", + " 'literature' 'screen' 'artificial' 'updated']\n", + "Topic 23: ['design' 'inject' 'protein' 'surveillance' 'carrying' 'effective'\n", + " 'beneficial' 'markov' 'incurring' 'saving']\n", + "Topic 24: ['assignment' 'compute' 'like' 'triggered' 'kind' 'directly' 'document'\n", + " 'addition' 'clustering' 'classification']\n", + "Topic 25: ['software' 'paradigm' 'equal' 'guide' 'timit' 'capture' 'entailment'\n", + " 'house' 'candidate' 'size']\n", + "Topic 26: ['prototype' 'argument' 'effort' 'two' 'improvement' 'knowledge'\n", + " 'represents' 'achieve' 'optimal' 'shift']\n", + "Topic 27: ['efficient' 'resolution' 'comprehensive' 'associated' 'search' 'category'\n", + " 'detection' 'carlo' 'error' 'low']\n", + "Topic 28: ['restriction' 'technology' 'identifies' 'representation' 'crucially'\n", + " 'convergent' 'f' 'loopy' 'mutual' 'experimental']\n", + "Topic 29: ['convergence' 'speed' 'ascent' 'understanding' 'test' 'skill'\n", + " 'computational' 'degrade' 'probability' 'significantly']\n", + "Topic 30: ['strongly' 'advertiser' 'intelligence' 'scientist' 'new' 'optimization'\n", + " 'initialization' 'address' 'handcrafted' 'expanded']\n", + "Topic 31: ['stage' 'researcher' 'turn' 'research' 'weight' 'found' 'dynamic'\n", + " 'automation' 'accepted' 'animal']\n", + "Topic 32: ['development' 'listener' 'amongst' 'fluid' 'nature' 'subsystem' 'address'\n", + " 'verb' 'recommendation' 'lead']\n", + "Topic 33: ['general' 'predicting' 'measurement' 'overlapping' 'tractability'\n", + " 'generalization' 'deeper' 'decoding' 'emphasizes' 'second']\n", + "Topic 34: ['lower' 'specific' 'time' 'related' 'across' 'literature' 'solving'\n", + " 'embedded' 'deriving' 'composed']\n", + "Topic 35: ['construct' 'pose' 'agent' 'daily' 'regularize' 'prevalent' 'linkage'\n", + " 'fuzzy' 'class' 'reduction']\n", + "Topic 36: ['established' 'size' 'simpler' 'accessible' 'equilibrium' 'document'\n", + " 'article' 'mapping' 'showing' 'bound']\n", + "Topic 37: ['appears' 'optimally' 'spent' 'many' 'response' 'formulation' 'exist'\n", + " 'spectral' 'reproduce' 'syntactically']\n", + "Topic 38: ['abundant' 'state' 'query' 'morpheme' 'graph' 'tune' 'future' 'sensor'\n", + " 'achieves' 'merely']\n", + "Topic 39: ['adjust' 'daily' 'judged' 'variability' 'performed' 'flat' 'well'\n", + " 'describes' 'language' 'inference']\n", + "Topic 40: ['rough' 'four' 'team' 'audio' 'organize' 'conditional' 'phase' 'sample'\n", + " 'collaboration' 'associated']\n", + "Topic 41: ['speed' 'decompose' 'processing' 'compositional' 'incompleteness'\n", + " 'improves' 'strategy' 'plus' 'gathering' 'idea']\n", + "Topic 42: ['employed' 'current' 'next' 'fact' 'log' 'setting' 'referring' 'role'\n", + " 'area' 'approximation']\n", + "Topic 43: ['exemplar' 'emotion' 'interoperability' 'generalization' 'upon'\n", + " 'programming' 'lie' 'far' 'require' 'test']\n", + "Topic 44: ['evolution' 'basis' 'rating' 'similarly' 'would' 'standard' 'discover'\n", + " 'basically' 'software' 'difficult']\n", + "Topic 45: ['efficient' 'latent' 'multitude' 'engagement' 'ontology' 'multiple'\n", + " 'observes' 'challenging' 'standard' 'daily']\n", + "Topic 46: ['executing' 'property' 'grammar' 'cardinality' 'deviation' 'apply'\n", + " 'modulo' 'appropriate' 'capturing' 'global']\n", + "Topic 47: ['empirical' 'interacts' 'relation' 'hidden' 'region' 'section'\n", + " 'important' 'speedup' 'succeed' 'specially']\n", + "Topic 48: ['stacking' 'program' 'school' 'requires' 'think' 'fluent' 'situation'\n", + " 'limit' 'protection' 'question']\n", + "Topic 49: ['issue' 'according' 'provide' 'find' 'serve' 'across' 'concept'\n", + " 'distance' 'major' 'prescribed']\n", + "Topic 50: ['cause' 'old' 'section' 'near' 'behave' 'planning' 'found' 'predictive'\n", + " 'across' 'achievement']\n", + "Topic 51: ['parameter' 'goal' 'signal' 'student' 'taking' 'leverage' 'existing'\n", + " 'cluster' 'triplet' 'compared']\n", + "Topic 52: ['defined' 'compare' 'reduction' 'inference' 'solves' 'offer' 'outcome'\n", + " 'lack' 'show' 'give']\n", + "Topic 53: ['lt' 'formal' 'modelling' 'nvidia' 'become' 'labelling' 'dimension'\n", + " 'extraction' 'effect' 'script']\n", + "Topic 54: ['constraint' 'approximation' 'dedicated' 'adversarially' 'towards' 'csps'\n", + " 'associate' 'nature' 'uniformly' 'solve']\n", + "Topic 55: ['although' 'material' 'resource' 'network' 'tracking' 'hidden' 'setting'\n", + " 'known' 'multilayer' 'sparse']\n", + "Topic 56: ['existing' 'negative' 'inform' 'representation' 'machine'\n", + " 'incompleteness' 'constraint' 'brought' 'high' 'user']\n", + "Topic 57: ['arcade' 'recognition' 'rough' 'break' 'separator' 'achieve' 'towards'\n", + " 'version' 'context' 'recognize']\n", + "Topic 58: ['term' 'speak' 'lag' 'technological' 'passing' 'called' 'solve'\n", + " 'mentioned' 'web' 'recognition']\n", + "Topic 59: ['plus' 'identifying' 'infers' 'rnn' 'together' 'may' 'morpheme' 'k'\n", + " 'general' 'dictionary']\n", + "Topic 60: ['nlp' 'ranging' 'captioning' 'sale' 'graph' 'correctness' 'benchmark'\n", + " 'text' 'candidate' 'result']\n", + "Topic 61: ['achieved' 'visualizing' 'reconstructed' 'sparse' 'syntactical' 'current'\n", + " 'complex' 'highly' 'convolution' 'fixed']\n", + "Topic 62: ['developed' 'computation' 'expect' 'proximity' 'bottom' 'continuous'\n", + " 'process' 'matrix' 'compared' 'robust']\n", + "Topic 63: ['semantic' 'demanding' 'length' 'experimental' 'basic' 'element'\n", + " 'assumption' 'therefore' 'example' 'sequence']\n", + "Topic 64: ['underlying' 'graphic' 'cue' 'operation' 'establishing' 'sequence'\n", + " 'rhetorical' 'unified' 'reduction' 'behavior']\n", + "Topic 65: ['whose' 'hierarchical' 'capability' 'color' 'learned' 'dnn' 'yet'\n", + " 'dataset' 'performance' 'different']\n", + "Topic 66: ['plan' 'smt' 'percentage' 'approach' 'partner' 'marked' 'inspired'\n", + " 'start' 'controllable' 'latent']\n", + "Topic 67: ['previous' 'behavior' 'norm' 'analysis' 'predictable' 'reconstructed'\n", + " 'reader' 'empirical' 'recent' 'committee']\n", + "Topic 68: ['high' 'provided' 'effort' 'otherwise' 'restricted' 'promising' 'recent'\n", + " 'commonly' 'influenced' 'average']\n", + "Topic 69: ['pomdp' 'paper' 'student' 'right' 'collected' 'graph' 'discover' 'parse'\n", + " 'campaign' 'perform']\n", + "Topic 70: ['parallel' 'practice' 'popular' 'location' 'approximate' 'empirically'\n", + " 'low' 'crisp' 'notion' 'hindi']\n", + "Topic 71: ['formulation' 'prevention' 'national' 'equivalent' 'downstream' 'come'\n", + " 'rather' 'stopping' 'due' 'observation']\n", + "Topic 72: ['stochastic' 'performance' 'goal' 'affect' 'large' 'hold' 'performs'\n", + " 'kernel' 'frame' 'generating']\n", + "Topic 73: ['matrix' 'researcher' 'enjoy' 'illustrating' 'bayesian' 'mse'\n", + " 'outperforms' 'hessian' 'control' 'observation']\n", + "Topic 74: ['numeric' 'recurrent' 'account' 'action' 'relationship' 'take' 'decision'\n", + " 'region' 'compose' 'difficult']\n", + "Topic 75: ['adaptive' 'path' 'reliability' 'pearl' 'disjunctive' 'object'\n", + " 'corresponding' 'language' 'clustering' 'feedback']\n", + "Topic 76: ['becomes' 'appears' 'rnns' 'idea' 'represent' 'number' 'view' 'weighted'\n", + " 'generated' 'randomization']\n", + "Topic 77: ['give' 'function' 'growth' 'appearance' 'penetration' 'geometric'\n", + " 'achieved' 'whose' 'relevant' 'fire']\n", + "Topic 78: ['search' 'presented' 'base' 'journal' 'emerge' 'comparable' 'explored'\n", + " 'nevertheless' 'constructing' 'report']\n", + "Topic 79: ['subset' 'instead' 'ranking' 'consider' 'furthermore' 'mistake' 'audio'\n", + " 'using' 'iteration' 'binary']\n", + "Topic 80: ['find' 'dense' 'subsumes' 'exploit' 'air' 'receptive' 'covered' 'rating'\n", + " 'specifically' 'capture']\n", + "Topic 81: ['optimized' 'discus' 'rely' 'analogy' 'efficiently' 'literal' 'amongst'\n", + " 'network' 'presented' 'gap']\n", + "Topic 82: ['task' 'allow' 'paper' 'previous' 'multilayer' 'motor' 'lemma' 'trait'\n", + " 'analytics' 'distributed']\n", + "Topic 83: ['generate' 'placing' 'player' 'specificity' 'force' 'evolutionary'\n", + " 'importance' 'injection' 'optimization' 'action']\n", + "Topic 84: ['incorporate' 'complete' 'technology' 'code' 'basic' 'around'\n", + " 'conditioning' 'done' 'highly' 'average']\n", + "Topic 85: ['optimizes' 'role' 'disambiguate' 'surpass' 'present' 'polyphonic'\n", + " 'introduce' 'cnn' 'recently' 'model']\n", + "Topic 86: ['although' 'news' 'competitor' 'rest' 'defined' 'np' 'global' 'spike'\n", + " 'distribution' 'manner']\n", + "Topic 87: ['spatial' 'uncovered' 'including' 'group' 'lower' 'leading' 'feature'\n", + " 'loss' 'hypothesis' 'initiative']\n", + "Topic 88: ['modeled' 'mel' 'large' 'access' 'successive' 'verification' 'impressive'\n", + " 'method' 'h' 'linear']\n", + "Topic 89: ['stream' 'behaviour' 'achieving' 'respect' 'labeled' 'level'\n", + " 'effectiveness' 'may' 'remain' 'recognition']\n", + "Topic 90: ['direct' 'lowest' 'experimented' 'efficiency' 'factor' 'sometimes'\n", + " 'mechanism' 'realistic' 'bridging' 'question']\n", + "Topic 91: ['hybrid' 'indexing' 'color' 'accurate' 'dozen' 'learning' 'iv' 'image'\n", + " 'demonstrate' 'example']\n", + "Topic 92: ['automatically' 'procedure' 'method' 'manipulator' 'challenging' 'k'\n", + " 'predictor' 'detected' 'possible' 'learning']\n", + "Topic 93: ['solve' 'considered' 'regarding' 'fed' 'empirical' 'characteristic'\n", + " 'filling' 'action' 'x' 'solution']\n", + "Topic 94: ['expression' 'id' 'problem' 'cost' 'literature' 'proactive' 'concept'\n", + " 'web' 'large' 'lstm']\n", + "Topic 95: ['standalone' 'planning' 'entity' 'et' 'discus' 'problem' 'overcome'\n", + " 'single' 'limited' 'produce']\n", + "Topic 96: ['computational' 'matched' 'requirement' 'impractical' 'track'\n", + " 'minimizing' 'accuracy' 'larger' 'modelling' 'code']\n", + "Topic 97: ['appear' 'resonance' 'outperform' 'pair' 'artifact' 'framework' 'belief'\n", + " 'research' 'advertising' 'sense']\n", + "Topic 98: ['answering' 'current' 'developed' 'failure' 'use' 'procedural' 'et'\n", + " 'area' 'bigram' 'supposed']\n", + "Topic 99: ['writing' 'logical' 'weight' 'exemplar' 'experimentation' 'space'\n", + " 'relevant' 'known' 'across' 'find']\n" + ] + } + ], + "source": [ + "reload(run_supervised_tm)\n", + "run_supervised_tm.visualize_topics(model, vocab,num_topics)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[0 0 0 ... 0 0 0]\n", + " [0 0 0 ... 0 0 0]\n", + " [0 0 0 ... 0 0 0]\n", + " ...\n", + " [0 0 1 ... 0 0 0]\n", + " [0 0 0 ... 0 0 0]\n", + " [0 0 0 ... 0 0 0]]\n", + "5862\n", + "tensor([[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", + " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", + " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", + " ...,\n", + " [0.0000, 0.0000, 0.0071, ..., 0.0000, 0.0000, 0.0000],\n", + " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n", + " [0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]])\n", + "tensor([[0.0087, 0.0088, 0.0100, ..., 0.0124, 0.0143, 0.0096],\n", + " [0.0072, 0.0077, 0.0066, ..., 0.0068, 0.0078, 0.0072],\n", + " [0.0234, 0.0096, 0.0096, ..., 0.0107, 0.0101, 0.0095],\n", + " ...,\n", + " [0.0030, 0.0034, 0.0033, ..., 0.0031, 0.0036, 0.0036],\n", + " [0.0036, 0.0033, 0.0038, ..., 0.0037, 0.0041, 0.0034],\n", + " [0.0097, 0.0081, 0.0090, ..., 0.0077, 0.0081, 0.0080]])\n" + ] + } + ], + "source": [ + "reload(run_supervised_tm)\n", + "propensity_scores, expected_st_treat, expected_st_no_treat = run_supervised_tm.predict(model, subset_counts, dtype='binary')" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "def psi_q_only(q_t0, q_t1, g, t, y):\n", + " ite_t = (q_t1 - q_t0)[t == 1]\n", + " estimate = ite_t.mean()\n", + " return estimate" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0.013394212" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "qhat = psi_q_only(expected_st_no_treat, expected_st_treat, propensity_scores, treatment_labels, outcomes)\n", + "qhat" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}