<a href="https://colab.research.google.com/github/nvaikunt/PromptBasedReranking/blob/main/BaselineTest.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/nvaikunt/PromptBasedReranking.git

In [None]:
!pip install -q condacolab
import condacolab
condacolab.install()

In [None]:
!conda --version

In [None]:
%cd PromptBasedReranking/

In [None]:
!bash setup.sh

In [None]:
import torch
torch.cuda.is_available()

In [None]:
!python utils/download_dpr_data.py --resource data.wikipedia-split.psgs_w100

In [None]:
!python utils/download_dpr_data.py --resource data.retriever-outputs.dpr.nq-train

In [None]:
!python utils/download_dpr_data.py --resource data.retriever-outputs.dpr.nq-dev

In [7]:
!mkdir output_test

In [None]:
!python baseline_train.py -ep 1 -lr 5e-4  -bs 10 -odir output_test \
 -edir 'downloads/data/wikipedia-split/psgs_w100.tsv' \
 -tdir "downloads/data/retriever-outputs/dpr/nq-train.json" \
 -vdir "downloads/data/retriever-outputs/dpr/nq-dev.json" \
 -trsz 200 -vlsz 20 -ngpu 1 -m "google/t5-base-lm-adapt"

In [None]:
from transformers import AutoTokenizer, T5ForConditionalGeneration
from preprocess_data import create_training_dataset
model_checkpoint = "google/t5-base-lm-adapt"
model_checkpoint_2 = "bigscience/T0_3B"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = T5ForConditionalGeneration.from_pretrained(model_checkpoint)
train_file = "downloads/data/retriever-outputs/dpr/nq-train.json"
evidence_file = 'downloads/data/wikipedia-split/psgs_w100.tsv'
train_dataset = create_training_dataset(train_file, evidence_file, 400, 
                                        isQG=False, isRanking=True, batch_sz=10,
                                        tokenizer=tokenizer)
train_dataset

In [None]:
from transformers import AdamW
from transformers import get_scheduler
from transformers import DataCollatorForSeq2Seq
from torch.utils.data import DataLoader
from tqdm import tqdm


batch_size = 10
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
train_dataloader = DataLoader(train_dataset, batch_size=10, collate_fn=data_collator, shuffle=False)
truth_ix = 1176
false_ix = 6136
optimizer = AdamW(model.parameters(), lr=5e-4)
num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)
model.train()
progress_bar = tqdm(range(num_training_steps))
device = "cuda:0"
model.to(device)
model.eval
log_softmax = torch.nn.LogSoftmax(dim=-1)
for epoch in range(num_epochs):
  losses = []
  with torch.no_grad():
    for step, batch in enumerate(train_dataloader):
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            logits = outputs.logits
            log_soft = log_softmax(logits)
            print(log_soft[:,0,truth_ix])
            print(log_soft[:,0,false_ix])
            print(batch["labels"])
            loss = ranking_loss(logits, batch["labels"], 1, 
                                  batch["labels"].size(dim=0))
            if step > 2: 
              break
            losses.append(loss)
            #loss.backward()

            #optimizer.step()
            #lr_scheduler.step()
            #optimizer.zero_grad()
            #progress_bar.update(1)
  print(f'Train Loss in Epoch {epoch}: {sum(losses)/len(losses)}')


In [None]:
wikipedia_txt

In [None]:
nq_open = datasets.load_dataset("json", data_files="downloads/data/retriever-outputs/dpr/nq-train.json")


In [None]:
nq_open["validation"] = datasets.load_dataset("json",data_files="downloads/data/retriever-outputs/dpr/nq-dev.json", split="train")

In [None]:
def get_top_k_pos(row, k, txt_database):
  ctxs = row["ctxs"]
  top_k = []
  for ctx in ctxs:
    if ctx["has_answer"]:
      text = txt_database[ctx["id"] - 1]["text"]
      top_k.append((text, "true"))
    if len(top_k) == k:
      break
  if len(top_k) == 0: 
    return []
  while len(top_k) < k:
    top_k.extend(top_k[:(k - len(top_k))])

  return top_k[:k]
  
