<a href="https://colab.research.google.com/github/nvaikunt/PromptBasedReranking/blob/main/BaselineTest.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/nvaikunt/PromptBasedReranking.git

In [None]:
!pip install -q condacolab
import condacolab
condacolab.install()

In [None]:
!conda --version

In [None]:
%cd PromptBasedReranking/

In [None]:
!bash setup.sh

In [None]:
import torch
torch.cuda.is_available()

In [None]:
!python Baseline/download_dpr_data.py --resource data.wikipedia-split.psgs_w100

In [None]:
!python Baseline/download_dpr_data.py --resource data.retriever-outputs.dpr.nq-train

In [None]:

import torch
from torch.utils.data import Dataset, DataLoader
import transformers
import tokenizers
import datasets 
import pandas

wikipedia_txt = datasets.load_dataset("csv", data_files='downloads/data/wikipedia-split/psgs_w100.tsv', delimiter='\t')


In [None]:
wikipedia_txt

In [None]:
nq_open = datasets.load_dataset("json",data_files="downloads/data/retriever-outputs/dpr/nq-train.json")

In [None]:
nq_open["train"][2]["question"]

In [11]:
def get_top_k_pos(row, k, txt_database):
  ctxs = row["ctxs"]
  top_k = []
  for ctx in ctxs:
    if ctx["has_answer"]:
      text = txt_database[ctx["id"] - 1]["text"]
      top_k.append((text, "true"))
    if len(top_k) == k:
      break
  if len(top_k) == 0: 
    return []
  while len(top_k) < k:
    top_k.extend(top_k[:(k - len(top_k))])

  return top_k[:k]
  
def get_top_k_pos_neg(row, k, txt_database):
  ctxs = row["ctxs"]
  top_k_pos = []
  top_k_neg = []
  for ctx in ctxs:
    if ctx["has_answer"] and len(top_k_pos) < k:
      text = txt_database[ctx["id"] - 1]["text"]
      top_k_pos.append((text, "true"))
    if not ctx["has_answer"] and len(top_k_neg) < k: 
      text = txt_database[ctx["id"] - 1]["text"]
      top_k_neg.append((text, "false"))
    if len(top_k_pos) == k and len(top_k_neg):
      break 
  if len(top_k_pos) == 0: 
    return []
  while len(top_k_pos) < k:
    top_k_pos.extend(top_k_pos[:(k - len(top_k_pos))])
  while len(top_k_neg) < k:
    top_k_neg.extend(top_k_neg[:(k - len(top_k_neg))])
  top_k = []
  for i in range(k):
    top_k.append(top_k_pos[i])
    top_k.append(top_k_neg[i])
  return top_k


In [12]:
def create_pos_txt_col(example, k, txt_database):
  return {"pos_text": get_top_k_pos(example, k, txt_database)}

def create_pos_neg_txt_col(example, k, txt_database):
  return {"pos_neg_text": get_top_k_pos_neg(example, k, txt_database)}


In [None]:
from functools import partial
nq_open["train"] = nq_open["train"].map(partial(create_pos_txt_col, k=20, txt_database=wikipedia_txt["train"]), num_proc=4)

In [None]:
nq_open["train"] = nq_open["train"].map(partial(create_pos_neg_txt_col, k=10, txt_database=wikipedia_txt["train"]), num_proc=4)

In [None]:
nq_open["train"][0]["pos_neg_text"]

In [16]:
from tqdm import tqdm
def create_ranking_loss_baseline_examples(dataset, n=None):
  if not n:
    n = len(dataset)
  with_answer = 0 
  inputs = []
  targets = []
  for i in tqdm(range(n)):
    texts = dataset[i]["pos_neg_text"]
    if not texts: continue 
    question = dataset[i]["question"]
    current_inputs = [f"Question: {question} Passage: {text[0]} Relevant: " for text in texts]
    current_targets = [text[1] for text in texts]
    inputs.extend(current_inputs)
    targets.extend(current_targets)
    with_answer += 1 
  k = [len(targets)/(with_answer * 2)] * len(targets)
  return {"inputs": inputs, "targets": targets, "k_pos_neg": k}

def create_q_gen_baseline_examples(dataset, n=None):
  if not n:
    n = len(dataset)
  inputs = []
  targets = []
  with_answer = 0 
  for i in tqdm(range(n)):
    texts = dataset[i]["pos_text"]
    if not texts: continue 
    question = dataset[i]["question"]
    current_inputs = [f"Passage: {text[0]} Please write a question based on this passage" for text in texts]
    current_targets = [question for text in texts]
    inputs.extend(current_inputs)
    targets.extend(current_targets)
    with_answer += 1     
  k = [len(targets)/(with_answer)] * len(targets)
  return {"inputs": inputs, "targets": targets, "k_pos": k}

