In [1]:
import torch
import numpy as np
from scipy.linalg import fractional_matrix_power
from yaspin import yaspin
from sklearn.metrics.pairwise import cosine_similarity
import pickle
from tqdm.notebook import tqdm
import pandas as pd
import math
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

In [52]:
np.random.seed(0)

In [3]:
embed_size = 300
num_docs = 50
batch_size = 10
epochs = 20

In [4]:
class GCN(torch.nn.Module):
    def __init__(self, n_h, input_features):
        super().__init__()
        W = torch.FloatTensor(input_features, n_h)
        self.weights = torch.nn.Parameter(W)
        self.relu = torch.nn.ReLU()
        stdv = 1. / math.sqrt(self.weights.size(1))
        self.weights.data.uniform_(-stdv, stdv)
    
    def forward(self, nfs, A):
        out = torch.matmul(nfs, self.weights)
        out = torch.matmul(A, out)
        out = self.relu(out)
        return out

In [55]:
class AmbiguityNetwork(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.gcn1 = GCN(20, embed_size)
        self.gcn2 = GCN(10, 20)
        self.gcn3 = GCN(10,10)
        self.dense1 = torch.nn.Linear(in_features = 500, out_features = 64)
        self.dense2 = torch.nn.Linear(in_features = 64, out_features = 4)
    
    def forward(self, inputs, A):
        out = self.gcn1(inputs, A)
        out = self.gcn2(out, A)
        out = self.gcn3(out, A)
        out = torch.flatten(out, 1)
        out = self.dense1(out)
        out = self.dense2(out)
        return out

In [7]:
requests = pd.read_table('./data/train.tsv', sep = '\t', header = 0).drop_duplicates('topic_id')
requests

Unnamed: 0,topic_id,initial_request,topic_desc,clarification_need,facet_id,facet_desc,question_id,question,answer
0,1,Tell me about Obama family tree.,Find information on President Barack Obama\'s ...,2,F0001,"Find the TIME magazine photo essay ""Barack Oba...",Q00384,are you interested in seeing barack obamas family,yes am interested in obamas family
39,102,What is Fickle Creek Farm,Find general information about Fickle Creek Fa...,2,F0014,Find general information about Fickle Creek Fa...,Q00059,are you going to purchase anything there,i dont know yet i just want general info about...
84,105,Tell me about sonoma county medical services.,What medical services are available in Sonoma ...,2,F0025,What medical services are available in Sonoma ...,Q00457,are you interested in the human services depar...,no i am looking for doctors or hospitals in so...
120,108,Tell me about of Ralph Owen Brester.,Find biographical information about Ralph Owen...,1,F0037,Find biographical information about Ralph Owen...,Q00297,are you interested in learning more about ralp...,yes and his biography
159,109,I'm looking for information about mayo clinic ...,What medical services are available at the May...,2,F0040,What medical services are available at the May...,Q00256,are you interested in jobs at mayo clinic jack...,no im interested in services provided at mayo ...
...,...,...,...,...,...,...,...,...,...
9113,234,tell me about dark chocolate health benefits,What are the health benefits associated with e...,1,F0539,What are the health benefits associated with e...,Q00430,are you interested in the different candles da...,no im more interested in the health benefits o...
9127,238,Tell me bio of george bush sr.,Find biographies of US President George H.W. B...,2,F0552,Find biographies of US President George H.W. B...,Q00402,are you interested in the 1992 presidential el...,yes i would like to know about george bush srs...
9139,240,Tell me more about us presidents,Find a list of the full names of US presidents.,3,F0558,Find a list of the full names of US presidents.,Q00201,are you interested in finding out the middle n...,no i am interested in the names of all us pres...
9151,261,Tell me about folk remedies for a sore throat.,What folk remedies are there for soothing a so...,1,F0633,What folk remedies are there for soothing a so...,Q00208,are you interested in folk remedies from a spe...,no just any folk remedies for a sore throat


In [8]:
rdev = pd.read_table('./data/dev.tsv', sep = '\t', header = 0).drop_duplicates('topic_id')
rdev

Unnamed: 0,topic_id,initial_request,topic_desc,clarification_need,facet_id,facet_desc,question_id,question,answer
0,101,Find me information about the Ritz Carlton Lak...,Find information about the Ritz Carlton resort...,2,F0010,Find information about the Ritz Carlton resort...,Q00697,are you looking for a specific web site,yes for the ritz carlton resort at lake las vegas
60,106,I'm looking for universal animal cuts reviews,Find testimonials of Universal Animal Cuts nut...,3,F0028,Find testimonials of Universal Animal Cuts nut...,Q01481,did universal animal cuts work for you,i need testimonials on the universal animal cu...
102,107,tell me about cass county missouri,Find demographic information about Cass County...,3,F0031,Find demographic information about Cass County...,Q00086,are you interested in a list of homes for sale...,no i want demographic info for cass county mo
192,114,Tell about an adobe indian house?,How does one build an adobe house?,2,F0063,How does one build an adobe house?,Q00057,are you going to purchase any specific product...,maybe
231,123,What is von Willebrand Disease?,What is von Willebrand Disease?,3,F0100,What is von Willebrand Disease?,Q00284,are you interested in learning about treatment...,id like to know what it is first
267,128,Tell me about atypical squamous cells,What do atypical squamous cells mean on a pap ...,2,F0114,What do atypical squamous cells mean on a pap ...,Q00081,are you interested in a desciption of aytypica...,yes and it must be related to pap smear tests
303,133,all men are created equal,"Who said \""all men are created equal\""?",2,F0134,"Who said ""all men are created equal""?",Q00796,are you looking for declaration of independenc...,i am looking for who said the qoute provided
378,139,Tell me more about Rocky Mountain News,discussion of the impending sale of the Rocky ...,2,F0156,discussion of the impending sale of the Rocky ...,Q00334,are you interested in news archives,no i am intrested in the sale of arocky mounta...
426,142,Find me information about the sales tax in Ill...,information about the sales tax in Illinois: w...,2,F0170,information about the sales tax in Illinois: w...,Q00241,are you interested in how the illinois state t...,no i would like to know the rate and what it i...
471,164,I'm looking for information on hobby stores,What hobby stores carry trains?,4,F0259,What hobby stores carry trains?,Q00659,are you looking for a radiocontrolled plane,no im looking for trains


In [9]:
df = requests.append(rdev)

In [10]:
df

Unnamed: 0,topic_id,initial_request,topic_desc,clarification_need,facet_id,facet_desc,question_id,question,answer
0,1,Tell me about Obama family tree.,Find information on President Barack Obama\'s ...,2,F0001,"Find the TIME magazine photo essay ""Barack Oba...",Q00384,are you interested in seeing barack obamas family,yes am interested in obamas family
39,102,What is Fickle Creek Farm,Find general information about Fickle Creek Fa...,2,F0014,Find general information about Fickle Creek Fa...,Q00059,are you going to purchase anything there,i dont know yet i just want general info about...
84,105,Tell me about sonoma county medical services.,What medical services are available in Sonoma ...,2,F0025,What medical services are available in Sonoma ...,Q00457,are you interested in the human services depar...,no i am looking for doctors or hospitals in so...
120,108,Tell me about of Ralph Owen Brester.,Find biographical information about Ralph Owen...,1,F0037,Find biographical information about Ralph Owen...,Q00297,are you interested in learning more about ralp...,yes and his biography
159,109,I'm looking for information about mayo clinic ...,What medical services are available at the May...,2,F0040,What medical services are available at the May...,Q00256,are you interested in jobs at mayo clinic jack...,no im interested in services provided at mayo ...
...,...,...,...,...,...,...,...,...,...
2120,51,Tell me more about horse hooves,"Find information about horse hooves, their car...",2,F0858,Find information about horses' hooves and how ...,Q00025,are you asking about a horses feet,yes about horse hooves
2195,79,tell me about Voyager,Find information about the NASA Voyager spacec...,4,F0981,Find the homepage for the NASA Voyager mission.,Q00453,are you interested in the history of the voyager,no just take me to the homepage
2251,83,tell me about memory,Find information about human memory.,4,F0998,Find information about human memory.,Q00176,are you interested in computer memory,no
2293,256,Who is the patron saint of mental illness?,Who is the patron saint of mental illness?,1,F0609,Who is the patron saint of mental illness?,Q00304,are you interested in learning the back story ...,im interested in learning who it is


In [11]:
f = open('retrieved_docs_encoded_d2v.pkl', 'rb')

In [12]:
docs = {}
with yaspin().arc:
    while True:
        try:
            data = pickle.load(f)
            tid = list(data.keys())[0]
            docs[tid] = list(data.values())[0]
        except EOFError:
            break

[K0m [K

In [13]:
ks = list(docs.keys())

In [14]:
tids = requests['topic_id'].to_numpy(dtype = int)

In [15]:
tids = np.intersect1d(tids, np.array(ks))

In [16]:
len(tids)

187

In [17]:
inputs = torch.zeros((len(tids), num_docs, embed_size))
cosines = torch.zeros((len(tids), num_docs, num_docs))
labels = torch.LongTensor(torch.zeros(len(tids)).long())

In [18]:
inputs.shape

torch.Size([187, 50, 300])

In [19]:
for i in tqdm(range(len(tids))):
    tid = tids[i]
    enc = docs[tid]['encoded_docs'][:50]
    enc = np.array(enc)
    csml = cosine_similarity(enc)
    inputs[i] = torch.Tensor(enc)
    A = torch.Tensor(csml)
    S = torch.sum(A, axis = 1)
    D = torch.diag(torch.sum(A, axis = 1))
    D_ = torch.Tensor(fractional_matrix_power(D, -0.5))
    A_ = torch.chain_matmul(D_, A, D_)
    cosines[i] = A_
    labels[i] = df.loc[df['topic_id'] == tid]['clarification_need'].to_numpy(dtype = np.int_)[0]

HBox(children=(FloatProgress(value=0.0, max=187.0), HTML(value='')))




In [20]:
labels = torch.LongTensor(labels.long())

In [21]:
labels

tensor([2, 3, 2, 3, 4, 2, 3, 2, 3, 2, 2, 3, 4, 1, 2, 2, 1, 2, 2, 2, 2, 3, 3, 2,
        2, 3, 2, 3, 3, 3, 2, 3, 2, 3, 3, 2, 4, 2, 3, 2, 4, 3, 2, 4, 1, 2, 3, 1,
        4, 4, 2, 4, 3, 3, 1, 2, 4, 3, 4, 4, 3, 2, 3, 4, 3, 3, 3, 4, 3, 3, 3, 3,
        2, 4, 2, 4, 4, 4, 4, 2, 3, 3, 2, 3, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 1, 3,
        2, 3, 2, 3, 3, 2, 3, 2, 2, 2, 4, 2, 2, 3, 3, 1, 2, 2, 4, 3, 2, 2, 2, 3,
        2, 2, 4, 4, 3, 3, 3, 3, 3, 2, 2, 3, 3, 2, 4, 2, 2, 3, 1, 3, 2, 3, 2, 2,
        3, 4, 3, 3, 1, 4, 3, 3, 1, 3, 2, 2, 1, 3, 2, 2, 2, 3, 2, 1, 1, 1, 2, 3,
        1, 2, 1, 1, 2, 1, 3, 4, 2, 2, 1, 1, 1, 1, 2, 1, 2, 1, 3])

In [22]:
cosines[0]

tensor([[0.0682, 0.0092, 0.0120,  ..., 0.0270, 0.0217, 0.0245],
        [0.0092, 0.0988, 0.0220,  ..., 0.0099, 0.0140, 0.0120],
        [0.0120, 0.0220, 0.0848,  ..., 0.0066, 0.0078, 0.0134],
        ...,
        [0.0270, 0.0099, 0.0066,  ..., 0.0854, 0.0214, 0.0223],
        [0.0217, 0.0140, 0.0078,  ..., 0.0214, 0.1043, 0.0311],
        [0.0245, 0.0120, 0.0134,  ..., 0.0223, 0.0311, 0.0835]])

In [23]:
dev_tids = rdev['topic_id'].to_numpy(dtype = int)

In [24]:
dev_labels = rdev['clarification_need'].to_numpy(dtype = int)

In [25]:
dev_inputs = torch.zeros((len(dev_tids), num_docs, embed_size))
dev_cosines = torch.zeros((len(dev_tids), num_docs, num_docs))

In [26]:
for i in tqdm(range(len(dev_tids))):
    tid = dev_tids[i]
    enc = docs[tid]['encoded_docs'][:50]
    enc = np.array(enc)
    csml = cosine_similarity(enc)
    dev_inputs[i] = torch.Tensor(enc)
    A = torch.Tensor(csml)
    S = torch.sum(A, axis = 1)
    D = torch.diag(torch.sum(A, axis = 1))
    D_ = torch.Tensor(fractional_matrix_power(D, -0.5))
    A_ = torch.chain_matmul(D_, A, D_)
    dev_cosines[i] = A_

HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))




In [30]:
def train(epochs, batch_size):
    model = AmbiguityNetwork()
    loss_fn = torch.nn.CrossEntropyLoss()
    dataset = torch.utils.data.TensorDataset(inputs, cosines, labels)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    loader = torch.utils.data.DataLoader(dataset, sampler = torch.utils.data.RandomSampler(dataset), batch_size = batch_size, pin_memory = True)
    for epoch in tqdm(range(epochs)):
        acc = 0
        model.train()
        for idx, data in enumerate(loader):
            nf, A, tgts = data
            tgts = tgts - 1
            tgts = torch.LongTensor(tgts)
            preds = model(nf, A)
            loss = loss_fn(preds, tgts)
            m = torch.nn.Softmax()
            npreds = m(preds).detach().numpy()
            class_preds = np.argmax(npreds, axis = 1)
            acc += accuracy_score(tgts.detach().numpy(), class_preds)
            model.zero_grad()
            loss.backward()
            optimizer.step()
            
        with torch.set_grad_enabled(False):
            model.eval()
            val_preds = model(dev_inputs, dev_cosines)
            m = torch.nn.Softmax()
            val_npreds = m(val_preds).detach().numpy()
            class_preds = np.argmax(val_npreds, axis = 1)
            class_preds = class_preds + 1
            val_acc = accuracy_score(dev_labels, class_preds)
            
        print(f"Epoch: {epoch}, loss: {loss}, training acc: {acc/len(loader)}, validation acc: {val_acc}")
        
    torch.save(model.state_dict(), 'saved_model_graphcnn.pt')

In [66]:
train(100, batch_size)

HBox(children=(FloatProgress(value=0.0), HTML(value='')))

  npreds = m(preds).detach().numpy()
  val_npreds = m(val_preds).detach().numpy()


Epoch: 0, loss: 1.3026492595672607, training acc: 0.33533834586466166, validation acc: 0.4
Epoch: 1, loss: 1.224786639213562, training acc: 0.4037593984962406, validation acc: 0.42
Epoch: 2, loss: 1.3173151016235352, training acc: 0.4142857142857142, validation acc: 0.42
Epoch: 3, loss: 1.5028024911880493, training acc: 0.4045112781954887, validation acc: 0.42
Epoch: 4, loss: 1.310724139213562, training acc: 0.4300751879699248, validation acc: 0.4
Epoch: 5, loss: 1.06879460811615, training acc: 0.41428571428571426, validation acc: 0.42
Epoch: 6, loss: 1.2578697204589844, training acc: 0.42781954887218043, validation acc: 0.38
Epoch: 7, loss: 1.2441402673721313, training acc: 0.4300751879699248, validation acc: 0.36
Epoch: 8, loss: 1.1429026126861572, training acc: 0.4300751879699249, validation acc: 0.38
Epoch: 9, loss: 1.5980193614959717, training acc: 0.42556390977443614, validation acc: 0.4
Epoch: 10, loss: 1.1346957683563232, training acc: 0.4383458646616541, validation acc: 0.4
Ep

Epoch: 93, loss: 0.06899743527173996, training acc: 0.9421052631578947, validation acc: 0.34
Epoch: 94, loss: 0.10894487798213959, training acc: 0.9368421052631579, validation acc: 0.34
Epoch: 95, loss: 0.3465721011161804, training acc: 0.9398496240601504, validation acc: 0.36
Epoch: 96, loss: 0.19021712243556976, training acc: 0.9421052631578949, validation acc: 0.32
Epoch: 97, loss: 0.40032991766929626, training acc: 0.9421052631578949, validation acc: 0.34
Epoch: 98, loss: 0.5040246844291687, training acc: 0.9398496240601504, validation acc: 0.36
Epoch: 99, loss: 0.06014762073755264, training acc: 0.942105263157895, validation acc: 0.34



In [63]:
model = AmbiguityNetwork()
model.load_state_dict(torch.load('saved_model_graphcnn.pt'))
model.eval()

AmbiguityNetwork(
  (gcn1): GCN(
    (relu): ReLU()
  )
  (gcn2): GCN(
    (relu): ReLU()
  )
  (gcn3): GCN(
    (relu): ReLU()
  )
  (dense1): Linear(in_features=500, out_features=64, bias=True)
  (dense2): Linear(in_features=64, out_features=4, bias=True)
)

In [64]:
val_preds = model(dev_inputs, dev_cosines)
m = torch.nn.Softmax()
val_npreds = m(val_preds).detach().numpy()
class_preds = np.argmax(val_npreds, axis = 1)
class_preds = class_preds + 1
prec_score = precision_score(dev_labels, class_preds, average = 'weighted')
rec_score = recall_score(dev_labels, class_preds, average = 'weighted')
f1 = f1_score(dev_labels, class_preds, average = 'weighted')
print(prec_score, rec_score, f1)

0.5103007518796993 0.5 0.49804822469528354


  val_npreds = m(val_preds).detach().numpy()


In [27]:
dev_inputs[0]

tensor([[ 0.7635, -3.1419,  0.8322,  ...,  2.2655, -0.8896,  1.2431],
        [-0.1833, -0.6641, -1.3043,  ..., -0.2687,  1.4635, -0.8374],
        [-1.0013,  1.1242,  0.4493,  ...,  1.1212, -1.0589,  1.3900],
        ...,
        [ 0.0247, -1.3875, -0.2284,  ...,  1.3820,  0.4456, -0.2931],
        [ 0.6275,  1.4585, -0.4108,  ...,  0.4533,  0.1874, -0.8568],
        [ 0.2958,  0.6222, -0.3626,  ..., -0.1815, -0.8143, -0.0502]])

In [48]:
preds = model(dev_inputs, dev_cosines)

In [49]:
m = torch.nn.Softmax()

preds = m(preds).detach().numpy()

  preds = m(preds).detach().numpy()


In [50]:
preds

array([[8.59804172e-03, 6.65261149e-01, 3.19370955e-01, 6.76983036e-03],
       [3.00306026e-02, 4.57673311e-01, 3.45387399e-01, 1.66908637e-01],
       [3.87188338e-04, 2.41328054e-03, 9.95126724e-01, 2.07275059e-03],
       [1.89618615e-03, 7.82768242e-03, 9.88410234e-01, 1.86594576e-03],
       [8.33924673e-03, 8.98035288e-01, 2.34802836e-03, 9.12774280e-02],
       [3.45815875e-04, 7.50817597e-01, 9.20605453e-05, 2.48744473e-01],
       [8.66917446e-02, 1.11487120e-01, 3.98872882e-01, 4.02948201e-01],
       [2.39838660e-03, 9.50534642e-01, 2.23094746e-02, 2.47575268e-02],
       [8.18969831e-02, 4.66684431e-01, 3.66807014e-01, 8.46116245e-02],
       [4.40147379e-03, 1.09722748e-01, 8.49538207e-01, 3.63375880e-02],
       [6.05910202e-04, 1.79295905e-03, 9.93209064e-01, 4.39204415e-03],
       [7.64172338e-03, 9.92168248e-01, 2.74152717e-05, 1.62576500e-04],
       [1.80002348e-03, 8.00464988e-01, 1.97729275e-01, 5.69362192e-06],
       [5.40936971e-03, 5.52157871e-03, 9.64325428e

In [51]:
class_preds = np.argmax(preds, axis = 1)
class_preds

array([1, 1, 2, 2, 1, 1, 3, 1, 1, 2, 2, 1, 1, 2, 1, 1, 3, 1, 1, 2, 3, 3,
       2, 2, 2, 2, 2, 2, 0, 3, 1, 2, 2, 1, 2, 0, 1, 1, 1, 2, 0, 0, 3, 3,
       2, 1, 3, 0, 1, 1], dtype=int64)

In [52]:
len(class_preds)

50

In [53]:
class_preds = class_preds + 1

In [54]:
class_preds

array([2, 2, 3, 3, 2, 2, 4, 2, 2, 3, 3, 2, 2, 3, 2, 2, 4, 2, 2, 3, 4, 4,
       3, 3, 3, 3, 3, 3, 1, 4, 2, 3, 3, 2, 3, 1, 2, 2, 2, 3, 1, 1, 4, 4,
       3, 2, 4, 1, 2, 2], dtype=int64)

In [55]:
outs = [(dev_tids[i], class_preds[i]) for i in range(len(preds))]
outs = np.array(outs)
outs

array([[101,   2],
       [106,   2],
       [107,   3],
       [114,   3],
       [123,   2],
       [128,   2],
       [133,   4],
       [139,   2],
       [142,   2],
       [164,   3],
       [165,   3],
       [166,   2],
       [169,   2],
       [174,   3],
       [ 18,   2],
       [190,   2],
       [191,   4],
       [193,   2],
       [195,   2],
       [200,   3],
       [ 24,   4],
       [ 25,   4],
       [ 27,   3],
       [ 35,   3],
       [ 37,   3],
       [ 45,   3],
       [ 71,   3],
       [ 74,   3],
       [  8,   1],
       [ 85,   4],
       [214,   2],
       [219,   3],
       [229,   3],
       [250,   2],
       [252,   3],
       [262,   1],
       [283,   2],
       [287,   2],
       [293,   2],
       [110,   3],
       [118,   1],
       [126,   1],
       [152,   4],
       [ 20,   4],
       [ 44,   3],
       [ 51,   2],
       [ 79,   4],
       [ 83,   1],
       [256,   2],
       [292,   2]], dtype=int64)

In [56]:
np.savetxt('preds_dev_doc2vec_cosine_10gcn_10gcn.txt', outs, fmt="%s %s")

In [57]:
!python clariq_eval_tool.py --eval_task clarification_need \
                                 --data_dir ./data/ \
                                 --experiment_type dev \
                                 --run_file preds_dev_doc2vec_cosine_10gcn_10gcn.txt \
                                 --out_file out

Precision:  0.5114117647058823
Recall:  0.5
F1: 0.5051049954349811
