In [1]:
import torch
import numpy as np
from scipy.linalg import fractional_matrix_power
from yaspin import yaspin
from sklearn.metrics.pairwise import cosine_similarity
import pickle
from tqdm.notebook import tqdm
import pandas as pd
import math

In [2]:
np.random.seed(0)

In [3]:
embed_size = 768
num_docs = 50
batch_size = 10

In [4]:
class GCN(torch.nn.Module):
    def __init__(self, n_h, input_features):
        super().__init__()
        W = torch.FloatTensor(input_features, n_h)
        self.weights = torch.nn.Parameter(W)
        self.relu = torch.nn.ReLU()
        stdv = 1. / math.sqrt(self.weights.size(1))
        self.weights.data.uniform_(-stdv, stdv)
    
    def forward(self, nfs, A):
        out = torch.matmul(nfs, self.weights)
        out = torch.matmul(A, out)
        out = self.relu(out)
        return out

In [5]:
class AmbiguityNetwork(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.gcn1 = GCN(3, embed_size)
        self.gcn2 = GCN(3, 3)
        self.dense = torch.nn.Linear(in_features = 150, out_features = 4)
    
    def forward(self, inputs, A):
        out = self.gcn1(inputs, A)
        out = self.gcn2(out, A)
        out = torch.flatten(out, 1)
        out = self.dense(out)
        return out

In [6]:
requests = pd.read_table('./data/train.tsv', sep = '\t', header = 0).drop_duplicates('topic_id')
requests

Unnamed: 0,topic_id,initial_request,topic_desc,clarification_need,facet_id,facet_desc,question_id,question,answer
0,1,Tell me about Obama family tree.,Find information on President Barack Obama\'s ...,2,F0001,"Find the TIME magazine photo essay ""Barack Oba...",Q00384,are you interested in seeing barack obamas family,yes am interested in obamas family
39,102,What is Fickle Creek Farm,Find general information about Fickle Creek Fa...,2,F0014,Find general information about Fickle Creek Fa...,Q00059,are you going to purchase anything there,i dont know yet i just want general info about...
84,105,Tell me about sonoma county medical services.,What medical services are available in Sonoma ...,2,F0025,What medical services are available in Sonoma ...,Q00457,are you interested in the human services depar...,no i am looking for doctors or hospitals in so...
120,108,Tell me about of Ralph Owen Brester.,Find biographical information about Ralph Owen...,1,F0037,Find biographical information about Ralph Owen...,Q00297,are you interested in learning more about ralp...,yes and his biography
159,109,I'm looking for information about mayo clinic ...,What medical services are available at the May...,2,F0040,What medical services are available at the May...,Q00256,are you interested in jobs at mayo clinic jack...,no im interested in services provided at mayo ...
...,...,...,...,...,...,...,...,...,...
9113,234,tell me about dark chocolate health benefits,What are the health benefits associated with e...,1,F0539,What are the health benefits associated with e...,Q00430,are you interested in the different candles da...,no im more interested in the health benefits o...
9127,238,Tell me bio of george bush sr.,Find biographies of US President George H.W. B...,2,F0552,Find biographies of US President George H.W. B...,Q00402,are you interested in the 1992 presidential el...,yes i would like to know about george bush srs...
9139,240,Tell me more about us presidents,Find a list of the full names of US presidents.,3,F0558,Find a list of the full names of US presidents.,Q00201,are you interested in finding out the middle n...,no i am interested in the names of all us pres...
9151,261,Tell me about folk remedies for a sore throat.,What folk remedies are there for soothing a so...,1,F0633,What folk remedies are there for soothing a so...,Q00208,are you interested in folk remedies from a spe...,no just any folk remedies for a sore throat


In [7]:
rdev = pd.read_table('./data/dev.tsv', sep = '\t', header = 0).drop_duplicates('topic_id')
rdev

Unnamed: 0,topic_id,initial_request,topic_desc,clarification_need,facet_id,facet_desc,question_id,question,answer
0,101,Find me information about the Ritz Carlton Lak...,Find information about the Ritz Carlton resort...,2,F0010,Find information about the Ritz Carlton resort...,Q00697,are you looking for a specific web site,yes for the ritz carlton resort at lake las vegas
60,106,I'm looking for universal animal cuts reviews,Find testimonials of Universal Animal Cuts nut...,3,F0028,Find testimonials of Universal Animal Cuts nut...,Q01481,did universal animal cuts work for you,i need testimonials on the universal animal cu...
102,107,tell me about cass county missouri,Find demographic information about Cass County...,3,F0031,Find demographic information about Cass County...,Q00086,are you interested in a list of homes for sale...,no i want demographic info for cass county mo
192,114,Tell about an adobe indian house?,How does one build an adobe house?,2,F0063,How does one build an adobe house?,Q00057,are you going to purchase any specific product...,maybe
231,123,What is von Willebrand Disease?,What is von Willebrand Disease?,3,F0100,What is von Willebrand Disease?,Q00284,are you interested in learning about treatment...,id like to know what it is first
267,128,Tell me about atypical squamous cells,What do atypical squamous cells mean on a pap ...,2,F0114,What do atypical squamous cells mean on a pap ...,Q00081,are you interested in a desciption of aytypica...,yes and it must be related to pap smear tests
303,133,all men are created equal,"Who said \""all men are created equal\""?",2,F0134,"Who said ""all men are created equal""?",Q00796,are you looking for declaration of independenc...,i am looking for who said the qoute provided
378,139,Tell me more about Rocky Mountain News,discussion of the impending sale of the Rocky ...,2,F0156,discussion of the impending sale of the Rocky ...,Q00334,are you interested in news archives,no i am intrested in the sale of arocky mounta...
426,142,Find me information about the sales tax in Ill...,information about the sales tax in Illinois: w...,2,F0170,information about the sales tax in Illinois: w...,Q00241,are you interested in how the illinois state t...,no i would like to know the rate and what it i...
471,164,I'm looking for information on hobby stores,What hobby stores carry trains?,4,F0259,What hobby stores carry trains?,Q00659,are you looking for a radiocontrolled plane,no im looking for trains


In [8]:
df = requests.append(rdev)

In [9]:
df

Unnamed: 0,topic_id,initial_request,topic_desc,clarification_need,facet_id,facet_desc,question_id,question,answer
0,1,Tell me about Obama family tree.,Find information on President Barack Obama\'s ...,2,F0001,"Find the TIME magazine photo essay ""Barack Oba...",Q00384,are you interested in seeing barack obamas family,yes am interested in obamas family
39,102,What is Fickle Creek Farm,Find general information about Fickle Creek Fa...,2,F0014,Find general information about Fickle Creek Fa...,Q00059,are you going to purchase anything there,i dont know yet i just want general info about...
84,105,Tell me about sonoma county medical services.,What medical services are available in Sonoma ...,2,F0025,What medical services are available in Sonoma ...,Q00457,are you interested in the human services depar...,no i am looking for doctors or hospitals in so...
120,108,Tell me about of Ralph Owen Brester.,Find biographical information about Ralph Owen...,1,F0037,Find biographical information about Ralph Owen...,Q00297,are you interested in learning more about ralp...,yes and his biography
159,109,I'm looking for information about mayo clinic ...,What medical services are available at the May...,2,F0040,What medical services are available at the May...,Q00256,are you interested in jobs at mayo clinic jack...,no im interested in services provided at mayo ...
...,...,...,...,...,...,...,...,...,...
2120,51,Tell me more about horse hooves,"Find information about horse hooves, their car...",2,F0858,Find information about horses' hooves and how ...,Q00025,are you asking about a horses feet,yes about horse hooves
2195,79,tell me about Voyager,Find information about the NASA Voyager spacec...,4,F0981,Find the homepage for the NASA Voyager mission.,Q00453,are you interested in the history of the voyager,no just take me to the homepage
2251,83,tell me about memory,Find information about human memory.,4,F0998,Find information about human memory.,Q00176,are you interested in computer memory,no
2293,256,Who is the patron saint of mental illness?,Who is the patron saint of mental illness?,1,F0609,Who is the patron saint of mental illness?,Q00304,are you interested in learning the back story ...,im interested in learning who it is


In [10]:
f = open('retrieved_docs_encoded.pkl', 'rb')

In [11]:
docs = {}
with yaspin().arc:
    while True:
        try:
            data = pickle.load(f)
            tid = list(data.keys())[0]
            docs[tid] = list(data.values())[0]
        except EOFError:
            break

[K0m [K

In [12]:
ks = list(docs.keys())

In [14]:
len(ks)

298

In [15]:
tids = requests['topic_id'].to_numpy(dtype = int)

In [16]:
tids = np.intersect1d(tids, np.array(ks))

In [17]:
len(tids)

187

In [18]:
inputs = torch.zeros((len(tids), num_docs, embed_size))
cosines = torch.zeros((len(tids), num_docs, num_docs))
labels = torch.LongTensor(torch.zeros(len(tids)).long())

In [19]:
inputs.shape

torch.Size([187, 50, 768])

In [20]:
for i in tqdm(range(len(tids))):
    tid = tids[i]
    enc = docs[tid]['encoded_docs'][:50]
    enc = np.array(enc)
    csml = cosine_similarity(enc)
    inputs[i] = torch.Tensor(enc)
    A = torch.Tensor(csml)
    S = torch.sum(A, axis = 1)
    D = torch.diag(torch.sum(A, axis = 1))
    D_ = torch.Tensor(fractional_matrix_power(D, -0.5))
    A_ = torch.chain_matmul(D_, A, D_)
    cosines[i] = A_
    labels[i] = df.loc[df['topic_id'] == tid]['clarification_need'].to_numpy(dtype = np.int_)[0]

HBox(children=(FloatProgress(value=0.0, max=187.0), HTML(value='')))




In [21]:
labels = torch.LongTensor(labels.long())

In [22]:
labels

tensor([2, 3, 2, 3, 4, 2, 3, 2, 3, 2, 2, 3, 4, 1, 2, 2, 1, 2, 2, 2, 2, 3, 3, 2,
        2, 3, 2, 3, 3, 3, 2, 3, 2, 3, 3, 2, 4, 2, 3, 2, 4, 3, 2, 4, 1, 2, 3, 1,
        4, 4, 2, 4, 3, 3, 1, 2, 4, 3, 4, 4, 3, 2, 3, 4, 3, 3, 3, 4, 3, 3, 3, 3,
        2, 4, 2, 4, 4, 4, 4, 2, 3, 3, 2, 3, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 1, 3,
        2, 3, 2, 3, 3, 2, 3, 2, 2, 2, 4, 2, 2, 3, 3, 1, 2, 2, 4, 3, 2, 2, 2, 3,
        2, 2, 4, 4, 3, 3, 3, 3, 3, 2, 2, 3, 3, 2, 4, 2, 2, 3, 1, 3, 2, 3, 2, 2,
        3, 4, 3, 3, 1, 4, 3, 3, 1, 3, 2, 2, 1, 3, 2, 2, 2, 3, 2, 1, 1, 1, 2, 3,
        1, 2, 1, 1, 2, 1, 3, 4, 2, 2, 1, 1, 1, 1, 2, 1, 2, 1, 3])

In [23]:
cosines.shape

torch.Size([187, 50, 50])

In [31]:
def train(epochs, batch_size):
    model = AmbiguityNetwork()
    loss_fn = torch.nn.CrossEntropyLoss()
    dataset = torch.utils.data.TensorDataset(inputs, cosines, labels)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loader = torch.utils.data.DataLoader(dataset, sampler = torch.utils.data.RandomSampler(dataset), batch_size = batch_size, pin_memory = True)
    for epoch in tqdm(range(epochs)):
        for idx, data in enumerate(loader):
            nf, A, tgts = data
            tgts = tgts - 1
            tgts = torch.LongTensor(tgts)
            preds = model(nf, A)
            loss = loss_fn(preds, tgts)
            model.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"Epoch: {epoch}, loss: {loss}")
    torch.save(model.state_dict(), 'saved_model_graphcnn_sbert.pt')

In [32]:
train(100, batch_size)

HBox(children=(FloatProgress(value=0.0), HTML(value='')))

Epoch: 0, loss: 1.1849801540374756
Epoch: 1, loss: 1.8277772665023804
Epoch: 2, loss: 1.2406257390975952
Epoch: 3, loss: 1.2138341665267944
Epoch: 4, loss: 1.1375653743743896
Epoch: 5, loss: 1.15119206905365
Epoch: 6, loss: 0.9308151006698608
Epoch: 7, loss: 1.2140511274337769
Epoch: 8, loss: 1.4741722345352173
Epoch: 9, loss: 1.0344047546386719
Epoch: 10, loss: 1.484643578529358
Epoch: 11, loss: 1.4912922382354736
Epoch: 12, loss: 1.167317509651184
Epoch: 13, loss: 1.1700752973556519
Epoch: 14, loss: 1.3470373153686523
Epoch: 15, loss: 1.184739112854004
Epoch: 16, loss: 1.400888204574585
Epoch: 17, loss: 1.082517385482788
Epoch: 18, loss: 0.8521986603736877
Epoch: 19, loss: 1.2358207702636719
Epoch: 20, loss: 0.9843258261680603
Epoch: 21, loss: 1.3326185941696167
Epoch: 22, loss: 1.0027674436569214
Epoch: 23, loss: 1.2829381227493286
Epoch: 24, loss: 1.1418507099151611
Epoch: 25, loss: 1.1652384996414185
Epoch: 26, loss: 1.2071702480316162
Epoch: 27, loss: 1.2767674922943115
Epoch: 28

In [33]:
dev_tids = rdev['topic_id'].to_numpy(dtype = int)

In [34]:
dev_inputs = torch.zeros((len(dev_tids), num_docs, embed_size))
dev_cosines = torch.zeros((len(dev_tids), num_docs, num_docs))

In [35]:
for i in tqdm(range(len(dev_tids))):
    tid = dev_tids[i]
    enc = docs[tid]['encoded_docs'][:50]
    enc = np.array(enc)
    csml = cosine_similarity(enc)
    dev_inputs[i] = torch.Tensor(enc)
    A = torch.Tensor(csml)
    S = torch.sum(A, axis = 1)
    D = torch.diag(torch.sum(A, axis = 1))
    D_ = torch.Tensor(fractional_matrix_power(D, -0.5))
    A_ = torch.chain_matmul(D_, A, D_)
    dev_cosines[i] = A_

HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))




