<a href="https://colab.research.google.com/github/nvaikunt/PromptBasedReranking/blob/main/BaselineTest.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/nvaikunt/PromptBasedReranking.git

In [None]:
!pip install -q condacolab
import condacolab
condacolab.install()

In [None]:
!conda --version

In [None]:
%cd PromptBasedReranking/

In [None]:
!bash setup.sh

In [11]:
import torch
torch.cuda.is_available()

True

In [None]:
!python Baseline/download_dpr_data.py --resource data.wikipedia-split.psgs_w100

In [None]:
!python Baseline/download_dpr_data.py --resource data.retriever-outputs.dpr.nq-train

In [None]:

import torch
from torch.utils.data import Dataset, DataLoader
import transformers
import tokenizers
import datasets 
import pandas

wikipedia_txt = datasets.load_dataset("csv", data_files='downloads/data/wikipedia-split/psgs_w100.tsv', delimiter='\t')


In [None]:
wikipedia_txt

In [None]:
nq_open = datasets.load_dataset("json",data_files="downloads/data/retriever-outputs/dpr/nq-train.json")

In [None]:
nq_open["train"][2]["question"]

In [14]:
def get_top_k_pos(row, k, txt_database):
  ctxs = row["ctxs"]
  top_k = []
  for ctx in ctxs:
    if ctx["has_answer"]:
      text = txt_database[ctx["id"] - 1]["text"]
      top_k.append((text, "true"))
    if len(top_k) == k:
      break
  if len(top_k) == 0: 
    return []
  while len(top_k) < k:
    top_k.extend(top_k[:(k - len(top_k))])

  return top_k[:k]
  
def get_top_k_pos_neg(row, k, txt_database):
  ctxs = row["ctxs"]
  top_k_pos = []
  top_k_neg = []
  for ctx in ctxs:
    if ctx["has_answer"] and len(top_k_pos) < k:
      text = txt_database[ctx["id"] - 1]["text"]
      top_k_pos.append((text, "true"))
    if not ctx["has_answer"] and len(top_k_neg) < k: 
      text = txt_database[ctx["id"] - 1]["text"]
      top_k_neg.append((text, "false"))
    if len(top_k_pos) == k and len(top_k_neg):
      break 
  if len(top_k_pos) == 0: 
    return []
  while len(top_k_pos) < k:
    top_k_pos.extend(top_k_pos[:(k - len(top_k_pos))])
  while len(top_k_neg) < k:
    top_k_neg.extend(top_k_neg[:(k - len(top_k_neg))])
 
  return top_k_pos + top_k_neg



In [15]:
def create_pos_txt_col(example, k, txt_database):
  return {"pos_text": get_top_k_pos(example, k, txt_database)}

def create_pos_neg_txt_col(example, k, txt_database):
  return {"pos_neg_text": get_top_k_pos_neg(example, k, txt_database)}


In [None]:
from functools import partial
nq_open["train"] = nq_open["train"].map(partial(create_pos_txt_col, k=20, txt_database=wikipedia_txt["train"]), num_proc=4)

In [None]:
nq_open["train"] = nq_open["train"].map(partial(create_pos_neg_txt_col, k=10, txt_database=wikipedia_txt["train"]), num_proc=4)

In [None]:
nq_open["train"][0]["pos_text"]

In [7]:
from tqdm import tqdm
def create_ranking_loss_baseline_examples(dataset, n=None):
  if not n:
    n = len(dataset)
  with_answer = 0 
  inputs = []
  targets = []
  for i in tqdm(range(n)):
    texts = dataset[i]["pos_neg_text"]
    if not texts: continue 
    question = dataset[i]["question"]
    current_inputs = [f"Question: {question} Passage: {text[0]} Relevant: " for text in texts]
    current_targets = [text[1] for text in texts]
    inputs.extend(current_inputs)
    targets.extend(current_targets)
    with_answer += 1 
  k = [len(targets)/(with_answer * 2)] * len(targets)
  return {"inputs": inputs, "targets": targets, "k_pos_neg": k}

def create_q_gen_baseline_examples(dataset, n=None):
  if not n:
    n = len(dataset)
  inputs = []
  targets = []
  with_answer = 0 
  for i in tqdm(range(n)):
    texts = dataset[i]["pos_text"]
    if not texts: continue 
    question = dataset[i]["question"]
    current_inputs = [f"Passage: {text[0]} Please write a question based on this passage" for text in texts]
    current_targets = [question for text in texts]
    inputs.extend(current_inputs)
    targets.extend(current_targets)
    with_answer += 1     
  k = [len(targets)/(with_answer)] * len(targets)
  return {"inputs": inputs, "targets": targets, "k_pos": k}

def create_q_gen_ranking_baseline_examples(dataset, n=None):
  if not n:
    n = len(dataset)
  inputs = []
  targets = []
  with_answer = 0 
  for i in tqdm(range(n)):
    texts = dataset[i]["pos_neg_text"]
    if not texts: continue 
    question = dataset[i]["question"]
    current_inputs = [f"Passage: {text[0]} Please write a question based on this passage" for text in texts]
    current_targets = [question for text in texts]
    inputs.extend(current_inputs)
    targets.extend(current_targets)
    with_answer += 1   
  k = [len(targets)/(with_answer * 2)] * len(targets)
  return {"inputs": inputs, "targets": targets, "k_pos_neg": k}



In [None]:
training_dict = create_q_gen_baseline_examples(nq_open["train"])

In [None]:
q_gen_dataset = datasets.Dataset.from_dict(training_dict)

In [2]:
from transformers import AutoTokenizer
from transformers import T5ForConditionalGeneration

model_checkpoint = "google/t5-base-lm-adapt"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)


The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


