# Ablation Experiment
### This notebook contains the code for the ablation experiment and comparison to LDA.

In [1]:
import numpy as np
import csv
import sys
from sklearn.metrics import roc_curve
from sklearn.metrics import auc
sys.path.append("./SeededLDA/src")

import matplotlib.pyplot as plt

In [3]:
Y = np.load("./data/newsgroup_labels.npy")

### Load in SeededLDA results from c++ code

In [4]:
SeededLDA_acc = np.zeros((7,8,3))

In [6]:
topic_order = [0,1,2]
doc_orders = [4,5,6]

names = ["Baseball", "Medical", "Space"]

results = np.zeros((4,8,3))

for num_words in range(1,9):
    for r in range(4,11):
        
        topic = np.loadtxt("./seeded_lda/ablation_results/SeededLDA_docTopicDist_" + str(r) + "_" + str(num_words) + ".txt", delimiter="\t").T
        

        for i in range(len(topic_order)):
            y_pred = topic[topic_order[i]]
            y_true = np.zeros(y_pred.shape[0])
            j = doc_orders[i]
            y_true[Y==j] = 1
            fpr, tpr, thresholds = roc_curve(y_true, y_pred)
            SeededLDA_acc[r-4,num_words-1,i] = auc(fpr,tpr)

In [None]:
"""
plt.imshow(SeededLDA_acc[:,:,0], vmin=0.5, vmax=1)
plt.title("Baseball", fontsize=16)
plt.yticks(range(0,7),range(4,11), fontsize=14)
plt.xticks(range(0,8),range(1,9), fontsize=14)
plt.ylabel("Number of Topics", fontsize=14)
plt.xlabel("Number of Seed Words", fontsize=14)
plt.show()

plt.imshow(SeededLDA_acc[:,:,1], vmin=0.5, vmax=1)
plt.title("Medical", fontsize=16)
plt.yticks(range(0,7),range(4,11), fontsize=14)
plt.xticks(range(0,8),range(1,9), fontsize=14)
plt.ylabel("Number of Topics", fontsize=14)
plt.xlabel("Number of Seed Words", fontsize=14)
plt.ylabel("Number of Topics")
plt.show()

plt.imshow(SeededLDA_acc[:,:,2], vmin=0.5, vmax=1)
plt.title("Space", fontsize=16)
plt.yticks(range(0,7),range(4,11), fontsize=14)
plt.xticks(range(0,8),range(1,9), fontsize=14)
plt.ylabel("Number of Topics", fontsize=14)
plt.xlabel("Number of Seed Words", fontsize=14)
plt.ylabel("Number of Topics")
plt.show()
"""

### Load in GuidedNMF results generated by newsgroup.ipynb 

In [None]:
model_nmf  = np.load("guidednmf_models.npy", allow_pickle=True)

In [None]:
GuidedNMF_acc = np.zeros((7,8,3))

In [None]:
sub=200
r = 4
for r in range(4,11):
    for num_words, model in enumerate(model_nmf[r-4]):

        S = model.A.T
        A = model.S.T


        topic_order = []
        for i in range(model.B.shape[0]):
            topic_order.append(np.argmax(model.B[i]))

        doc_orders = [4,5,6]


        for i in range(len(topic_order)):
            y_pred = S[topic_order[i]]
            y_true = np.zeros(y_pred.shape[0])
            j = doc_orders[i]
            y_true[sub*j:sub*(j+1)] = 1
            fpr, tpr, thresholds = roc_curve(y_true, y_pred)
            GuidedNMF_acc[r-4,num_words,i] = auc(fpr,tpr)

In [None]:
"""
plt.imshow(GuidedNMF_acc[:,:,0], vmin=0.5, vmax=1)
plt.title("Baseball", fontsize=16)
plt.yticks(range(0,7),range(4,11), fontsize=14)
plt.xticks(range(0,8),range(1,9), fontsize=14)
plt.ylabel("Number of Topics", fontsize=14)
plt.xlabel("Number of Seed Words", fontsize=14)
plt.show()

plt.imshow(GuidedNMF_acc[:,:,1], vmin=0.5, vmax=1)
plt.title("Medical", fontsize=16)
plt.yticks(range(0,7),range(4,11), fontsize=14)
plt.xticks(range(0,8),range(1,9), fontsize=14)
plt.ylabel("Number of Topics", fontsize=14)
plt.xlabel("Number of Seed Words", fontsize=14)
plt.show()

plt.imshow(GuidedNMF_acc[:,:,2], vmin=0.5, vmax=1)
plt.title("Space", fontsize=16)
plt.yticks(range(0,7),range(4,11), fontsize=14)
plt.xticks(range(0,8),range(1,9), fontsize=14)
plt.ylabel("Number of Topics", fontsize=14)
plt.xlabel("Number of Seed Words", fontsize=14)
plt.show()
"""

### Generate Latex Code for Table with results

In [None]:
for r in [4,6,10]:
    entrees_nmf = GuidedNMF_acc[r-4][[0,1,3,7]][:,0]
    entrees_lda = SeededLDA_acc[r-4][[0,1,3,7]][:,0]
    print("\multirow{2}{*}{" + str(r)  + "} & " + "GuidedNMF & " + " & ".join([str(int(e*100)/100) for e in entrees_nmf]) + " \\\\")
    print("& " + "SeededLDA & " +  " & ".join([str(int(e*100)/100) for e in entrees_lda]) + " \\\\")       

In [None]:
for r in [4,6,10]:
    entrees_nmf = GuidedNMF_acc[r-4][[0,1,3,7]][:,2]
    entrees_lda = SeededLDA_acc[r-4][[0,1,3,7]][:,0]
    print("\multirow{2}{*}{" + str(r)  + "} & " + "GuidedNMF & " + " & ".join([str(int(e*100)/100) for e in entrees_nmf]) + " \\\\")
    print("& " + "SeededLDA & " +  " & ".join([str(int(e*100)/100) for e in entrees_lda]) + " \\\\") 
        