In [1]:
from helpers import filter_document_terms, make_index_mapping, assign_split
import numpy as np
import pandas as pd
import os
import sys
import supervised_topic_model 
import run_supervised_tm
from scipy import sparse
from importlib import reload

params="1.0"
sim_dir = '../../dat/sim/peerread_buzzytitle_based/'
mode = 'simple'
sim_setting ='beta00.25' + '.beta1' + params + '.gamma0.0'
simulation_file = sim_dir + '/mode' + mode + '/' + sim_setting + ".tsv"

In [2]:
def load_peerread(path='../../dat/PeerRead/'):
	return pd.read_csv(path + 'proc_abstracts.csv')

def load_term_counts(df, path='../../dat/PeerRead/', force_redo=False, text_col='abstract_text'):
	count_filename = path  + 'term_counts'
	vocab_filename = path + 'vocab'

	if os.path.exists(count_filename + '.npz') and not force_redo:
		return sparse.load_npz(count_filename + '.npz'), np.load(vocab_filename + '.npy')

	post_docs = df[text_col].values
	counts, vocab, _ = tokenize_documents(post_docs)    
	sparse.save_npz(count_filename, counts)
	np.save(vocab_filename, vocab)
	return counts, np.array(vocab)

def load_simulated_data():
	sim_df = pd.read_csv(simulation_file, delimiter='\t')
	return sim_df

In [3]:
peerread = load_peerread()
counts,vocab = load_term_counts(peerread)
counts= counts.toarray()

indices = peerread['paper_id'].values
index_mapping = make_index_mapping(peerread, on='index')

sim_df = load_simulated_data()

In [6]:
bootstrap_sim_df = assign_split(sim_df, num_splits=2)
bootstrap_sim_df = bootstrap_sim_df[bootstrap_sim_df.split==0]
treatment_labels = bootstrap_sim_df.treatment.values
outcomes = bootstrap_sim_df.outcome.values

subset_counts = filter_document_terms(bootstrap_sim_df, counts, index_mapping, on='id')
num_documents = subset_counts.shape[0]
vocab_size = subset_counts.shape[1]
num_topics=100

reload(run_supervised_tm)
reload(supervised_topic_model)
model = supervised_topic_model.SupervisedTopicModel(num_topics, vocab_size, num_documents, outcome_linear_map=False)
run_supervised_tm.train(model, subset_counts, treatment_labels, outcomes, dtype='binary', num_epochs=1000)



Acc. loss: 634.7957763671875 KL loss.: 0.04140656813979149 Supervised loss: 2.088380813598633
Acc. loss: 652.6654052734375 KL loss.: 0.1379348635673523 Supervised loss: 2.124856948852539
Acc. loss: 649.2124633789062 KL loss.: 0.13782715797424316 Supervised loss: 2.0561790466308594
Acc. loss: 628.5281372070312 KL loss.: 0.09775403887033463 Supervised loss: 2.0797927379608154
Acc. loss: 626.9488525390625 KL loss.: 0.10366755723953247 Supervised loss: 2.0787906646728516
Acc. loss: 632.6178588867188 KL loss.: 0.11328624188899994 Supervised loss: 2.074093818664551
Acc. loss: 650.6323852539062 KL loss.: 0.10288330912590027 Supervised loss: 2.066779613494873
Acc. loss: 647.1697998046875 KL loss.: 0.12123933434486389 Supervised loss: 2.0511393547058105
Acc. loss: 626.5287475585938 KL loss.: 0.1381053626537323 Supervised loss: 2.0768837928771973
Acc. loss: 625.0350952148438 KL loss.: 0.08887579292058945 Supervised loss: 2.0441465377807617
Acc. loss: 630.6015014648438 KL loss.: 0.142496824264526

Acc. loss: 599.6876220703125 KL loss.: 0.38490843772888184 Supervised loss: 1.9805517196655273
Acc. loss: 598.4629516601562 KL loss.: 0.41774091124534607 Supervised loss: 1.9279911518096924
Acc. loss: 603.804443359375 KL loss.: 0.48628202080726624 Supervised loss: 1.9626843929290771
Acc. loss: 620.818115234375 KL loss.: 0.5278497934341431 Supervised loss: 1.941227674484253
Acc. loss: 617.5929565429688 KL loss.: 0.48286446928977966 Supervised loss: 1.9193298816680908
Acc. loss: 598.300048828125 KL loss.: 0.382050484418869 Supervised loss: 1.9763628244400024
Acc. loss: 597.0465087890625 KL loss.: 0.39224833250045776 Supervised loss: 1.9261529445648193
Acc. loss: 602.426513671875 KL loss.: 0.4795830547809601 Supervised loss: 1.961333155632019
Acc. loss: 619.4287719726562 KL loss.: 0.5574470162391663 Supervised loss: 1.9386060237884521
Acc. loss: 616.1077880859375 KL loss.: 0.5500497817993164 Supervised loss: 1.9145116806030273
Acc. loss: 596.890869140625 KL loss.: 0.48524120450019836 Supe

Acc. loss: 600.3922729492188 KL loss.: 0.7509818077087402 Supervised loss: 1.922802209854126
Acc. loss: 597.365234375 KL loss.: 0.8456413149833679 Supervised loss: 1.8953893184661865
Acc. loss: 578.6476440429688 KL loss.: 0.9063360691070557 Supervised loss: 1.9655643701553345
Acc. loss: 577.5272216796875 KL loss.: 0.9557400345802307 Supervised loss: 1.9087704420089722
Acc. loss: 583.2100830078125 KL loss.: 0.758919358253479 Supervised loss: 1.9502391815185547
Acc. loss: 599.614013671875 KL loss.: 0.6773056983947754 Supervised loss: 1.920617938041687
Acc. loss: 596.4628295898438 KL loss.: 0.7548312544822693 Supervised loss: 1.8950583934783936
Acc. loss: 577.8541870117188 KL loss.: 0.8534181714057922 Supervised loss: 1.964996099472046
Acc. loss: 576.6302490234375 KL loss.: 0.9545819759368896 Supervised loss: 1.9079914093017578
Acc. loss: 582.0859985351562 KL loss.: 0.9591970443725586 Supervised loss: 1.9496138095855713
Acc. loss: 598.5098876953125 KL loss.: 0.8687512278556824 Supervised 

