In [2]:
import pickle
import json
import re
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
from torch import nn

In [3]:
# load training data
training_data = pickle.load(open('../../data/reddit/bbc_news_scrape_raw.pkl', 'rb'))
training_data[0].keys()

dict_keys(['post_id', 'comment_id', 'url', 'ancestors', 'text', 'full_context'])

In [4]:
'''
Frame as a retrieval problem

comment thread c1 --> c2 --> ... --> cn
link in cn leads to article with paragraphs p1 --> p2 --> pn

simple approach:

take beginning of article, encode it with IR model
lstm on each comment (embedding initialized with BERT?)
'''
bert_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased")
bert_model = AutoModelForMaskedLM.from_pretrained("distilbert-base-cased")
sbert_model = SentenceTransformer('msmarco-distilbert-cos-v5')

In [22]:
# encode the text data
X = []
Y = []
for example in training_data:
    encoded_context = sbert_model.encode(example['full_context'],convert_to_tensor=True)
    X.append(encoded_context)


Y = sbert_model.encode([" ".join(example['text']) for example in training_data], convert_to_tensor=True)

In [23]:
class URLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(URLSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers)

    def forward(self, x, prev_state):
        output, state = self.lstm(x, prev_state)
        return output, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, 1, self.input_size),
                torch.zeros(self.num_layers, 1, self.input_size))

input_size = 768
hidden_size = 768
num_layers = 1

model = URLSTM(input_size, hidden_size, num_layers)
model.to('cuda:0')


URLSTM(
  (lstm): LSTM(768, 768)
)

In [24]:
learning_rate = 1e-1
loss_fn = nn.CosineEmbeddingLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [25]:
for epoch in range(100):
    total_loss = 0
    for x,y in zip(X,Y):
        state_h, state_c = model.init_state(len(x))
        state_h = state_h.to('cuda:0')
        state_c = state_c.to('cuda:0')


        pred, (state_h, state_c) = model(torch.unsqueeze(x, 1), (state_h, state_c))

        condition = torch.tensor(1).to('cuda:0')

        loss = loss_fn(state_h[-1][0], y, condition)

        state_h = state_h.detach()
        state_c = state_c.detach()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
    if epoch % 10 == 0:
        print(total_loss / len(X))

0.8676569685339928
0.7914408115964187
0.763634228784787
0.750069588814911
0.7316822452764762
0.7320737991678087
0.7293162534111425
0.7288060125551725
0.7250551765686587
0.7101579649667991


In [32]:
# simple code to output results in trec eval style
# in future, will need to split to train dev test and probably port this code to an importable module
# current is clearly not meaningful, but it shows that we are learning something


def eval(X, Y, model, k=10):
    run = []
    model.eval()
    for i,(x,y) in enumerate(zip(X,Y)):
        state_h, state_c = model.init_state(len(x))
        state_h = state_h.to('cuda:0')
        state_c = state_c.to('cuda:0')
        pred, (state_h, state_c) = model(torch.unsqueeze(x, 1), (state_h, state_c))

        encoded_query = state_h[-1][0]

        cos_scores = util.cos_sim(encoded_query, Y)[0]
        top_results = torch.topk(cos_scores, k=k)
        for j, (score, idx) in enumerate(zip(top_results[0], top_results[1])):
            # 0 Q0 0 1 193.457108 Anserini
            run.append(' '.join([str(i), 'Q0', str(idx.item()), str(j), str(score.item()), 'LSTM']))

    with open('run.bbc_lstm.txt', 'w') as f:
        for result in run:
            f.write(result + '\n')

In [33]:
eval(X, Y, model)

In [34]:
!python3 -m pyserini.eval.trec_eval -m map -m recall.1 -m recall.5 ../../data/reddit/pyserini/bbc_news_rel.txt run.bbc_lstm.txt

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Downloading https://search.maven.org/remotecontent?filepath=uk/ac/gla/dcs/terrierteam/jtreceval/0.0.5/jtreceval-0.0.5-jar-with-dependencies.jar to /home/kjros2/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar...
/home/kjros2/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar already exists!
Skipping download.
Running command: ['java', '-jar', '/home/kjros2/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar', '-m', 'map', '-m', 'recall.1', '-m', 'recall.5', '../../data/reddit/pyserini/bbc_news_rel.txt', 'run.bbc_lstm.txt']
Results:
map                   	all	0.2871
recall_1              	all	0.1974

