<a href="https://colab.research.google.com/github/nvaikunt/PromptBasedReranking/blob/main/BaselineTest.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/nvaikunt/PromptBasedReranking.git

In [None]:
!pip install -q condacolab
import condacolab
condacolab.install()

In [None]:
!conda --version

In [None]:
%cd PromptBasedReranking/

In [None]:
!bash setup.sh

In [None]:
!python Baseline/download_dpr_data.py --resource data.wikipedia-split.psgs_w100

In [None]:
!python Baseline/download_dpr_data.py --resource data.retriever-outputs.dpr.nq-train

In [None]:

import torch
from torch.utils.data import Dataset, DataLoader
import transformers
import tokenizers
import datasets 
import pandas

wikipedia_txt = datasets.load_dataset("csv", data_files='downloads/data/wikipedia-split/psgs_w100.tsv', delimiter='\t')


In [None]:
wikipedia_txt

In [None]:
nq_open = datasets.load_dataset("json",data_files="downloads/data/retriever-outputs/dpr/nq-train.json")

In [None]:
nq_open["train"][2]["question"]

In [127]:
def get_top_k_pos(row, k, txt_database):
  ctxs = row["ctxs"]
  top_k = []
  for ctx in ctxs:
    if ctx["has_answer"]:
      text = txt_database[ctx["id"] - 1]["text"]
      top_k.append((text, "true"))
    if len(top_k) == k:
      break
  if len(top_k) == 0: 
    return []
  while len(top_k) < k:
    top_k.extend(top_k[:(k - len(top_k))])

  return top_k[:k]
  
def get_top_k_pos_neg(row, k, txt_database):
  ctxs = row["ctxs"]
  top_k_pos = []
  top_k_neg = []
  for ctx in ctxs:
    if ctx["has_answer"] and len(top_k_pos) < k:
      text = txt_database[ctx["id"] - 1]["text"]
      top_k_pos.append((text, "true"))
    if not ctx["has_answer"] and len(top_k_neg) < k: 
      text = txt_database[ctx["id"] - 1]["text"]
      top_k_neg.append((text, "false"))
    if len(top_k_pos) == k and len(top_k_neg):
      break 
  if len(top_k_pos) == 0: 
    return []
  while len(top_k_pos) < k:
    top_k_pos.extend(top_k_pos[:(k - len(top_k_pos))])
  while len(top_k_neg) < k:
    top_k_neg.extend(top_k_neg[:(k - len(top_k_neg))])
 
  return top_k_pos + top_k_neg



In [140]:
def create_pos_txt_col(example, k, txt_database):
  return {"pos_text": get_top_k_pos(example, k, txt_database)}

def create_pos_neg_txt_col(example, k, txt_database):
  return {"pos_neg_text": get_top_k_pos_neg(example, k, txt_database)}


In [None]:
from functools import partial
nq_open["train"] = nq_open["train"].map(partial(create_pos_txt_col, k=20, txt_database=wikipedia_txt["train"]), num_proc=4)

In [None]:
nq_open["train"] = nq_open["train"].map(partial(create_pos_neg_txt_col, k=10, txt_database=wikipedia_txt["train"]), num_proc=4)

In [None]:
nq_open["train"][0]["pos_neg_text"]

In [156]:
import tqdm
def create_ranking_loss_baseline_examples(dataset, n=None):
  if not n:
    n = len(dataset)
  inputs = []
  targets = []
  for i in tqdm(range(n)):
    texts = dataset[i]["pos_neg_text"]
    question = dataset[i]["question"]
    for text in texts:
      passage = text[0]
      inputs.append(f"Question: {question} Passage: {passage} Relevant: ")
      targets.append(text[1])
  k = [len(targets)/(n*2)] * len(targets)
  return {"inputs": inputs, "targets": targets, "k_pos_neg": k}

def create_q_gen_baseline_examples(dataset, n=None):
  if not n:
    n = len(dataset)
  inputs = []
  targets = []
  for i in tqdm(range(n)):
    texts = dataset[i]["pos_text"]
    question = dataset[i]["question"]
    for text in texts:
      passage = text[0]
      inputs.append(f"Passage: {passage} Please write a question based on this passage")
      targets.append(question)
  k = [len(targets)/(n)] * len(targets)
  return {"inputs": inputs, "targets": targets, "k_pos": k}

def create_q_gen_ranking_baseline_examples(dataset, n=None):
  if not n:
    n = len(dataset)
  inputs = []
  targets = []
  for i in tqdm(range(n)):
    texts = dataset[i]["pos_neg_text"]
    question = dataset[i]["question"]
    for text in texts:
      passage = text[0]
      inputs.append(f"Passage: {passage} Please write a question based on this passage")
      targets.append(question)
  k = [len(targets)/(n*2)] * len(targets)
  return {"inputs": inputs, "targets": targets, "k_pos_neg": k}



In [159]:
training_dict = create_ranking_loss_baseline_examples(nq_open["train"])

AttributeError: ignored

In [161]:
ranking_loss_dataset = datasets.Dataset.from_dict(training_dict)

In [164]:
ranking_loss_dataset.save_to_disk('/downloads/data')