# Projeto final IA376E - 2020
# Título: Busca utilizando vetores densos

Nomes: Rafael Gonçalves, Thomas Portugal

## Semana 3/4: Criação da pipeline de treinamento e validação

## Configuração geral

Configuração de ambiente

In [1]:
use_cuda = True
pretrained_model = 'bert-base-uncased'
seed = 0 # reproducible

In [2]:
import os
import torch

device = 'cuda' if (torch.cuda.is_available() and use_cuda) else 'cpu'
nproc = os.cpu_count()

nproc, torch.cuda.get_device_name() if device == "cuda" else None

(2, 'Tesla P4')

Montar o drive para download de dados

In [3]:
# Mount drive
from google.colab import drive

drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Importar funções e bibliotecas

In [4]:
%%capture
!wget -nc https://raw.githubusercontent.com/spacemanidol/MSMARCO/master/Ranking/Baselines/msmarco_eval.py
!wget -nc https://raw.githubusercontent.com/RafaelGoncalves8/search-with-dense-vectors/master/src/helpers.py
!wget -nc https://raw.githubusercontent.com/RafaelGoncalves8/search-with-dense-vectors/master/src/dataset.py
!wget -nc https://raw.githubusercontent.com/RafaelGoncalves8/search-with-dense-vectors/master/src/modules.py

In [5]:
! pip install --quiet pytorch-lightning
! pip install --quiet transformers

In [6]:
import msmarco_eval
from helpers import *
from dataset import MyDataset
from modules import CosineSimilarityLoss, Encoder

In [7]:
# lightning
import pytorch_lightning as pl

# transformers
from transformers import BertTokenizer
from transformers import BertModel

# torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# others
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from argparse import ArgumentParser

Outras configurações:

In [8]:
# Pseudo-deterministic code
if seed is not None:
    np.random.seed(seed)
    torch.manual_seed(seed)

In [9]:
# matplotlib style
plt.rcParams['figure.figsize'] = (12,8)
plt.style.use('ggplot')

In [10]:
data_dir='drive/My Drive/msmarco'
docs_path = data_dir + '/collection.tsv'
qrels_train_path = data_dir + '/qrels.train.tsv'
qrels_dev_path = data_dir + '/qrels.dev.small.tsv'
queries_train_path = data_dir + '/queries.train.tsv'
queries_dev_path = data_dir + '/queries.dev.tsv'
doc2queries_path = data_dir + '/msmarco-passage/pred-test_topk10.txt'
triples_path = data_dir + '/triples/qidpidtriples.train.full.tsv'
top1000_path = data_dir +'/top1000.dev'

## Baixando os top1000 documentos para cada query

In [11]:
DATA_DIR='/content/drive/My\ Drive/msmarco'

In [12]:
# %%capture
# !wget -nc https://msmarco.blob.core.windows.net/msmarcoranking/top1000.train.tar.gz -P {DATA_DIR}

In [13]:
# %%capture
# !wget -nc https://msmarco.blob.core.windows.net/msmarcoranking/top1000.dev.tar.gz -P {DATA_DIR}

In [14]:
# !tar xvkfz {DATA_DIR}/top1000.dev.tar.gz -C {DATA_DIR}

In [15]:
!head {DATA_DIR + '/top1000.dev'}

188714	1000052	foods and supplements to lower blood sugar	Watch portion sizes: ■ Even healthy foods will cause high blood sugar if you eat too much. ■ Make sure each of your meals has the same amount of CHOs. Avoid foods high in sugar: ■ Some foods to avoid: sugar, honey, candies, syrup, cakes, cookies, regular soda and.
1082792	1000084	what does the golgi apparatus do to the proteins and lipids once they arrive ?	Start studying Bonding, Carbs, Proteins, Lipids. Learn vocabulary, terms, and more with flashcards, games, and other study tools.
995526	1000094	where is the federal penitentiary in ind	It takes THOUSANDS of Macy's associates to bring the MAGIC of MACY'S to LIFE! Our associate team is an invaluable part of who we are and what we do. F ind the seasonal job that's right for you at holiday.macysJOBS.com!
199776	1000115	health benefits of eating vegetarian	The good news is that you will discover what goes into action spurs narrowing of these foods not only a theoretical suppos

## Dataset de validação