def get_top_k_pos_neg(row, k, txt_database):
  ctxs = row["ctxs"]
  top_k_pos = []
  top_k_neg = []
  for ctx in ctxs:
    if ctx["has_answer"] and len(top_k_pos) < k:
      text = txt_database[ctx["id"] - 1]["text"]
      top_k_pos.append((text, "true"))
    if not ctx["has_answer"] and len(top_k_neg) < k: 
      text = txt_database[ctx["id"] - 1]["text"]
      top_k_neg.append((text, "false"))
    if len(top_k_pos) == k and len(top_k_neg):
      break 
  if len(top_k_pos) == 0: 
    return []
  while len(top_k_pos) < k:
    top_k_pos.extend(top_k_pos[:(k - len(top_k_pos))])
  while len(top_k_neg) < k:
    top_k_neg.extend(top_k_neg[:(k - len(top_k_neg))])
  top_k = []
  top_k.extend(top_k_pos[:k])
  top_k.extend(top_k_neg[:k])
  return top_k


In [None]:
def create_pos_txt_col(example, k, txt_database):
  return {"pos_text": get_top_k_pos(example, k, txt_database)}

def create_pos_neg_txt_col(example, k, txt_database):
  return {"pos_neg_text": get_top_k_pos_neg(example, k, txt_database)}


In [None]:
from functools import partial
nq_open["train"] = nq_open["train"].map(partial(create_pos_txt_col, k=10, txt_database=wikipedia_txt["train"]), num_proc=4)

In [None]:
nq_open["train"] = nq_open["train"].map(partial(create_pos_neg_txt_col, k=5, txt_database=wikipedia_txt["train"]), num_proc=4)

In [None]:
nq_open["validation"] = nq_open["validation"].map(partial(create_pos_txt_col, k=10, txt_database=wikipedia_txt["train"]), num_proc=4)
nq_open["validation"] = nq_open["validation"].map(partial(create_pos_neg_txt_col, k=5, txt_database=wikipedia_txt["train"]), num_proc=4)

In [None]:
from tqdm import tqdm
def create_ranking_loss_baseline_examples(dataset, n=None):
  if not n:
    n = len(dataset)
  with_answer = 0 
  inputs = []
  targets = []
  for i in tqdm(range(n)):
    texts = dataset[i]["pos_neg_text"]
    if not texts: continue 
    question = dataset[i]["question"]
    current_inputs = [f"Question: {question} Passage: {text[0]} Relevant: " for text in texts]
    current_targets = [text[1] for text in texts]
    inputs.extend(current_inputs)
    targets.extend(current_targets)
    with_answer += 1 
  k = [len(targets)/(with_answer * 2)] * len(targets)
  return {"inputs": inputs, "targets": targets, "k_pos_neg": k}

def create_q_gen_baseline_examples(dataset, n=None):
  if not n:
    n = len(dataset)
  inputs = []
  targets = []
  with_answer = 0 
  for i in tqdm(range(n)):
    texts = dataset[i]["pos_text"]
    if not texts: continue 
    question = dataset[i]["question"]
    current_inputs = [f"Passage: {text[0]} Please write a question based on this passage" for text in texts]
    current_targets = [question for text in texts]
    inputs.extend(current_inputs)
    targets.extend(current_targets)
    with_answer += 1     
  k = [len(targets)/(with_answer)] * len(targets)
  return {"inputs": inputs, "targets": targets, "k_pos": k}

def create_q_gen_ranking_baseline_examples(dataset, n=None):
  if not n:
    n = len(dataset)
  inputs = []
  targets = []
  with_answer = 0 
  for i in tqdm(range(n)):
    texts = dataset[i]["pos_neg_text"]
    if not texts: continue 
    question = dataset[i]["question"]
    current_inputs = [f"Passage: {text[0]} Please write a question based on this passage" for text in texts]
    current_targets = [question for text in texts]
    inputs.extend(current_inputs)
    targets.extend(current_targets)
    with_answer += 1   
  k = [len(targets)/(with_answer * 2)] * len(targets)
  return {"inputs": inputs, "targets": targets, "k_pos_neg": k}