In [36]:
dev_inputs[0]

tensor([[-7.2326e-01,  5.3134e-01,  1.2137e+00,  ..., -1.2824e-01,
         -2.9642e-01,  4.0198e-01],
        [-5.2355e-01,  6.2617e-01,  3.9688e-01,  ..., -7.4506e-02,
          1.8932e-01,  5.2183e-02],
        [-5.3171e-01,  1.0504e+00,  3.2651e-01,  ...,  6.9487e-02,
          3.0893e-01,  9.6469e-01],
        ...,
        [-7.6178e-01,  6.2427e-01,  3.0667e-01,  ..., -3.4062e-01,
          7.3413e-02,  1.1522e-01],
        [-4.3759e-02,  1.0616e+00,  1.4367e-01,  ..., -6.0262e-01,
          2.8618e-01, -1.6619e-04],
        [ 1.3685e-01,  8.7571e-01,  5.8264e-01,  ...,  3.9131e-01,
         -1.8068e-01,  4.4024e-01]])

In [38]:
model = AmbiguityNetwork()
model.load_state_dict(torch.load('saved_model_graphcnn_sbert.pt'))
model.eval()

AmbiguityNetwork(
  (gcn1): GCN(
    (relu): ReLU()
  )
  (gcn2): GCN(
    (relu): ReLU()
  )
  (dense): Linear(in_features=150, out_features=4, bias=True)
)