Moving 0 files to the new cache system


0it [00:00, ?it/s]

In [None]:
q_gen_dataset[2000]

In [97]:
def preprocess_function(examples, max_input_length, max_target_length, input_col):
    model_inputs = tokenizer(
        examples[input_col],
        max_length=max_input_length,
        truncation=True,  padding="longest"
    )
    labels = tokenizer(text_target=examples["targets"], max_length=max_target_length, truncation=True,  padding="longest", return_tensors="pt"
    )
    model_inputs["labels"] = labels["input_ids"]
    model_inputs["labels_mask"] = labels["attention_mask"]
    return model_inputs

In [4]:
from transformers import T5ForConditionalGeneration
model = T5ForConditionalGeneration.from_pretrained(model_checkpoint)

In [113]:
import torch
import datasets
eval_subset = datasets.Dataset.from_dict(create_q_gen_baseline_examples(nq_open["train"].select(range(4))))
eval_dataset = eval_subset.map(partial(preprocess_function, max_input_length=300, max_target_length=50, input_col='inputs'), batched=True)
eval_dataset.set_format(type="torch", columns=["input_ids", "labels", "attention_mask", "labels_mask"])


100%|██████████| 4/4 [00:00<00:00, 83.39it/s]


  0%|          | 0/1 [00:00<?, ?ba/s]

In [68]:
nq_open["train"]

Dataset({
    features: ['question', 'answers', 'ctxs', 'pos_text', 'pos_neg_text'],
    num_rows: 79168
})

In [None]:
i = 0
k = 100
txt_database = wikipedia_txt["train"]
question = nq_open["train"][i]["question"]
ctxs = nq_open["train"][i]["ctxs"][:k]
has_ans = [ctx["has_answer"] for ctx in ctxs]
texts = [txt_database[ctx["id"] - 1]["text"] for ctx in ctxs]
texts = [f"Passage: {text} Please write a question based on this passage" for text in texts]
targets = [question for text in texts]
new_dataset = datasets.Dataset.from_dict({'inputs': texts, 'targets': targets})
new_dataset = new_dataset.map(partial(preprocess_function, max_input_length=300, max_target_length=50, input_col='inputs'), batched=True)
new_dataset.set_format(type="torch", columns=["input_ids", "labels", "attention_mask", "labels_mask"])
new_dataset["input_ids"][0:3]


In [133]:
#eval_dataset = eval_dataset.remove_columns(['inputs', 'targets', 'k_pos'])
eval_inputs = eval_dataset["input_ids"][0:8]
eval_masks = eval_dataset["attention_mask"][0:8]
target_eval_labels = eval_dataset["labels"][0:8]
target_eval_label_mask= eval_dataset["labels_mask"][0:8]
decoder_input_ids = model.prepare_decoder_input_ids_from_labels(labels=eval_dataset["labels"][0:8])

In [134]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model.to(device)
eval_masks = eval_masks.to(device)
eval_inputs = eval_inputs.to(device)
target_eval_labels = target_eval_labels.to(device)
target_eval_label_mask = target_eval_label_mask.to(device)
decoder_input_ids = decoder_input_ids.to(device)
outputs = model(input_ids=eval_inputs, labels = target_eval_labels, attention_mask=eval_masks, decoder_attention_mask=target_eval_label_mask, decoder_input_ids=decoder_input_ids)

In [124]:
device


'cuda:0'

In [138]:
import torch.nn 
logits = outputs.logits
log_softmax = torch.nn.LogSoftmax(dim=-1)
log_soft = log_softmax(logits)

In [None]:
labels = target_eval_labels.unsqueeze(2)
log_soft = log_soft.gather(2, labels).squeeze(2)
log_soft = log_soft.mean(dim=1)
topk_scores, indexes = torch.topk(log_soft, k=len(log_soft))
topk_scores
eval_subset[indexes]


In [66]:
-topk_scores.mean()
outputs.loss


tensor(31.8265, grad_fn=<NllLossBackward0>)

In [None]:
def evaluate_questions(validation, k, model, tokenizer, batch_size, dataset_generator=None):
  if not dataset_generator:
    dataset_generator = create_q_gen_baseline_examples
  assert(k // batch_size != 0, "k must be multiple of batch_size")
  for i in range(len(validation)):
    question = validation[i]["question"]
    ctxs = validation[i]["ctxs"][:k]
    has_ans = [ctx["has_answer"] for ctx in ctxs]
    for i in tqdm(range(0, k, batch_size)):



    



torch.Size([50, 50, 32128])

In [None]:
test_nq = nq_open["train"].select(range(5, 205))
test_nq = test_nq.map(partial(create_pos_txt_col, k=20, txt_database=wikipedia_txt["train"]), num_proc=4)
training_dict = create_q_gen_baseline_examples(test_nq)
test_train = datasets.Dataset.from_dict(training_dict)
test_train = test_train.map(partial(preprocess_function, max_input_length=300, max_target_length=50, input_col='inputs'), batched=True)
test_train.set_format(type="torch", columns=["input_ids", "labels", "attention_mask", "labels_mask"])
test_train = test_train.remove_columns(["inputs", "targets", "k_pos"])
test_train


In [110]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
trainer_args = Seq2SeqTrainingArguments(
    output_dir=f"test_1",
    evaluation_strategy="epoch",
    learning_rate=5.6e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    weight_decay=0.01,
    num_train_epochs=3,
    predict_with_generate=True,
)
trainer =  Seq2SeqTrainer(
    model,
    trainer_args,
    train_dataset=test_train,
    data_collator=data_collator,
    tokenizer=tokenizer
)


In [144]:
!nvidia-smi

Fri Oct  7 17:47:17 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   77C    P0    33W /  70W |   9558MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces