<a href="https://colab.research.google.com/github/leonardo3108/IA368dd/blob/main/exercicios/Aula_6/Aula_6_Generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Enunciado

* Treinar um modelo seq2seq (a partir do T5-base) na tarefa de expansão de documentos usando o doc2query
* Usar como treino o dataset "tiny" do MS MARCO na tarefa doc2query
https://storage.googleapis.com/unicamp-dl/ia368dd_2023s1/msmarco/msmarco_triples.train.tiny.tsv
* doc2query: A entrada é a passagem e o target é a query
Note que apenas pares (query, passagem relevante) são usados como treino.
O treino é relativamente rápido (<1 hora).
* Validar a cada X steps usando o sacreBLEU 
* A parte lenta deste exercício é a pré-indexação: para cada documento da coleção, temos que gerar uma ou mais queries, que depois são concatenadas ao documento original, e esse documento "expandido" é indexado.
* Avaliar no TREC-COVID (171K docs), pois é menor que o MS MARCO/TREC-DL 2020 (8.8M passagens). 
  * Indice invertido do Trec-covid no pyserini: beir-v1.0.0-trec-covid-flat
  * Corpus e queries na HF: https://huggingface.co/datasets/BeIR/trec-covid
  * qrels: https://huggingface.co/datasets/BeIR/trec-covid-qrels
  * Usar nDCG@10
  * Comparar com o BM25 com e sem os documentos expandidos pelo doc2query

# Setup

## Integração com Google Drive

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Instalação de libs