In [39]:
preds = model(dev_inputs, dev_cosines)

In [40]:
m = torch.nn.Softmax()
preds = m(preds).detach().numpy()

  preds = m(preds).detach().numpy()


In [41]:
preds

array([[1.95710897e-01, 3.10611784e-01, 3.42244774e-01, 1.51432618e-01],
       [2.20654294e-01, 2.34331384e-01, 2.93923438e-01, 2.51090944e-01],
       [9.49987248e-02, 5.06062806e-01, 3.75365108e-01, 2.35734377e-02],
       [2.25822046e-01, 2.04093158e-01, 2.69160509e-01, 3.00924271e-01],
       [1.11180410e-01, 4.73655969e-01, 3.82721543e-01, 3.24420296e-02],
       [2.91841850e-02, 6.59491658e-01, 3.09588790e-01, 1.73536537e-03],
       [1.74505770e-01, 3.56186748e-01, 3.63386095e-01, 1.05921343e-01],
       [2.12578088e-01, 2.63698906e-01, 3.15613061e-01, 2.08110005e-01],
       [1.39554396e-01, 4.24821049e-01, 3.79219741e-01, 5.64048029e-02],
       [6.48860484e-02, 5.60253501e-01, 3.65250826e-01, 9.60955583e-03],
       [6.80125132e-02, 5.55125296e-01, 3.66266280e-01, 1.05958479e-02],
       [3.21957730e-02, 6.51008546e-01, 3.14714670e-01, 2.08104192e-03],
       [5.21889403e-02, 5.92635512e-01, 3.49153280e-01, 6.02234853e-03],
       [2.01082140e-01, 2.98966259e-01, 3.36467147e

In [42]:
class_preds = np.argmax(preds, axis = 1)
class_preds

array([2, 2, 1, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 3,
       3, 3, 1, 1, 2, 3, 1, 1, 3, 1, 3, 1, 1, 3, 3, 1, 1, 1, 3, 2, 3, 1,
       1, 3, 1, 3, 1, 1], dtype=int64)

In [43]:
len(class_preds)

50

In [44]:
class_preds = class_preds + 1

In [45]:
class_preds

array([3, 3, 2, 4, 2, 2, 3, 3, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 4,
       4, 4, 2, 2, 3, 4, 2, 2, 4, 2, 4, 2, 2, 4, 4, 2, 2, 2, 4, 3, 4, 2,
       2, 4, 2, 4, 2, 2], dtype=int64)

In [46]:
outs = [(dev_tids[i], class_preds[i]) for i in range(len(preds))]
outs = np.array(outs)
outs

array([[101,   3],
       [106,   3],
       [107,   2],
       [114,   4],
       [123,   2],
       [128,   2],
       [133,   3],
       [139,   3],
       [142,   2],
       [164,   2],
       [165,   2],
       [166,   2],
       [169,   2],
       [174,   3],
       [ 18,   2],
       [190,   2],
       [191,   2],
       [193,   2],
       [195,   2],
       [200,   2],
       [ 24,   2],
       [ 25,   4],
       [ 27,   4],
       [ 35,   4],
       [ 37,   2],
       [ 45,   2],
       [ 71,   3],
       [ 74,   4],
       [  8,   2],
       [ 85,   2],
       [214,   4],
       [219,   2],
       [229,   4],
       [250,   2],
       [252,   2],
       [262,   4],
       [283,   4],
       [287,   2],
       [293,   2],
       [110,   2],
       [118,   4],
       [126,   3],
       [152,   4],
       [ 20,   2],
       [ 44,   2],
       [ 51,   4],
       [ 79,   2],
       [ 83,   4],
       [256,   2],
       [292,   2]], dtype=int64)

In [48]:
np.savetxt('preds_dev_gcn_sbert_cosine.txt', outs, fmt="%s %s")