Acc. loss: 570.0428466796875 KL loss.: 1.2529869079589844 Supervised loss: 1.9424664974212646
Acc. loss: 585.8638305664062 KL loss.: 1.1537368297576904 Supervised loss: 1.917574405670166
Acc. loss: 582.84228515625 KL loss.: 1.3180943727493286 Supervised loss: 1.881895899772644
Acc. loss: 564.8743286132812 KL loss.: 1.3374701738357544 Supervised loss: 1.9691131114959717
Acc. loss: 563.9944458007812 KL loss.: 1.4079235792160034 Supervised loss: 1.8945671319961548
Acc. loss: 569.2296752929688 KL loss.: 1.4860103130340576 Supervised loss: 1.9335544109344482
Acc. loss: 585.0882568359375 KL loss.: 1.3481898307800293 Supervised loss: 1.9167640209197998
Acc. loss: 582.326904296875 KL loss.: 1.302163004875183 Supervised loss: 1.8882637023925781
Acc. loss: 564.4805908203125 KL loss.: 1.1653491258621216 Supervised loss: 1.9626132249832153
Acc. loss: 563.4149169921875 KL loss.: 1.4037744998931885 Supervised loss: 1.899533748626709
Acc. loss: 568.4398803710938 KL loss.: 1.5649011135101318 Supervise

Acc. loss: 554.577392578125 KL loss.: 2.248950481414795 Supervised loss: 1.7718029022216797
Acc. loss: 559.764892578125 KL loss.: 2.209470272064209 Supervised loss: 1.7533116340637207
Acc. loss: 574.5811767578125 KL loss.: 2.7022790908813477 Supervised loss: 1.639876127243042
Acc. loss: 571.3305053710938 KL loss.: 3.155371904373169 Supervised loss: 1.607314944267273
Acc. loss: 554.3609008789062 KL loss.: 2.7596218585968018 Supervised loss: 1.7376245260238647
Acc. loss: 553.9786987304688 KL loss.: 2.4950077533721924 Supervised loss: 1.637010097503662
Acc. loss: 559.1825561523438 KL loss.: 2.4872887134552 Supervised loss: 1.719037652015686
Acc. loss: 574.06689453125 KL loss.: 2.772099018096924 Supervised loss: 1.4953625202178955
Acc. loss: 570.8015747070312 KL loss.: 3.13488507270813 Supervised loss: 1.5272738933563232
Acc. loss: 553.8721313476562 KL loss.: 2.79630708694458 Supervised loss: 1.6142207384109497
Acc. loss: 553.0410766601562 KL loss.: 2.8367996215820312 Supervised loss: 1.50

KeyboardInterrupt: 

In [5]:
reload(run_supervised_tm)
run_supervised_tm.visualize_topics(model, vocab,num_topics)

####################################################################################################
Visualize topics...
Topic 0: ['generative' 'taking' 'inject' 'x' 'computer' 'subtle' 'concatenation'
 'proof' 'include' 'demonstrate']
Topic 1: ['subset' 'dimensional' 'performance' 'combine' 'yahoo' 'action'
 'rationale' 'entity' 'label' 'reduction']
Topic 2: ['third' 'semantics' 'adaptive' 'region' 'improve' 'computer' 'preferred'
 'thoroughly' 'popular' 'extensive']
Topic 3: ['probabilistic' 'decision' 'form' 'covered' 'efficiency' 'thereof'
 'successfully' 'flow' 'completely' 'datasets']
Topic 4: ['sound' 'used' 'inexpensive' 'become' 'one' 'boosting' 'kind' 'different'
 'enough' 'death']
Topic 5: ['previous' 'passage' 'tuned' 'various' 'effect' 'interpretation' 'prone'
 'required' 'node' 'upper']
Topic 6: ['digit' 'transportation' 'weak' 'filtering' 'neural' 'flow' 'attachment'
 'principle' 'entity' 'plausibility']
Topic 7: ['fixed' 'caused' 'without' 'accessing' 'supervision' 'lea

In [19]:
reload(run_supervised_tm)
propensity_scores, expected_st_treat, expected_st_no_treat = run_supervised_tm.predict(model, subset_counts, dtype='binary')

[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 1 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]]
5862
tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0071,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])
tensor([[0.0087, 0.0088, 0.0100,  ..., 0.0124, 0.0143, 0.0096],
        [0.0072, 0.0077, 0.0066,  ..., 0.0068, 0.0078, 0.0072],
        [0.0234, 0.0096, 0.0096,  ..., 0.0107, 0.0101, 0.0095],
        ...,
        [0.0030, 0.0034, 0.0033,  ..., 0.0031, 0.0036, 0.0036],
        [0.0036, 0.0033, 0.0038,  ..., 0.0037, 0.0041, 0.0034],
        [0.0097, 0.0081, 0.0090,  ..., 0.0077, 0.0081, 0.0080]])


In [20]:
def psi_q_only(q_t0, q_t1, g, t, y):
    ite_t = (q_t1 - q_t0)[t == 1]
    estimate = ite_t.mean()
    return estimate

In [21]:
qhat = psi_q_only(expected_st_no_treat, expected_st_treat, propensity_scores, treatment_labels, outcomes)
qhat

0.013394212