In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration
from utils.data_utils import create_evidence_texts, create_pos_neg_txt_col
from functools import partial
import datasets
import utils.data_utils as data_utils
import utils.train_utils as train_utils
wikipedia_txt = create_evidence_texts('downloads/data/wikipedia-split/psgs_w100.tsv')
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
nq_open = datasets.load_dataset("json", data_files="downloads/data/retriever-outputs/dpr/nq-train.json", split="train")
nq_open = nq_open.select(range(0,400))
nq_open = nq_open.map(partial(data_utils.create_pos_neg_txt_col, k=10, txt_database=wikipedia_txt), num_proc=4)
trn_function = train_utils.create_ranking_loss_baseline_examples
train_dict = trn_function(nq_open, n=3)
train_dataset = datasets.Dataset.from_dict(train_dict)
train_dataset = train_dataset.map(partial(data_utils.preprocess_function, tokenizer=tokenizer, max_input_length=300, 
                                     max_target_length=50, input_col="inputs"), batched=True)
train_dataset

In [None]:
def preprocess_function(examples, tokenizer, max_input_length, max_target_length, input_col) -> dict:
    print(type(examples))
    model_inputs = tokenizer(
        examples[input_col],
        max_length=max_input_length,
        truncation=True,  padding="longest"
    )
    labels = tokenizer(text_target=examples["targets"], max_length=max_target_length, truncation=True,  padding="longest", return_tensors="pt"
    )
    model_inputs["labels"] = labels["input_ids"]
    model_inputs["decoder_attention_mask"] = labels["attention_mask"]
    return model_inputs

In [None]:
from transformers import T5ForConditionalGeneration
model = T5ForConditionalGeneration.from_pretrained(model_checkpoint)

In [None]:
def qg_batching(question, ctxs, has_ans, evidence_txts):
  texts = [evidence_txts[ctx["id"] - 1]["text"] for ctx in ctxs]
  texts = [f"Passage: {text} Please write a question based on this passage" for text in texts]
  targets = [question for text in texts]
  eval_dataset = datasets.Dataset.from_dict({'inputs': texts, 'targets': targets})  
  return eval_dataset

def relevance_batching(question, ctxs, has_ans, evidence_txts):
  texts = [evidence_txts[ctx["id"] - 1]["text"] for ctx in ctxs]
  texts = [f"Question: {question} Passage: {text} Relevant: " for text in texts]
  targets = ["true" if ans else "false" for ans in has_ans]
  new_dataset = datasets.Dataset.from_dict({'inputs': texts, 'targets': targets})
  return new_dataset
  
def qg_ranking(logits, labels):
  log_softmax = torch.nn.LogSoftmax(dim=-1)
  log_soft = log_softmax(logits)
  labels = labels.unsqueeze(2)
  log_soft = log_soft.gather(2, labels).squeeze(2)
  log_soft = log_soft.mean(dim=1)
  return log_soft
  
def relevance_ranking(logits, labels, truth_ix):
  softmax = torch.nn.Softmax(dim=-1)
  probs = softmax(logits)
  probs = probs[:, 0, truth_ix]
  return probs




