In [32]:
import pickle
from sentence_transformers import SentenceTransformer, util
import torch
from torch import nn
import random

In [33]:
# load training data
training_data = pickle.load(open('../../data/reddit/bbc_news_scrape_raw.pkl', 'rb'))
print(training_data[0].keys())

dict_keys(['post_id', 'comment_id', 'url', 'ancestors', 'text', 'full_context'])


In [34]:
sbert_model = SentenceTransformer('msmarco-distilbert-cos-v5')

In [41]:
print(len(training_data), training_data[0]['full_context'], training_data[0]['url'])

6000 ["it was this [raid](http://www.attorneygeneral.gov/press.aspx?id=1299)...\numm... I don't know, I sold the rest of the weed, Paid my dude off. At first I\nhad no idea how big it was. I just thought it was me and my roommates. The\ngirl who's apartment I spent the night at (in her roommates empty bed) was\nactually the girlfriend of my coke dealer (Justin in the article). She heard\nfrom me at first, but it wasn't until day two of being on the lamb that we\nlearned that he had been picked up for a controlled sale of a half ounce of\ncoke to the same informant that busted me. Snitches name was Julia btw. I was\nmore of a weed dealer, but I would do favors for friends, coke LSD Mushrooms\netc. That snitch approached me and wanted an ounce coke. I knew I could do it.\nI texted my dude and I was like 'I want one' he thought I wanted a gram. And\nthat is all I gave to Julia that night. It probably ended up saving me from a\nlot of trouble in retrospect. I had over a grand, but I couldn

In [42]:
REMOVE_LAST_COMMENT=False

# encode the text data
X = []
Y = []
for example in training_data:
    if REMOVE_LAST_COMMENT:
        encoded_context = sbert_model.encode(example['full_context'][:-1],convert_to_tensor=True)
    else:
        encoded_context = sbert_model.encode(example['full_context'],convert_to_tensor=True)
    X.append(encoded_context)

# text from the url
Y = sbert_model.encode([example['text'] for example in training_data], convert_to_tensor=True)

In [43]:
X_train = X[:int(len(training_data) * 0.75)]
Y_train = Y[:int(len(training_data) * 0.75)]

X_test = X[int(len(training_data) * 0.75):]
Y_test = Y[int(len(training_data) * 0.75):]

In [45]:
class URLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(URLSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers)
        # self.fc = nn.Linear(hidden_size, )

    def forward(self, x, prev_state):
        output, state = self.lstm(x, prev_state)
        return output, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, 1, self.input_size),
                torch.zeros(self.num_layers, 1, self.input_size))

input_size = 768
hidden_size = 768
num_layers = 2

model = URLSTM(input_size, hidden_size, num_layers)
model.to('cuda:0')


URLSTM(
  (lstm): LSTM(768, 768, num_layers=2)
)

In [46]:
learning_rate = 1e-4
loss_fn = nn.CosineEmbeddingLoss(margin=0.5)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [47]:
def train(X_train, Y_train, model):
    total_loss = 0
    model.train()
    for i, (x,y) in enumerate(zip(X_train, Y_train)):
        state_h, state_c = model.init_state(len(x))
        state_h = state_h.to('cuda:0')
        state_c = state_c.to('cuda:0')

        pred, (state_h, state_c) = model(torch.unsqueeze(x, 1), (state_h, state_c)) 

        condition = torch.tensor(1).to('cuda:0')
        loss = loss_fn(state_h[-1][0], y, condition)

        # randomly sample negative examples
        # should really do contrastive loss over the batch
        for i in range(5):
            neg_idx = random.randint(0,len(X_train))
            if neg_idx == i: continue

            x_neg = X_train[i]
            y_neg = Y_train[i]

            state_h, state_c = model.init_state(len(x_neg))
            state_h = state_h.to('cuda:0')
            state_c = state_c.to('cuda:0')

            pred, (state_h, state_c) = model(torch.unsqueeze(x_neg, 1), (state_h, state_c))

            condition = torch.tensor(0).to('cuda:0')
            loss_neg = loss_fn(state_h[-1][0], y_neg, condition)

            loss += loss_neg
            
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
    return total_loss / len(X_train)

In [48]:
def eval(X_test, Y_test, model, k=10):
    run = []
    model.eval()
    for i,x in enumerate(X_test):
        state_h, state_c = model.init_state(len(x))
        state_h = state_h.to('cuda:0')
        state_c = state_c.to('cuda:0')
        pred, (state_h, state_c) = model(torch.unsqueeze(x, 1), (state_h, state_c))

        encoded_query = pred[-1].squeeze()

        cos_scores = util.cos_sim(encoded_query, Y_test)[0]
        top_results = torch.topk(cos_scores, k=k)
        for j, (score, idx) in enumerate(zip(top_results[0], top_results[1])):
            # 0 Q0 0 1 193.457108 Anserini
            run.append(' '.join([str(i), 'Q0', str(idx.item()), str(j), str(score.item()), 'LSTM']))
    with open('run.bbc_lstm.txt', 'w') as f:
        for result in run:
            f.write(result + '\n')

In [49]:
for epoch in range(20):
    loss = train(X_train, Y_train, model)
    print(loss)

eval(X_test, Y_test, model)

0.7720086172554228
0.6064042072428597
0.549935013016065
0.5169137947691812
0.49287047290802
0.4731370198196835
0.45563307891951665
0.43929480600357057
0.42351988543404473
0.4079572091897329
0.3924367056025399
0.3769024339649412
0.3613540124628279
0.345854930334621
0.3305822537740072
0.31593244398964776
0.3019901195367177
0.2884418619076411
0.27527661522229513
0.2628413880268733


In [51]:
!python3 -m pyserini.eval.trec_eval -m map -m P.1 ../../data/reddit/bbc_news_rel.txt run.bbc_lstm.txt

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Downloading https://search.maven.org/remotecontent?filepath=uk/ac/gla/dcs/terrierteam/jtreceval/0.0.5/jtreceval-0.0.5-jar-with-dependencies.jar to /home/mjin11/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar...
/home/mjin11/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar already exists!
Skipping download.
Running command: ['java', '-jar', '/home/mjin11/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar', '-m', 'map', '-m', 'P.1', '../../data/reddit/bbc_news_rel.txt', 'run.bbc_lstm.txt']
Results:
map                   	all	0.5592
P_1                   	all	0.4667