def create_q_gen_ranking_baseline_examples(dataset, n=None):
  if not n:
    n = len(dataset)
  inputs = []
  targets = []
  with_answer = 0 
  for i in tqdm(range(n)):
    texts = dataset[i]["pos_neg_text"]
    if not texts: continue 
    question = dataset[i]["question"]
    current_inputs = [f"Passage: {text[0]} Please write a question based on this passage" for text in texts]
    current_targets = [question for text in texts]
    inputs.extend(current_inputs)
    targets.extend(current_targets)
    with_answer += 1   
  k = [len(targets)/(with_answer * 2)] * len(targets)
  return {"inputs": inputs, "targets": targets, "k_pos_neg": k}



In [None]:
training_dict = create_q_gen_baseline_examples(nq_open["train"])

In [None]:
q_gen_dataset = datasets.Dataset.from_dict(training_dict)

In [None]:
from transformers import AutoTokenizer
from transformers import T5ForConditionalGeneration

model_checkpoint = "google/t5-base-lm-adapt"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)


In [38]:
def preprocess_function(examples, max_input_length, max_target_length, input_col):
    model_inputs = tokenizer(
        examples[input_col],
        max_length=max_input_length,
        truncation=True,  padding="longest"
    )
    labels = tokenizer(text_target=examples["targets"], max_length=max_target_length, truncation=True,  padding="longest", return_tensors="pt"
    )
    model_inputs["labels"] = labels["input_ids"]
    model_inputs["decoder_attention_mask"] = labels["attention_mask"]
    return model_inputs

In [None]:
from transformers import T5ForConditionalGeneration
model = T5ForConditionalGeneration.from_pretrained(model_checkpoint)

In [None]:
import torch
import datasets
eval_subset = datasets.Dataset.from_dict(create_q_gen_baseline_examples(nq_open["train"].select(range(4))))
eval_dataset = eval_subset.map(partial(preprocess_function, max_input_length=300, max_target_length=50, input_col='inputs'), batched=True)
eval_dataset.set_format(type="torch", columns=["input_ids", "labels", "attention_mask", "labels_mask"])


In [None]:
nq_open["train"]

In [None]:
i = 0
k = 100
txt_database = wikipedia_txt["train"]
question = nq_open["train"][i]["question"]
ctxs = nq_open["train"][i]["ctxs"][:k]
has_ans = [ctx["has_answer"] for ctx in ctxs]
texts = [txt_database[ctx["id"] - 1]["text"] for ctx in ctxs]
texts = [f"Passage: {text} Please write a question based on this passage" for text in texts]
targets = [question for text in texts]
new_dataset = datasets.Dataset.from_dict({'inputs': texts, 'targets': targets})
new_dataset = new_dataset.map(partial(preprocess_function, max_input_length=300, max_target_length=50, input_col='inputs'), batched=True)
new_dataset.set_format(type="torch", columns=["input_ids", "labels", "attention_mask", "labels_mask"])
new_dataset["input_ids"][0:3]


In [None]:
#eval_dataset = eval_dataset.remove_columns(['inputs', 'targets', 'k_pos'])
eval_inputs = eval_dataset["input_ids"][0:8]
eval_masks = eval_dataset["attention_mask"][0:8]
target_eval_labels = eval_dataset["labels"][0:8]
target_eval_label_mask= eval_dataset["labels_mask"][0:8]
decoder_input_ids = model.prepare_decoder_input_ids_from_labels(labels=eval_dataset["labels"][0:8])

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model.to(device)
eval_masks = eval_masks.to(device)
eval_inputs = eval_inputs.to(device)
target_eval_labels = target_eval_labels.to(device)
target_eval_label_mask = target_eval_label_mask.to(device)
decoder_input_ids = decoder_input_ids.to(device)
outputs = model(input_ids=eval_inputs, labels = target_eval_labels, attention_mask=eval_masks, decoder_attention_mask=target_eval_label_mask, decoder_input_ids=decoder_input_ids)

In [None]:
device


In [None]:
import torch.nn 
logits = outputs.logits
log_softmax = torch.nn.LogSoftmax(dim=-1)
log_soft = log_softmax(logits)

In [None]:
labels = target_eval_labels.unsqueeze(2)
log_soft = log_soft.gather(2, labels).squeeze(2)
log_soft = log_soft.mean(dim=1)
topk_scores, indexes = torch.topk(log_soft, k=len(log_soft))
topk_scores
eval_subset[indexes]


In [None]:
-topk_scores.mean()
outputs.loss