In [16]:
class ValDataset(Dataset):
    def __init__(self, queries, docs, tokenizer, max_length, queries_1000):
        self.queries = queries
        self.docs = docs
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.val_data = [(q, doc_ids) for q, doc_ids in queries_1000.items()]

    def __len__(self):
        return len(self.val_data)

    def __getitem__(self, idx):
        qid, doc_ids = self.val_data[idx]
        query = self.queries[qid]
        q_tok, q_mask, q_type = self.tokenize(query)
        docs2query = [self.docs[doc_id] for doc_id in doc_ids]
        tokens = self.tokenizer.batch_encode_plus(batch_text_or_text_pairs=docs2query, max_length=self.max_length,
                                       pad_to_max_length=True, add_special_tokens = True,
                                       return_tensors='pt')
        docs_tok = tokens["input_ids"]
        docs_mask = tokens["attention_mask"]
        docs_type = tokens["token_type_ids"]

        return  (q_tok, q_mask, q_type, docs_tok, docs_mask, docs_type, qid, doc_ids)

    def tokenize(self, text):
        tokens = self.tokenizer.encode_plus(text=text, max_length=self.max_length,
                                       pad_to_max_length=True, add_special_tokens = True,
                                       return_tensors='pt')
        return tokens["input_ids"], tokens['attention_mask'], tokens['token_type_ids']

In [17]:
def correct_docids(doc_ids):
  batch_size = len(doc_ids[0])
  doc_ids_correct = []
  for idx in range(batch_size):
    doc_id_idx = [int(doc_id[idx]) for doc_id in doc_ids]
    doc_ids_correct.append(doc_id_idx)
  
  return doc_ids_correct

In [18]:
def load_top1000_dev(path):
  queries = {}
  with open(path) as f:
    for i, line in enumerate(f):
      
      qid, pid, _, _ = line.rstrip().split('\t')
      if qid in queries.keys():
        queries[qid].append(pid)
      else:
        queries[qid] = [pid]
      if i % 1000000 == 0:
        print('Loading top1000, doc {}'.format(i))
  return queries

## Testando somente a etapa de validação

In [19]:
max_length = 32
tokenizer = BertTokenizer.from_pretrained(pretrained_model)
#qrels = load_qrels(qrels_dev_path)
docs_queries = load_doc2query(docs_path)
queries_top1000 = load_top1000_dev(top1000_path)

Loading doc2query, doc 0
Loading doc2query, doc 1000000
Loading doc2query, doc 2000000
Loading doc2query, doc 3000000
Loading doc2query, doc 4000000
Loading doc2query, doc 5000000
Loading doc2query, doc 6000000
Loading doc2query, doc 7000000
Loading doc2query, doc 8000000
Loading top1000, doc 0
Loading top1000, doc 1000000
Loading top1000, doc 2000000
Loading top1000, doc 3000000
Loading top1000, doc 4000000
Loading top1000, doc 5000000
Loading top1000, doc 6000000


In [20]:
queries = load_queries(queries_dev_path)

Loading queries 0
Loading queries 100000


In [21]:
len(queries_top1000[str(188714)])

1000

In [22]:
encoder = Encoder(128)

In [23]:
val_dataset  = ValDataset(queries,docs_queries,tokenizer, max_length,queries_top1000)

In [24]:
list(queries_top1000.items())[0]