In [None]:
from transformers.utils.logging import disable_progress_bar
from transformers import DataCollatorForSeq2Seq
from torch.utils.data import SequentialSampler, DataLoader
import numpy as np
def evaluate_recall(validation, k, model, tokenizer, batch_size, evidence_txts, 
                       preprocess_function, truth_ix, isRanking=False, isQG=True):
  
  assert k // batch_size != 0, "k must be multiple of batch_size"
  assert batch_size // 2 != 0, "Batch Size Must Be Even"

  if k < batch_size: 
    batch_size = k

  original_recall = []
  current_recall = []
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
  model.to(device)
  losses = []
  for i in tqdm(range(len(validation))):

    # Extract Question, Passages, and Info on Whether Passages have Answer
    question = validation[i]["question"]
    ctxs = validation[i]["ctxs"][:k]
    has_ans = [ctx["has_answer"] for ctx in ctxs]
    has_ans = torch.BoolTensor(has_ans)

    # Build Data as Model Expects
    if isQG: 
      eval_dataset = qg_batching(question, ctxs, has_ans, evidence_txts)
    else: 
      eval_dataset = relevance_batching(question, ctxs, has_ans, evidence_txts)

    datasets.utils.disable_progress_bar()
    data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
    eval_dataset = eval_dataset.map(partial(preprocess_function, max_input_length=300,
                                                            max_target_length=50, input_col='inputs'), 
                                    batched=True)

    eval_dataset = eval_dataset.remove_columns(["inputs", "targets"])
    eval_dataset.set_format(type="torch")
    eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, collate_fn=data_collator, shuffle=False)
    
    # Calculate Log Scores and Get Ranking
    scores = []
    model.eval()
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
          batch = {k: v.to(device) for k, v in batch.items()}
          outputs = model(**batch)
          logits = outputs.logits
          labels = batch["labels"]
          if isRanking:
            loss = ranking_loss(logits, batch["labels"], 1, 
                                batch["labels"].size(dim=0))
            losses.append(loss)
          else: 
            losses.append(outputs.loss)
          if isQG: 
            score = qg_ranking(logits, labels)
          else: 
            score = relevance_ranking(logits, labels, truth_ix)
          scores.append(score)
    
    scores = torch.cat(scores)
    topk_scores, indexes = torch.topk(scores, k=len(scores))

    # Collect Stats for Recall
    ranked_answers = has_ans[indexes]
    current_has_ans = torch.cumsum(ranked_answers, dim=0) > 0
    original_has_ans = torch.cumsum(has_ans, dim=0) > 0

    original_recall.append(original_has_ans.tolist())
    current_recall.append(current_has_ans.tolist())

  original_recall = np.mean(np.array(original_recall), axis=0)
  current_recall = np.mean(np.array(current_recall), axis=0)
  loss = sum(losses) / len(losses)
  return original_recall, current_recall, loss



    



In [None]:
val_dataset = nq_open["train"].select(range(30,50))
original_recall, current_recall, loss = evaluate_recall(val_dataset, k=30, model=model,
                                                     tokenizer=tokenizer, 
                                                     batch_size=10, 
                                                     evidence_txts = wikipedia_txt["train"], 
                                                     preprocess_function=preprocess_function, truth_ix=1176)


In [None]:
!nvidia-smi

In [None]:
def ranking_loss(outputs, labels, margin, batch_size):
  assert batch_size // 2 != 0, "Batch Size must be even" 

  log_softmax = torch.nn.LogSoftmax(dim=-1)
  outputs = log_softmax(outputs)
  pos_end = batch_size // 2
  ce_loss = torch.nn.CrossEntropyLoss()
  pos_outputs = outputs[:pos_end, :, :]
  neg_outputs = outputs[pos_end:, :, :]
  print(pos_outputs.size(), neg_outputs.size())
  flat_size = pos_outputs.size(-1)
  pos_loss = ce_loss(pos_outputs.view(-1, flat_size), labels[:pos_end, :].view(-1))
  neg_loss = ce_loss(neg_outputs.view(-1, flat_size), labels[pos_end:, :].view(-1))
  print(pos_loss)
  print(neg_loss)
  margin_loss = torch.nn.MarginRankingLoss(margin)
  loss = margin_loss(pos_loss, neg_loss, torch.tensor(-1))

  return loss

In [None]:
from transformers import AdamW
from transformers import get_scheduler
train_dataset = datasets.Dataset.from_dict(create_q_gen_ranking_baseline_examples(nq_open["train"].select(range(100, 500))))
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
train_dataset = train_dataset.map(partial(preprocess_function, max_input_length=300,
                                                            max_target_length=50, input_col='inputs'), 
                                    batched=True)

train_dataset = train_dataset.remove_columns(["inputs", "targets", "k_pos_neg"])
train_dataset.set_format(type="torch")
batch_size = 10
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=data_collator, shuffle=False)