In [108]:
from transformers.utils.logging import disable_progress_bar
from transformers import DataCollatorForSeq2Seq
from torch.utils.data import SequentialSampler, DataLoader
import numpy as np
def evaluate_questions(validation, k, model, tokenizer, batch_size, evidence_txts, preprocess_function):
  
  assert k // batch_size != 0, "k must be multiple of batch_size"
  assert batch_size // 2 != 0, "Batch Size Must Be Even"

  if k < batch_size: 
    batch_size = k

  original_recall = []
  current_recall = []
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
  model.to(device)
  for i in tqdm(range(len(validation))):

    # Extract Question, Passages, and Info on Whether Passages have Answer
    question = validation[i]["question"]
    ctxs = validation[i]["ctxs"][:k]
    has_ans = [ctx["has_answer"] for ctx in ctxs]
    has_ans = torch.BoolTensor(has_ans)

    # Build Data as Model Expects
    texts = [evidence_txts[ctx["id"] - 1]["text"] for ctx in ctxs]
    texts = [f"Passage: {text} Please write a question based on this passage" for text in texts]
    targets = [question for text in texts]
    eval_dataset = datasets.Dataset.from_dict({'inputs': texts, 'targets': targets})  

    datasets.utils.disable_progress_bar()
    data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
    eval_dataset = eval_dataset.map(partial(preprocess_function, max_input_length=300,
                                                            max_target_length=50, input_col='inputs'), 
                                    batched=True)

    eval_dataset = eval_dataset.remove_columns(["inputs", "targets"])
    eval_dataset.set_format(type="torch")
    eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, collate_fn=data_collator, shuffle=False)
    
    # Calculate Log Scores and Get Ranking
    log_scores = []
    model.eval()
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
          batch = {k: v.to(device) for k, v in batch.items()}
          outputs = model(**batch)
          logits = outputs.logits
          log_softmax = torch.nn.LogSoftmax(dim=-1)
          log_soft = log_softmax(logits)
          labels = batch["labels"].unsqueeze(2)
          log_soft = log_soft.gather(2, labels).squeeze(2)
          log_soft = log_soft.mean(dim=1)
          log_scores.append(log_soft)
    
    log_scores = torch.cat(log_scores)
    topk_scores, indexes = torch.topk(log_scores, k=len(log_scores))

    # Collect Stats for Recall
    ranked_answers = has_ans[indexes]
    current_has_ans = torch.cumsum(ranked_answers, dim=0) > 0
    original_has_ans = torch.cumsum(has_ans, dim=0) > 0

    original_recall.append(original_has_ans.tolist())
    current_recall.append(current_has_ans.tolist())

  original_recall = np.mean(np.array(original_recall), axis=0)
  current_recall = np.mean(np.array(current_recall), axis=0)
  
  return original_recall, current_recall



    



In [None]:
val_dataset = nq_open["train"].select(range(30,50))
original_recall, current_recall = evaluate_questions(val_dataset, k=30, model=model,
                                                     tokenizer=tokenizer, 
                                                     batch_size=10, 
                                                     evidence_txts = wikipedia_txt["train"], 
                                                     preprocess_function=preprocess_function)


In [None]:
import gc

gc.collect()

torch.cuda.is_available()

In [None]:
device = "cpu"
model.to(device)
ranking_samples = create_q_gen_ranking_baseline_examples(val_dataset, n=20)
ranking_samples = datasets.Dataset.from_dict(ranking_samples)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
eval_dataset = ranking_samples.map(partial(preprocess_function, max_input_length=300,
                                                            max_target_length=50, input_col='inputs'), 
                                    batched=True)

eval_dataset = eval_dataset.remove_columns(["inputs", "targets", "k_pos_neg"])
eval_dataset.set_format(type="torch")
eval_dataloader = DataLoader(eval_dataset, batch_size=8, collate_fn=data_collator, shuffle=False)
pos = torch.arange(0, 10, 2)
neg = torch.arange(1, 10, 2)
for batch in eval_dataloader:
  batch['input_ids'].to(device)
  batch = {k: v.to(device) for k, v in batch.items()}
  outputs = model(**batch)
  print(batch["labels"])
  targets = batch["labels"]
  loss = ranking_loss(outputs.logits, targets, margin=.01, batch_size=10)
  print(loss)


In [None]:
!nvidia-smi

In [145]:
def ranking_loss(outputs, labels, margin, batch_size):
  assert batch_size // 2 != 0, "Batch Size must be even" 

  pos_indexes = torch.arange(0, batch_size, 2)
  neg_indexes = torch.arange(1, batch_size, 2)

  log_softmax = torch.nn.LogSoftmax(dim=-1)
  outputs = log_softmax(outputs)

  ce_loss = torch.nn.CrossEntropyLoss()
  pos_outputs = outputs[pos_indexes]
  neg_outputs = outputs[neg_indexes]
  flat_size = pos_outputs.size(-1)

  pos_loss = ce_loss(pos_outputs.view(-1, flat_size), labels[pos_indexes].view(-1))
  neg_loss = ce_loss(neg_outputs.view(-1, flat_size), labels[neg_indexes].view(-1))
  margin_loss = torch.nn.MarginRankingLoss(margin)
  loss = margin_loss(pos_loss, neg_loss, 1)
  return loss