In [4]:
import pickle
from sentence_transformers import SentenceTransformer, util
import torch
from torch import nn
import random

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
# load training data
training_data = pickle.load(open('../../data/reddit/bbc_news_scrape_raw.pkl', 'rb'))
print(training_data[0].keys())

dict_keys(['post_id', 'comment_id', 'url', 'ancestors', 'text', 'full_context'])


In [6]:
sbert_model = SentenceTransformer('msmarco-distilbert-cos-v5')

In [7]:
len(training_data)

2000

In [9]:
REMOVE_LAST_COMMENT=False

# encode the text data
X = []
Y = []
for example in training_data:
    if REMOVE_LAST_COMMENT:
        encoded_context = sbert_model.encode(example['full_context'][:-1],convert_to_tensor=True)
    else:
        encoded_context = sbert_model.encode(example['full_context'],convert_to_tensor=True)
    X.append(encoded_context)

# text from the url
Y = sbert_model.encode([example['text'] for example in training_data], convert_to_tensor=True)

X_train = X[:int(len(training_data) * 0.75)]
Y_train = Y[:int(len(training_data) * 0.75)]

X_test = X[int(len(training_data) * 0.75):]
Y_test = Y[int(len(training_data) * 0.75):]

In [10]:
class URLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super(URLSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers)

    def forward(self, x, prev_state):
        output, state = self.lstm(x, prev_state)
        return output, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, 1, self.input_size),
                torch.zeros(self.num_layers, 1, self.input_size))

input_size = 768
hidden_size = 768
num_layers = 1

model = URLSTM(input_size, hidden_size, num_layers)
model.to('cuda:0')


URLSTM(
  (lstm): LSTM(768, 768)
)

In [11]:
learning_rate = 1e-1
loss_fn = nn.CosineEmbeddingLoss(margin=0.5)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [12]:
def train(X_train, Y_train, model):
    total_loss = 0
    model.train()
    for i, (x,y) in enumerate(zip(X_train, Y_train)):
        state_h, state_c = model.init_state(len(x))
        state_h = state_h.to('cuda:0')
        state_c = state_c.to('cuda:0')

        pred, (state_h, state_c) = model(torch.unsqueeze(x, 1), (state_h, state_c))

        condition = torch.tensor(1).to('cuda:0')
        loss = loss_fn(state_h[-1][0], y, condition)

        # randomly sample negative examples
        # should really do contrastive loss over the batch
        for i in range(5):
            neg_idx = random.randint(0,len(X_train))
            if neg_idx == i: continue

            x_neg = X_train[i]
            y_neg = Y_train[i]

            state_h, state_c = model.init_state(len(x_neg))
            state_h = state_h.to('cuda:0')
            state_c = state_c.to('cuda:0')

            pred, (state_h, state_c) = model(torch.unsqueeze(x_neg, 1), (state_h, state_c))

            condition = torch.tensor(0).to('cuda:0')
            loss_neg = loss_fn(state_h[-1][0], y_neg, condition)

            loss += loss_neg
            
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
    return total_loss / len(X_train)

In [13]:
def eval(X_test, Y_test, model, k=10):
    run = []
    model.eval()
    for i,x in enumerate(X_test):
        state_h, state_c = model.init_state(len(x))
        state_h = state_h.to('cuda:0')
        state_c = state_c.to('cuda:0')
        pred, (state_h, state_c) = model(torch.unsqueeze(x, 1), (state_h, state_c))

        encoded_query = state_h[-1][0]

        cos_scores = util.cos_sim(encoded_query, Y_test)[0]
        top_results = torch.topk(cos_scores, k=k)
        for j, (score, idx) in enumerate(zip(top_results[0], top_results[1])):
            # 0 Q0 0 1 193.457108 Anserini
            run.append(' '.join([str(i), 'Q0', str(idx.item()), str(j), str(score.item()), 'LSTM']))
    with open('run.bbc_lstm.txt', 'w') as f:
        for result in run:
            f.write(result + '\n')

In [14]:
for epoch in range(100):
    loss = train(X_train, Y_train, model)
    print(loss)

eval(X_test, Y_test, model)

0.9539169545173645
0.9204665031035741
0.9084327460130056
0.903541519443194
0.9044022781848907
0.9014742388725281
0.9004322565396626
0.8959063827196757
0.8899603281815847
0.8805830189387004
0.8779329530795416
0.880387656211853
0.87794398188591
0.8742574352025986
0.8677788877089818
0.8627379221916198
0.8600226390361786
0.8576407336393992
0.8586472557783127
0.8574633162021637
0.8585937035878499
0.8596939259767532
0.8556704318920771
0.855423396507899
0.8538799122571945
0.8562172104120255
0.8541538262367249
0.8541197048425675
0.8557214626471201
0.855997831026713
0.8526761809984843
0.8483020449876786
0.8483032863537471
0.8510729819933573
0.8491122448047003
0.8490739744504293
0.8490649414857229
0.8471759220759074
0.8456326934893926
0.8459368196328481
0.8458845191399257
0.8456139040788014
0.8423839308818182


In [1]:
!python3 -m pyserini.eval.trec_eval -m map ../../data/reddit/bbc_news_rel.txt run.bbc_lstm.txt

Downloading https://search.maven.org/remotecontent?filepath=uk/ac/gla/dcs/terrierteam/jtreceval/0.0.5/jtreceval-0.0.5-jar-with-dependencies.jar to /home/mjin11/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar...
/home/mjin11/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar already exists!
Skipping download.
Running command: ['java', '-jar', '/home/mjin11/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar', '-m', 'map', '../../data/reddit/bbc_news_rel.txt', 'run.bbc_lstm.txt']
Results:
map                   	all	0.0021