truth_ix = 1176
optimizer = AdamW(model.parameters(), lr=5e-4)
num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)
model.train()
progress_bar = tqdm(range(num_training_steps))
device = "cuda:0"
model.to(device)
for epoch in range(num_epochs):
  losses = []
  for step, batch in enumerate(train_dataloader):
          batch = {k: v.to(device) for k, v in batch.items()}
          outputs = model(**batch)
          logits = outputs.logits
          loss = ranking_loss(logits, batch["labels"], 1, 
                                batch["labels"].size(dim=0))
          losses.append(loss)
          loss.backward()

          optimizer.step()
          lr_scheduler.step()
          optimizer.zero_grad()
          progress_bar.update(1)
  print(f'Train Loss in Epoch {epoch}: {sum(losses)/len(losses)}')
  val_dataset = nq_open["validation"].select(range(30,50))
  original_recall, current_recall, loss = evaluate_recall(val_dataset, k=30, model=model,
                                                     tokenizer=tokenizer, 
                                                     batch_size=10, 
                                                     evidence_txts = wikipedia_txt["train"], 
                                                     preprocess_function=preprocess_function, isRanking=True, 
                                                     truth_ix=truth_ix)
  print(f'DPR Recall @ 5: {original_recall[5]}')
  print(f'Our Recall @ 5: {current_recall[5]}')
  print(f'Validation Loss: {loss.item()}')
  
  


In [None]:
original_recall, current_recall, loss = evaluate_recall(nq_open["train"].select(range(100, 500)), k=30, model=model,
                                                     tokenizer=tokenizer, 
                                                     batch_size=10, 
                                                     evidence_txts = wikipedia_txt["train"], 
                                                     preprocess_function=preprocess_function, isQG=False, 
                                                     truth_ix=1176)
print(original_recall, current_recall, loss)

In [None]:
print(current_recall[19], original_recall[19])

In [None]:
model = T5ForConditionalGeneration.from_pretrained(model_checkpoint)
original_recall, current_recall, loss = evaluate_recall(nq_open["train"].select(range(100, 110)), k=30, model=model,
                                                     tokenizer=tokenizer, 
                                                     batch_size=10, 
                                                     evidence_txts = wikipedia_txt["train"], 
                                                     preprocess_function=preprocess_function, isQG=False,
                                                     truth_ix=truth_ix)

print(current_recall[19], original_recall[19])

In [None]:
print(current_recall[19], original_recall[19], loss)

In [None]:
validation = nq_open["train"].select(range(100, 110))
truth_ix = 1176
isQG = False
i = 0
k = 30
batch_size = 10
question = validation[i]["question"]
ctxs = validation[i]["ctxs"][:k]
has_ans = [ctx["has_answer"] for ctx in ctxs]
has_ans = torch.BoolTensor(has_ans)
losses = []
# Build Data as Model Expects
evidence_txts = wikipedia_txt["train"]
eval_dataset = relevance_batching(question, ctxs, has_ans, evidence_txts)

datasets.utils.disable_progress_bar()
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
eval_dataset = eval_dataset.map(partial(preprocess_function, max_input_length=300,
                                                        max_target_length=50, input_col='inputs'), 
                                batched=True)

eval_dataset = eval_dataset.remove_columns(["inputs", "targets"])
eval_dataset.set_format(type="torch")
eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, collate_fn=data_collator, shuffle=False)
isRanking=False
# Calculate Log Scores and Get Ranking
scores = []
model.to(device)
model.eval()
for step, batch in enumerate(eval_dataloader):
    with torch.no_grad():
      batch = {k: v.to(device) for k, v in batch.items()}
      outputs = model(**batch)
      logits = outputs.logits
      labels = batch["labels"]
      print(labels)
      if isRanking:
        loss = ranking_loss(logits, batch["labels"], 1, 
                            batch["labels"].size(dim=0))
        losses.append(loss)
      else: 
        losses.append(outputs.loss)
      if isQG: 
        score = qg_ranking(logits, labels)
      else: 
        score = relevance_ranking(logits, labels, truth_ix)
      scores.append(score)

scores = torch.cat(scores)
topk_scores, indexes = torch.topk(scores, k=len(scores))


In [None]:
scores
print(validation[0]["ctxs"][:k], scores)
print(losses)