('188714',
 ['1000052',
  '1022490',
  '1051362',
  '1090909',
  '1097195',
  '1130809',
  '1148823',
  '1160714',
  '1200661',
  '1241103',
  '1500656',
  '1656996',
  '1687849',
  '1751526',
  '1831803',
  '1899272',
  '1996465',
  '2133570',
  '2141476',
  '2165254',
  '2287600',
  '2455992',
  '2456320',
  '2470190',
  '2517166',
  '2666492',
  '2788357',
  '3045191',
  '3045313',
  '3045314',
  '3054173',
  '3125534',
  '3125543',
  '3172505',
  '3174245',
  '3181392',
  '3241451',
  '3483672',
  '3548230',
  '3615882',
  '3719372',
  '3849449',
  '386069',
  '3928996',
  '4057562',
  '4057564',
  '4074394',
  '4283881',
  '4321739',
  '4321742',
  '4405570',
  '443989',
  '454883',
  '454884',
  '4585316',
  '46209',
  '471210',
  '4963488',
  '4963491',
  '4963495',
  '496432',
  '5007414',
  '5072224',
  '507420',
  '5092333',
  '523571',
  '5499899',
  '5499901',
  '5548810',
  '5614247',
  '5624275',
  '5631262',
  '5681536',
  '5703711',
  '5734858',
  '5804838',
  '5820625'

In [25]:
batch_size_doc = 5
dataloader = DataLoader(val_dataset,batch_size=batch_size_doc, shuffle=False)

In [26]:
q_tok, q_mask, q_type, docs_tok, docs_mask, docs_type, qid, doc_ids = next(iter(dataloader))

docs_tok = docs_tok.view((-1,max_length))
docs_mask = docs_mask.view((-1,max_length))
docs_type = docs_type.view((-1,max_length))
q_tok = q_tok.squeeze(-2)
q_mask = q_mask.squeeze(-2)
q_type = q_type.squeeze(-2)
doc_ids = torch.tensor(correct_docids(doc_ids)) #Corrige a ordem dos doc_ids

In [27]:
with torch.no_grad():
    query_embedding = encoder(q_tok,q_mask,q_type)
    doc_embedding = encoder(docs_tok,docs_mask,docs_type)
doc_embedding = doc_embedding.view((batch_size_doc,-1,128))

In [28]:
scores = nn.CosineSimilarity()(query_embedding.unsqueeze(1), doc_embedding); scores

tensor([[ 23.1358,  25.2877,  31.3584,  -7.3632,  23.2295,  24.3620,  30.9032,
          22.9093,   6.8727,  29.1552,  20.8786,  13.0957,  30.8971,  21.5632,
          30.6178,  21.8084,  25.8451,  25.1448,   9.2291,  29.6068,  29.0584,
          29.1164, -21.2582,  30.4093,  24.9229,   4.4423,  30.3647,  30.2558,
          26.1021,  -1.3349,  29.2834,  28.9829,  10.1251,  29.1005,  30.3842,
          30.7510,  29.3160,  29.0283,  18.2078,  29.0511,  30.1717,  30.9839,
          30.5954,  22.8455,  30.8748,  -4.8363,  21.4290,  19.0206,  30.7041,
          22.9558,  23.0113,  21.1528,  16.8043,  28.1179,  14.2555,  29.9271,
         -10.3083,  30.5158, -30.0086,  30.7496,  26.8493,  27.3076,  22.6187,
          27.9301,  -1.0920,  28.8284,  31.2186,  29.4756,  13.6048,   2.8503,
           3.9909,  22.6241,  29.8577, -18.7712,  29.6831,  31.1987,  25.1523,
           7.1178,  23.6047,  16.5981,  15.7792,  27.4297,  29.0860,  -8.5250,
           8.5053, -14.1449,   6.2515,   8.4176,  24

In [29]:
del encoder, doc_embedding, query_embedding

In [30]:
_, indices = torch.sort(scores)
doc_ids_sorted = []
for i, doc_id in enumerate(doc_ids):
    doc_ids_sorted.append(doc_id[indices[i]].tolist())

qrel = {}
for i, doc_ids in enumerate(doc_ids_sorted):
    qrel[int(qid[i])] = doc_ids
qrel[int(qid[0])]

[4963491,
 2456320,
 5703711,
 6416264,
 7562875,
 1492699,
 471210,
 6254756,
 71106,
 1090909,
 4057564,
 7132477,
 3045314,
 5092333,
 1031308,
 5614247,
 5624275,
 2666492,
 6420874,
 1200661,
 6039084,
 6507763,
 6301922,
 2141476,
 3125543,
 1656996,
 5548810,
 4585316,
 7889383,
 6172028,
 6167786,
 6637468,
 454883,
 3548230,
 1471691,
 4283881,
 1500656,
 443989,
 6947936,
 4074394,
 1751526,
 8820189,
 1899272,
 8351004,
 5072224,
 5631262,
 3928996,
 1160714,
 4321742,
 4405570,
 1000052,
 1097195,
 92890,
 6043639,
 8004932,
 1901494,
 1130809,
 6555516,
 2517166,
 701315,
 2133570,
 5820625,
 1022490,
 7336145,
 6759289,
 1996465,
 3045313,
 1130808,
 496432,
 6899201,
 6771762,
 5007414,
 6176398,
 507420,
 1472482,
 6650234,
 454884,
 7794099,
 1476382,
 734216,
 7278828,
 523571,
 8090379,
 3125534,
 3483672,
 3615882,
 2287600,
 6192043,
 3172505,
 2455992,
 1241103,
 3054173,
 3241451,
 8703759,
 5499901,
 1796938,
 2165254,
 5734858,
 5681536,
 46209,
 6849865,
 3719

In [31]:
del docs_queries, queries_top1000, val_dataset, dataloader

## Configuração do modelo

In [32]:
class TwoTower(pl.LightningModule):
    """Two tower model for document retrieval using query and document."""
    def __init__(self,
                 learning_rate=1e-3,
                 batch_size=5,
                 dim=64,
                 epochs = None,
                 max_length=32,
                 shuffle=True,
                 queries_train_path = queries_train_path,
                 queries_dev_path = queries_dev_path,
                 docs_path = doc2queries_path,
                 triples_path = triples_path,
                 top1000_dev_path = top1000_path,
                 qrels_dev_path = qrels_dev_path,
                 batch_size_val = 5
                 ):
        super(TwoTower, self).__init__()

        #Data
        self.queries_train = load_queries(queries_train_path)
        self.queries_dev = load_queries(queries_dev_path)
        self.docs_queries = load_doc2query(docs_path)
        self.triples = load_triple(triples_path, 10_000)
        self.queries_top1000 = load_top1000_dev(top1000_path)
        qrels = load_qrels(qrels_dev_path)
        self.qrels = {int(qid):docids for qid,docids in qrels.items()}

        self.embedding_dim = dim
        self.vectors = [] 

        # Configuration
        self.learning_rate = learning_rate
        self.batch_size=batch_size
        self.batch_size_val = batch_size_val
        self.max_length = max_length
        self.loss = CosineSimilarityLoss()
        self.similarity = nn.CosineSimilarity(dim=-1, eps=1e-08)
        self.optimizer = torch.optim.Adam
        self.opt_params = {'lr': learning_rate, 'eps': 1e-08, 'betas': (0.9, 0.999)}
        self.warmup_steps = 500
        self.training_steps = (epochs*len(self.triples))/batch_size if epochs != None else None
        self.tokenizer = BertTokenizer.from_pretrained(pretrained_model)


        self.shuffle=shuffle
        if use_cuda:
            self.pin_mem = True
            self.n_workers = 0
        else:
            self.pin_mem = False
            self.n_workers = nproc

        # Models
        self.query_encoder = Encoder(dim)
        self.doc_encoder = Encoder(dim)

    def prepare_data(self):
        
        self.train_data = MyDataset(
                triples = self.triples,
                queries = self.queries_train,
                docs = self.docs_queries,
                max_length = self.max_length,
                tokenizer=self.tokenizer)
        
        self.valid_data = ValDataset(
                queries = self.queries_dev,
                docs = self.docs_queries,
                tokenizer = self.tokenizer,
                max_length = self.max_length,
                queries_1000 = self. queries_top1000)
    
    
        

    @gpu_mem_restore
    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=self.batch_size,
                          shuffle=self.shuffle, num_workers=self.n_workers,
                          pin_memory=self.pin_mem)

    @gpu_mem_restore
    def val_dataloader(self):
        return DataLoader(self.valid_data, batch_size=self.batch_size_val,
                          shuffle=False, num_workers=self.n_workers,
                          pin_memory=self.pin_mem)
    
    def configure_optimizers(self):
        optimizer = self.optimizer(
            [e for e in self.parameters() if e.requires_grad],
            **self.opt_params)
        return optimizer

    # learning rate warm-up and linear scheduler
    def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure=None):
        # warm up lr
        if self.trainer.global_step < self.warmup_steps:
            lr_scale = min(1., float(self.trainer.global_step + 1) / float(self.warmup_steps))
            for pg in optimizer.param_groups:
                pg['lr'] = lr_scale * self.learning_rate
        # linear decrease schedule
        else:
            if self.training_steps != None:
                lr_scale = max(0.0,
                    float(self.training_steps - self.trainer.global_step) / float(max(1,
                                                    self.training_steps - self.warmup_steps)))
                for pg in optimizer.param_groups:
                    pg['lr'] = lr_scale * self.learning_rate
        # update params
        optimizer.step()
        optimizer.zero_grad()

    def forward(self,batch):
        if self.training:
          q_tok, q_mask, q_type, p_tok, p_mask, p_type, n_tok, n_mask, n_type, query, doc_pos, doc_neg = batch
          
          query_embedding = self.query_encoder(q_tok.squeeze(-2),q_mask.squeeze(-2),q_type.squeeze(-2))
          p_doc_embedding = self.doc_encoder(p_tok.squeeze(-2),p_mask.squeeze(-2),p_type.squeeze(-2))
          n_doc_embedding = self.doc_encoder(n_tok.squeeze(-2),n_mask.squeeze(-2),n_type.squeeze(-2))
          
          loss, sim_pos, sim_neg = self.loss(query_embedding, p_doc_embedding, n_doc_embedding) 
          
          return loss.mean(), sim_pos, sim_neg

        else:
          qrel = {}
          q_tok, q_mask, q_type, docs_tok, docs_mask, docs_type, qid, doc_ids = batch
          
          #Corrige as dimensões dos elementos. 
          docs_tok = docs_tok.view((-1,self.max_length))
          docs_mask = docs_mask.view((-1,self.max_length))
          docs_type = docs_type.view((-1,self.max_length))
          q_tok = q_tok.squeeze(-2)
          q_mask = q_mask.squeeze(-2)
          q_type = q_type.squeeze(-2)
          doc_ids = torch.tensor(correct_docids(doc_ids)) #Corrige a ordem dos doc_ids
          
          query_embedding = self.query_encoder(q_tok,q_mask,q_type)
          
          doc_embedding = self.doc_encoder(docs_tok,docs_mask,docs_type)
          doc_embedding = doc_embedding.view((self.batch_size_val,-1,self.embedding_dim))
          
          scores = self.similarity(query_embedding.unsqueeze(1), doc_embedding)
          
          _, indices = torch.sort(scores)

          doc_ids_sorted = []
          for i, doc_id in enumerate(doc_ids):
            doc_ids_sorted.append(doc_id[indices[i]].tolist())
          
          qrel = {}
          for i, doc_ids in enumerate(doc_ids_sorted):
            qrel[int(qid[i])] = doc_ids

          return qrel

    def _step(self, prefix, batch, batch_nb):

        if prefix == 'train':
            loss, _, _ = self(batch)
            log = {f'{prefix}_loss': loss}
            return {'loss': loss, 'log': log,}

        elif(prefix=='val'):
            qrel_pred = self(batch)
            mrr_dict = msmarco_eval.compute_metrics(self.qrels,qrel_pred)

            mrr = mrr_dict['MRR @10']
            n_queries = mrr_dict['QueriesRanked']

            log = {f'{prefix}_mrr': mrr}
            return {'mrr': mrr, 'log': log}

    def training_step(self, batch, batch_idx):
        return self._step("train", batch, batch_idx)

    def validation_step(self, batch, batch_idx):
        return self._step("val", batch, batch_idx)

    def test_step(self, batch, batch_idx):
        return self._step("test", batch, batch_idx)

    def _epoch_end(self, prefix, outputs):   
        if not outputs:
            return {}
        acc_mean = 0
        loss_mean = 0

        if prefix == 'train':
          loss_mean = torch.mean(torch.tensor([out["loss"] for out in outputs]))
          log = {
            f'{prefix}_loss': loss_mean
            } 
        if prefix == 'val':
          mrr_mean = torch.mean(torch.tensor([out["mrr"] for out in outputs], dtype=torch.float)) 
          log = {
            f'{prefix}_mrr': mrr_mean,
            } 
        


        return {'progress_bar': log, 'log': log}

    def get_query_encoder(self):
        return self.query_encoder

    def get_doc_encoder(self):
        return self.doc_encoder
    
    def training_epoch_end(self, outputs):
        return self._epoch_end("train", outputs)

    def validation_epoch_end(self, outputs):
        return self._epoch_end("val", outputs)

    def test_epoch_end(self, outputs):
        return self._epoch_end("test", outputs)

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--learning_rate', type=int, default=1e-3)
        parser.add_argument('--batch_size', type=int, default=64)
        parser.add_argument('--dim', type=int, default=64)
        parser.add_argument('--epochs', type=int, default=None)
        parser.add_argument('--max_lenght', type=int, default=32)
        parser.add_argument('--shuffle', type=bool, default=True)
        parser.add_argument('--queries-path', type=str)
        parser.add_argument('--docs-path', type=str)
        parser.add_argument('--triples-path', type=str)
        return parser

In [33]:
model = TwoTower(batch_size_val = 2)
trainer = pl.Trainer(gpus=1 if use_cuda else 0, 
                    logger=False,
                    checkpoint_callback=False,
                    fast_dev_run=True,
                    weights_summary=None,)
trainer.fit(model)

del model, trainer

Loading queries 0
Loading queries 100000
Loading queries 200000
Loading queries 300000
Loading queries 400000
Loading queries 500000
Loading queries 600000
Loading queries 700000
Loading queries 800000
Loading queries 0
Loading queries 100000
Loading doc2query, doc 0
Loading doc2query, doc 1000000
Loading doc2query, doc 2000000
Loading doc2query, doc 3000000
Loading doc2query, doc 4000000
Loading doc2query, doc 5000000
Loading doc2query, doc 6000000
Loading doc2query, doc 7000000
Loading doc2query, doc 8000000
Loading triple 0
Loading top1000, doc 0
Loading top1000, doc 1000000
Loading top1000, doc 2000000
Loading top1000, doc 3000000
Loading top1000, doc 4000000
Loading top1000, doc 5000000
Loading top1000, doc 6000000
Loading qrels 0


Running in fast_dev_run mode: will run a full train, val and test loop using a single batch
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


