In [77]:
import pickle
from sentence_transformers import SentenceTransformer, util
import torch
from torch import nn
import random

In [10]:
# load training data
training_data = pickle.load(open('../../data/reddit/bbc_news_scrape_raw.pkl', 'rb'))
print(training_data[0].keys())

dict_keys(['post_id', 'comment_id', 'url', 'ancestors', 'text', 'full_context'])
<class 'str'>


In [20]:
sbert_model = SentenceTransformer('msmarco-distilbert-cos-v5')

In [69]:
REMOVE_LAST_COMMENT=False

# encode the text data
X = []
Y = []
for example in training_data:
    if REMOVE_LAST_COMMENT:
        encoded_context = sbert_model.encode(example['full_context'][:-1],convert_to_tensor=True)
    else:
        encoded_context = sbert_model.encode(example['full_context'],convert_to_tensor=True)
    X.append(encoded_context)


Y = sbert_model.encode([example['text'] for example in training_data], convert_to_tensor=True)

X_train = X[:20]
Y_train = Y[:20]

X_test = X[:-20]
Y_test = Y[:-20]

In [79]:
class URLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(URLSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers)

    def forward(self, x, prev_state):
        output, state = self.lstm(x, prev_state)
        return output, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, 1, self.input_size),
                torch.zeros(self.num_layers, 1, self.input_size))

input_size = 768
hidden_size = 768
num_layers = 1

model = URLSTM(input_size, hidden_size, num_layers)
model.to('cuda:0')


URLSTM(
  (lstm): LSTM(768, 768)
)

In [80]:
learning_rate = 1e-1
loss_fn = nn.CosineEmbeddingLoss(margin=0.5)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [81]:
def train(X_train, Y_train, model):
    total_loss = 0
    model.train()
    for i,(x,y) in enumerate(zip(X_train,Y_train)):
        state_h, state_c = model.init_state(len(x))
        state_h = state_h.to('cuda:0')
        state_c = state_c.to('cuda:0')

        pred, (state_h, state_c) = model(torch.unsqueeze(x, 1), (state_h, state_c))

        condition = torch.tensor(1).to('cuda:0')
        loss = loss_fn(state_h[-1][0], y, condition)

        # randomly sample negative examples
        # should really do contrastive loss over the batch
        for i in range(5):
            neg_idx = random.randint(0,len(X_train))
            if neg_idx == i: continue

            x_neg = X_train[i]
            y_neg = Y_train[i]

            state_h, state_c = model.init_state(len(x_neg))
            state_h = state_h.to('cuda:0')
            state_c = state_c.to('cuda:0')

            pred, (state_h, state_c) = model(torch.unsqueeze(x_neg, 1), (state_h, state_c))

            condition = torch.tensor(0).to('cuda:0')
            loss_neg = loss_fn(state_h[-1][0], y_neg, condition)

            loss += loss_neg
            
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
    return total_loss / len(X_train)

In [82]:
def eval(X_test, Y_test, model, k=10):
    run = []
    model.eval()
    for i,x in enumerate(X_test):
        state_h, state_c = model.init_state(len(x))
        state_h = state_h.to('cuda:0')
        state_c = state_c.to('cuda:0')
        pred, (state_h, state_c) = model(torch.unsqueeze(x, 1), (state_h, state_c))

        encoded_query = state_h[-1][0]

        cos_scores = util.cos_sim(encoded_query, Y_test)[0]
        top_results = torch.topk(cos_scores, k=k)
        for j, (score, idx) in enumerate(zip(top_results[0], top_results[1])):
            # 0 Q0 0 1 193.457108 Anserini
            run.append(' '.join([str(i), 'Q0', str(idx.item()), str(j), str(score.item()), 'LSTM']))
    with open('run.bbc_lstm.txt', 'w') as f:
        for result in run:
            f.write(result + '\n')

In [83]:
for epoch in range(100):
    loss = train(X_train, Y_train, model)
    print(loss)


eval(X_test, Y_test, model)

0.8829471468925476
0.8522069096565247
0.8509826809167862
0.8506454914808274
0.8504580587148667
0.8502295732498169
0.8500773042440415
0.849984535574913
0.8499258428812027
0.8498474180698394
0.8497972935438156
0.8497100442647934
0.8496699750423431
0.8496271789073944
0.8497452348470688
0.8484131783246994
0.8481456249952316
0.8416735976934433
0.9045635849237442
0.8909571290016174
0.876683184504509
0.8564090132713318
0.8359565228223801
0.8304447293281555
0.8172893732786178
0.8070936590433121
0.7974644124507904
0.7962585002183914
0.7885753810405731
0.7846668988466263
0.7855526298284531
0.7840627163648606
0.781375727057457
0.7794364482164383
0.7783875405788422
0.7764515042304992
0.7766604006290436
0.7759959429502488
0.7757554173469543
0.7741636991500854
0.7749157577753067
0.7739398330450058
0.7742692083120346
0.7735579520463943
0.7733197957277298
0.7729277700185776
0.7725156724452973
0.7721900552511215
0.77172050178051
0.7711200147867203
0.7706597954034805
0.7700295180082322
0.769839847087860

In [84]:
!python3 -m pyserini.eval.trec_eval -m map ../../data/reddit/pyserini/bbc_news_rel.txt run.bbc_lstm.txt

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Downloading https://search.maven.org/remotecontent?filepath=uk/ac/gla/dcs/terrierteam/jtreceval/0.0.5/jtreceval-0.0.5-jar-with-dependencies.jar to /home/kjros2/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar...
/home/kjros2/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar already exists!
Skipping download.
Running command: ['java', '-jar', '/home/kjros2/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar', '-m', 'map', '../../data/reddit/pyserini/bbc_news_rel.txt', 'run.bbc_lstm.txt']
Results:
map                   	all	0.3054