In [2]:
!pip install transformers
!pip install datasets
!pip install pyserini
!pip install faiss-gpu
!pip install evaluate
!pip install sacrebleu
!pip install trectools

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.27.4-py3-none-any.whl (6.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m68.5 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m100.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting huggingface-hub<1.0,>=0.11.0
  Downloading huggingface_hub-0.13.4-py3-none-any.whl (200 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m200.1/200.1 kB[0m [31m24.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.13.4 tokenizers-0.13.3 transformers-4.27.4
Looking in indexes: https://pypi.org/simple, htt

## Importação de libs

In [3]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import evaluate
import json
import torch
import os

from datasets import load_dataset
from pyserini.index import IndexReader
from pyserini.search import SimpleSearcher
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import T5Tokenizer, T5ForConditionalGeneration,AutoModelForSeq2SeqLM, T5Config, AdamW, Adafactor, GenerationConfig
from pathlib import Path

## Utilização de GPUs

In [4]:
if torch.cuda.is_available(): 
   dev = "cuda:0"
else: 
   dev = "cpu"
device = torch.device(dev)
print('Using {}'.format(device))

Using cuda:0


In [5]:
if dev != 'cpu':
    !nvidia-smi

Wed Apr 12 12:48:59 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   43C    P8     9W /  70W |      3MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# Preparação do modelo

## Carga do tokenizador

In [6]:
tokenizer = T5Tokenizer.from_pretrained('t5-base')

Downloading (…)ve/main/spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


## Carga do modelo

In [7]:
path = '/content/drive/MyDrive/temp'
model = T5ForConditionalGeneration.from_pretrained(path).to(device)

## Parâmetros de geração

In [17]:
generation_params = GenerationConfig( 
    do_sample=False, 
    num_beams=10, 
    num_return_sequences=1
)

# Preparação dos dados

## Obtenção

In [8]:
!wget https://huggingface.co/datasets/BeIR/trec-covid/resolve/main/corpus.jsonl.gz

--2023-04-12 12:49:32--  https://huggingface.co/datasets/BeIR/trec-covid/resolve/main/corpus.jsonl.gz
Resolving huggingface.co (huggingface.co)... 18.172.170.44, 18.172.170.14, 18.172.170.36, ...
Connecting to huggingface.co (huggingface.co)|18.172.170.44|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/a8/10/a810e88b0e7b233be82b89c1fa6ec2d75efc6d55784c2ada9dcac8434a634f3a/e9e97686e3138eaff989f67c04cd32e8f8f4c0d4857187e3f180275b23e24e85?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27corpus.jsonl.gz%3B+filename%3D%22corpus.jsonl.gz%22%3B&response-content-type=application%2Fgzip&Expires=1681562972&Policy=eyJTdGF0ZW1lbnQiOlt7IlJlc291cmNlIjoiaHR0cHM6Ly9jZG4tbGZzLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2E4LzEwL2E4MTBlODhiMGU3YjIzM2JlODJiODljMWZhNmVjMmQ3NWVmYzZkNTU3ODRjMmFkYTlkY2FjODQzNGE2MzRmM2EvZTllOTc2ODZlMzEzOGVhZmY5ODlmNjdjMDRjZDMyZThmOGY0YzBkNDg1NzE4N2UzZjE4MDI3NWIyM2UyNGU4NT9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9u

In [9]:
!gzip -dv corpus.jsonl.gz

corpus.jsonl.gz:	 66.8% -- replaced with corpus.jsonl


In [10]:
!head corpus.jsonl

{"_id": "ug7v899j", "title": "Clinical features of culture-proven Mycoplasma pneumoniae infections at King Abdulaziz University Hospital, Jeddah, Saudi Arabia", "text": "OBJECTIVE: This retrospective chart review describes the epidemiology and clinical features of 40 patients with culture-proven Mycoplasma pneumoniae infections at King Abdulaziz University Hospital, Jeddah, Saudi Arabia. METHODS: Patients with positive M. pneumoniae cultures from respiratory specimens from January 1997 through December 1998 were identified through the Microbiology records. Charts of patients were reviewed. RESULTS: 40 patients were identified, 33 (82.5%) of whom required admission. Most infections (92.5%) were community-acquired. The infection affected all age groups but was most common in infants (32.5%) and pre-school children (22.5%). It occurred year-round but was most common in the fall (35%) and spring (30%). More than three-quarters of patients (77.5%) had comorbidities. Twenty-four isolates (60

## Extração dos textos

In [11]:
texts = []
for line in open('corpus.jsonl', 'r'):
    doc_data = json.loads(line)
    texts.append(doc_data['title'] + '\n' + doc_data['text'])
len(texts)

171332

# Execução do modelo

## Piloto com 16 documentos

In [57]:
tokenized_inputs = tokenizer(texts[:16], return_tensors = "pt", max_length = 256, padding = "max_length", truncation = True).to(device)
tokenized_inputs

{'input_ids': tensor([[14067,   753,    13,  ...,     5,  2712,     1],
        [ 2504,  3929, 21491,  ...,     0,     0,     0],
        [ 3705,  8717,   288,  ..., 24613,    23,     1],
        ...,
        [    3, 31334,    23,  ...,    32,  5529,     1],
        [   37,   353,    13,  ...,     0,     0,     0],
        [ 4908,    18, 28842,  ...,   224,   608,     1]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1]], device='cuda:0')}

In [58]:
generated_ids = model.generate(
    input_ids = tokenized_inputs["input_ids"], 
    attention_mask = tokenized_inputs["attention_mask"], 
    generation_config = generation_params
)
generated_ids

tensor([[    0,   125,    33,     8,  3739,   753,    13,    82,   509, 21178,
             9, 30195,    15,     1,     0,     0,     0,     0,     0,     0],
        [    0,    19,     3,    29,    23,  3929, 21491,     3,     9,   813,
            18, 15329,  3102,     1,     0,     0,     0,     0,     0,     0],
        [    0,   125,    19,     8,  1750,   344,   244,  8717,   288,  3619,
             3,    26,    11,     3, 26836,  2290,  4453,     1,     0,     0],
        [    0,   125,    19,   414,    32,   532,    40,    77,  2292,     1,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
        [    0,   125,    19,     8,  1773,    13, 19944,     3,     7,    63,
         11298, 10646,  6722,    41,  5249,   553,    61,    11, 30195,  6722],
        [    0,   125,    19,     8,  5932,   831,    21, 26950,  1162,     3,
         17282,    13,     3,    29,    23,    26,    32, 18095,     3,    51],
        [    0,   405,  3017,    89,  9381,   

In [59]:
for ids in generated_ids:
    print(tokenizer.decode(ids, skip_special_tokens=True))

what are the clinical features of mycoplasma pneumoniae
is nitric oxide a pro-inflammatory agent
what is the difference between surfactant protein d and pulmonary host defense
what is endothelin-1
what is the response of respiratory syncytial virus (RSV) and pneumonia virus
what is the sequence required for discontinuous synthesis of nidovirus m
does transfusing to normal haemoglobin improve survival
what was the theme of the 21st international conference on icm
what does heme oxygenase-1 do
what is a rods system
scizosaccharomyces pombe frameshift function
hnrnp a1 regulates mRNA synthesis
what is p62's uba domain
what is the role of the microtubule cytoskeleton in vaccinia
what was the site of origin of the 1918 influenza pandemic
what is a multi-virus array


## Execução

In [45]:
queries = generated_queries.copy()

In [69]:
batch_size = 16
generated_queries = []

total = len(texts)
first_id = 0
batches = 0
while first_id < total:
    last_id = first_id + batch_size
    tokenized_inputs = tokenizer(texts[first_id:last_id], return_tensors = "pt", max_length = 256, padding = "max_length", truncation = True).to(device)
    generated_ids = model.generate(
        input_ids = tokenized_inputs["input_ids"], 
        attention_mask = tokenized_inputs["attention_mask"], 
        generation_config = generation_params
    )
    for ids in generated_ids:
        query = tokenizer.decode(ids, skip_special_tokens=True)
        generated_queries.append(query)
    first_id = last_id
    batches += 1
    if batches % 100 == 0:
        print(f'{100 * first_id / total:.4f}% of corpus processed.')
print('Total:', len(generated_queries))

0.9339% of corpus processed.
1.8677% of corpus processed.
2.8016% of corpus processed.
3.7354% of corpus processed.
4.6693% of corpus processed.
5.6032% of corpus processed.
6.5370% of corpus processed.
7.4709% of corpus processed.
8.4047% of corpus processed.
9.3386% of corpus processed.
10.2725% of corpus processed.
11.2063% of corpus processed.
12.1402% of corpus processed.
13.0740% of corpus processed.
14.0079% of corpus processed.
14.9418% of corpus processed.
15.8756% of corpus processed.
16.8095% of corpus processed.
17.7433% of corpus processed.
18.6772% of corpus processed.
19.6110% of corpus processed.
20.5449% of corpus processed.
21.4788% of corpus processed.
22.4126% of corpus processed.
23.3465% of corpus processed.
24.2803% of corpus processed.
25.2142% of corpus processed.
26.1481% of corpus processed.
27.0819% of corpus processed.
28.0158% of corpus processed.
28.9496% of corpus processed.
29.8835% of corpus processed.
30.8174% of corpus processed.
31.7512% of corpus p

## Guarda das queries

In [70]:
generated_queries[:20]

['what are the clinical features of mycoplasma pneumoniae',
 'is nitric oxide a pro-inflammatory agent',
 'what is the difference between surfactant protein d and pulmonary host defense',
 'what is endothelin-1',
 'what is the response of respiratory syncytial virus (RSV) and pneumonia virus',
 'what is the sequence required for discontinuous synthesis of nidovirus m',
 'does transfusing to normal haemoglobin improve survival',
 'what was the theme of the 21st international conference on icm',
 'what does heme oxygenase-1 do',
 'what is a rods system',
 'scizosaccharomyces pombe frameshift function',
 'hnrnp a1 regulates mRNA synthesis',
 "what is p62's uba domain",
 'what is the role of the microtubule cytoskeleton in vaccinia',
 'what was the site of origin of the 1918 influenza pandemic',
 'what is a multi-virus array',
 'pathogenicity of herpes simplex virus type 1 in critically ill patients',
 'logistics of community smallpox control',
 'what is hmyh adenine glycosylase',
 'what i

In [75]:
path = '/content/drive/MyDrive/temp'

with open(path + '/generated_queries.txt', 'w') as fout:
    for query in generated_queries:
        fout.write(query)
        fout.write('\n')

In [72]:
len(generated_queries)

171332