In [2]:
import pickle
from sentence_transformers import SentenceTransformer, util
import torch
from torch import nn
import random

In [3]:
# load training data
training_data = pickle.load(open('../../data/reddit/bbc_news_scrape_raw.pkl', 'rb'))
print(training_data[0].keys())

dict_keys(['post_id', 'comment_id', 'url', 'ancestors', 'text', 'full_context'])


In [4]:
sbert_model = SentenceTransformer('msmarco-distilbert-cos-v5')

In [5]:
print(len(training_data), len(training_data[6]['full_context']), training_data[0]['url'])

2000 7 http://en.wikipedia.org/wiki/No_Pants_Day


In [6]:
REMOVE_LAST_COMMENT=False

# encode the text data
X = []
Y = []
for example in training_data:
    if REMOVE_LAST_COMMENT:
        encoded_context = sbert_model.encode(example['full_context'][:-1],convert_to_tensor=True)
    else:
        encoded_context = sbert_model.encode(example['full_context'],convert_to_tensor=True)
    X.append(encoded_context)

# text from the url
Y = sbert_model.encode([example['text'] for example in training_data], convert_to_tensor=True)

In [7]:
X_train = X[:int(len(training_data) * 0.75)]
Y_train = Y[:int(len(training_data) * 0.75)]

X_test = X[int(len(training_data) * 0.75):]
Y_test = Y[int(len(training_data) * 0.75):]

In [33]:
class URLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(URLSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers)
        self.fc = nn.Linear(hidden_size * 2, hidden_size)
        self.relu = nn.ReLU()

    def forward(self, x, prev_state):
        output, state = self.lstm(x, prev_state)
        last_hidden_state = output[-1].squeeze()
        last_comment_encoding = x[-1, 0, :]
        output = self.fc(torch.cat((last_hidden_state, last_comment_encoding)))
        output = self.relu(output)
        return output, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, 1, self.input_size),
                torch.zeros(self.num_layers, 1, self.input_size))

input_size = 768
hidden_size = 768
num_layers = 2

model = URLSTM(input_size, hidden_size, num_layers)
model.to('cuda:0')


URLSTM(
  (lstm): LSTM(768, 768, num_layers=2)
  (fc): Linear(in_features=1536, out_features=768, bias=True)
  (relu): ReLU()
)

In [34]:
learning_rate = 1e-4
loss_fn = nn.CosineEmbeddingLoss(margin=0.5)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [35]:
def train(X_train, Y_train, model):
    total_loss = 0
    model.train()
    for i, (x,y) in enumerate(zip(X_train, Y_train)):
        state_h, state_c = model.init_state(len(x))
        state_h = state_h.to('cuda:0')
        state_c = state_c.to('cuda:0')

        pred, (state_h, state_c) = model(torch.unsqueeze(x, 1), (state_h, state_c)) 

        condition = torch.tensor(1).to('cuda:0')
        # loss = loss_fn(state_h[-1][0], y, condition)
        loss = loss_fn(pred, y, condition)

        # randomly sample negative examples
        # should really do contrastive loss over the batch
        for i in range(5):
            neg_idx = random.randint(0,len(X_train))
            if neg_idx == i: continue

            x_neg = X_train[i]
            y_neg = Y_train[i]

            state_h, state_c = model.init_state(len(x_neg))
            state_h = state_h.to('cuda:0')
            state_c = state_c.to('cuda:0')

            pred, (state_h, state_c) = model(torch.unsqueeze(x_neg, 1), (state_h, state_c))

            condition = torch.tensor(0).to('cuda:0')
            # loss_neg = loss_fn(state_h[-1][0], y_neg, condition)
            loss_neg = loss_fn(pred, y_neg, condition)


            loss += loss_neg
            
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
    return total_loss / len(X_train)

In [36]:
def eval(X_test, Y_test, model, k=10):
    run = []
    model.eval()
    for i,x in enumerate(X_test):
        state_h, state_c = model.init_state(len(x))
        state_h = state_h.to('cuda:0')
        state_c = state_c.to('cuda:0')
        pred, (state_h, state_c) = model(torch.unsqueeze(x, 1), (state_h, state_c))

        # encoded_query = pred[-1].squeeze()
        encoded_query = pred

        cos_scores = util.cos_sim(encoded_query, Y_test)[0]
        top_results = torch.topk(cos_scores, k=k)
        for j, (score, idx) in enumerate(zip(top_results[0], top_results[1])):
            # 0 Q0 0 1 193.457108 Anserini
            run.append(' '.join([str(i), 'Q0', str(idx.item()), str(j), str(score.item()), 'LSTM']))
    with open('run.bbc_lstm.txt', 'w') as f:
        for result in run:
            f.write(result + '\n')

In [39]:
for epoch in range(20):
    loss = train(X_train, Y_train, model)
    print(loss)

eval(X_test, Y_test, model)

0.4545632730325063
0.4514605682690938
0.44862243461608886
0.4462801754872004
0.4438993047873179
0.44146749313672384
0.4394278266429901
0.4375402568181356
0.43589641964435577
0.4340717802842458
0.43232801342010496
0.4306684990723928
0.429213254570961
0.4277275584936142
0.4262499550580978
0.4248405787150065
0.4234851983785629
0.4222828644911448
0.42084990163644154
0.41969202518463133


In [14]:
!python3 -m pyserini.eval.trec_eval -m map -m P.1 ../../data/reddit/bbc_news_rel.txt run.bbc_lstm.txt

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Downloading https://search.maven.org/remotecontent?filepath=uk/ac/gla/dcs/terrierteam/jtreceval/0.0.5/jtreceval-0.0.5-jar-with-dependencies.jar to /home/mjin11/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar...
/home/mjin11/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar already exists!
Skipping download.
Running command: ['java', '-jar', '/home/mjin11/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar', '-m', 'map', '-m', 'P.1', '../../data/reddit/bbc_news_rel.txt', 'run.bbc_lstm.txt']
Results:
map                   	all	0.4866
P_1                   	all	0.3820



In [38]:
!python3 -m pyserini.eval.trec_eval -m map -m P.1 ../../data/reddit/bbc_news_rel.txt run.bbc_lstm.txt

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Downloading https://search.maven.org/remotecontent?filepath=uk/ac/gla/dcs/terrierteam/jtreceval/0.0.5/jtreceval-0.0.5-jar-with-dependencies.jar to /home/mjin11/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar...
/home/mjin11/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar already exists!
Skipping download.
Running command: ['java', '-jar', '/home/mjin11/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar', '-m', 'map', '-m', 'P.1', '../../data/reddit/bbc_news_rel.txt', 'run.bbc_lstm.txt']
Results:
map                   	all	0.5040
P_1                   	all	0.3820



In [40]:
!python3 -m pyserini.eval.trec_eval -m map -m P.1 ../../data/reddit/bbc_news_rel.txt run.bbc_lstm.txt

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Downloading https://search.maven.org/remotecontent?filepath=uk/ac/gla/dcs/terrierteam/jtreceval/0.0.5/jtreceval-0.0.5-jar-with-dependencies.jar to /home/mjin11/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar...
/home/mjin11/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar already exists!
Skipping download.
Running command: ['java', '-jar', '/home/mjin11/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar', '-m', 'map', '-m', 'P.1', '../../data/reddit/bbc_news_rel.txt', 'run.bbc_lstm.txt']
Results:
map                   	all	0.5145
P_1                   	all	0.4060

