In [None]:
!pip install transformers
!pip install datasets
!pip install sacrebleu
!pip install sentencepiece

In [2]:
import os, json
import torch
import pandas as pd
from google.colab import drive
from collections import defaultdict
from transformers import T5ForConditionalGeneration, T5Tokenizer
from transformers import Trainer, TrainingArguments
from torch.utils.data import Dataset, DataLoader
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import gc
from sklearn.metrics import accuracy_score, f1_score, classification_report
import pickle

In [3]:
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
path_to_data = '/content/drive/MyDrive/Colab Notebooks/data/'

In [5]:
class TrainerDataset(Dataset):
    def __init__(self, path):
        df = pd.read_json(path, lines=True)
        df = df.dropna()
        self.dataset = df
        self.tokenizer = T5Tokenizer.from_pretrained('t5-base')

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        source = self.dataset.iloc[idx, 0]
        target = self.dataset.iloc[idx, 1]
        input_ids = self.tokenizer.encode(source, 
                                          return_tensors='pt',
                                          padding='max_length', 
                                          truncation='longest_first', 
                                          max_length=512)[0]
        
        label = self.tokenizer.encode(target, 
                                      return_tensors='pt', 
                                      padding='max_length',
                                      truncation='longest_first', 
                                      max_length=2)[0]
        
        return {'input_ids': input_ids, 'labels': label}

In [6]:
def compute_T5_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages, MaxMRRRank):
    """Compute MRR metric
    Args:   

    p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping
    
    p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates

    Returns:
        dict: dictionary of metrics {'MRR': <MRR Score>}
    """
    all_scores = {}
    MRR = 0
    qids_with_relevant_passages = 0
    ranking = []
    
    for qid in qids_to_ranked_candidate_passages:
    
        if qid in qids_to_relevant_passageids:
            ranking.append(0)
            target_pid = qids_to_relevant_passageids[qid]
            candidate_pid = qids_to_ranked_candidate_passages[qid]
            if len(candidate_pid) > 0 :
              if len(candidate_pid) < MaxMRRRank:
                MaxMRRRank = len(candidate_pid)
              for i in range(0,MaxMRRRank):
                if candidate_pid[i] in target_pid:
                    MRR += 1/(i + 1)
                    ranking.pop()
                    ranking.append(i+1)
                    break

    print(MRR)
    if len(ranking) == 0:
        raise IOError("No matching QIDs found. Are you sure you are scoring the evaluation set?")
    MRR = MRR/len(qids_to_relevant_passageids)
    all_scores['MRR @10'] = MRR
    all_scores['QueriesRanked'] = len(qids_to_ranked_candidate_passages)
    return all_scores

In [54]:
def convert_to_evaluation(lol_of_dists, top1000_dict):
  """Converts the output of the model into the required format for MRR ranking"""
  eval_dict = defaultdict(list)

  lol_idx = 0
  for Q_ID, value in top1000_dict.items():
    
    probs = []
    while len(probs) < len(value):
      probs += lol_of_dists[lol_idx]
      lol_idx += 1

    assert len(probs) == len(value), f"{len(probs)}, {len(value)}, {Q_ID}"
    for idx, ele in enumerate(probs):
      eval_dict[Q_ID].append((value[idx], ele))

    eval_dict[Q_ID] = sorted(eval_dict[Q_ID], key=lambda x: x[1][0], reverse=True)
    eval_dict[Q_ID] =[x[0] for x in eval_dict[Q_ID] if x[1][0] > .50]

  return eval_dict
  

In [8]:
def convert_to_jsonl(file_path, relevance_dict, split='train'):
  t5_file = open(file_path, 'w')
  if split != 'train':
      for key, value in relevance_dict.items():
        q_id = key
        for p_id in value[:]: 
          t5_out = {}
          t5_out['input_text'] = "Query: " + QUERY_DICT[q_id] + " Document: " + PASSAGE_DICT[p_id] + " Relevance: "
          t5_out['target_text'] = ''
          t5_file.write(json.dumps(t5_out))
          t5_file.write('\n')
      t5_file.close()
  else:
      for key, value in relevance_dict.items():
        q_id, pp_id = key
        t5_out = {}
        t5_out['input_text'] = "Query: " + QUERY_DICT[q_id] + " Document: " + PASSAGE_DICT[pp_id] + " Relevance: "
        #t5_out['Document'] = PASSAGE_DICT[pp_id]
        t5_out['target_text'] = 'true'
        t5_file.write(json.dumps(t5_out))
        t5_file.write('\n')
        for np_id in value[:10]: 
          t5_out = {}
          t5_out['input_text'] = "Query: " + QUERY_DICT[q_id] + " Document: " + PASSAGE_DICT[np_id] + " Relevance: "
          t5_out['target_text'] = 'false'
          t5_file.write(json.dumps(t5_out))
          t5_file.write('\n')
      t5_file.close()



In [9]:
def load_triplets(path, split, trip_dict):
  with open(path_to_data + f'diverse.triplets.{split}.tsv', 'r', encoding='utf-8') as inF:
    for line in inF:
      q_id, pp_id, np_id = line.split('\t')
      trip_dict[(q_id, pp_id)].append(np_id.strip('\n'))
  inF.close()

In [10]:
def enforce_correct_size(rel_dict, split='train'):
  """Enforces each Query has atleast 10 documents"""
  removed_keys = []

  for key, value in rel_dict.items():

    if split == 'train':
      if len(value) < 10:
        removed_keys.append(key)
    else:
      if len(value) < 1000:
        removed_keys.append(key)
  
  for key in removed_keys:
    del rel_dict[key]
  return rel_dict

In [11]:
def remove_duplicate_queries(rel_dict):
  """Removes duplicate queries"""
  queries = []
  for key, value in rel_dict.items():
    queries.append(key[0])

  duplicates = [x for x in queries if queries.count(x) >= 2]
  removed_keys = []
  for key, value in rel_dict.items():
    if key[0] in duplicates:
      removed_keys.append(key)

  for key in removed_keys:
    del rel_dict[key]
  return rel_dict



In [12]:
PASSAGE_DICT = defaultdict(str)
QUERY_DICT = defaultdict(str)

with open(path_to_data + 'diverse.passages.all.tsv', 'r', encoding='utf-8') as inF:
  for line in inF:
    p_id, passage = line.split('\t')
    PASSAGE_DICT[p_id] = passage[:-1]

with open(path_to_data + 'diverse.queries.all.tsv', 'r', encoding='utf-8') as inF:
  for line in inF:
    q_id, query = line.split('\t')
    QUERY_DICT[q_id] = query[:-1]


In [13]:
with open(path_to_data+'/top1000dict.pickle', "rb") as inF:
    top_1000_dict = pickle.load(inF)

In [14]:
all_relevance_dict = defaultdict(list)
train_relevance_dict = defaultdict(list)
dev_relevance_dict = defaultdict(list)
test_relevance_dict = defaultdict(list)

load_triplets(path_to_data, 'all', all_relevance_dict)
load_triplets(path_to_data, 'train', train_relevance_dict)
load_triplets(path_to_data, 'dev', dev_relevance_dict)
load_triplets(path_to_data, 'test', test_relevance_dict)

In [15]:
train_relevance_dict = enforce_correct_size(train_relevance_dict)
dev_relevance_dict = enforce_correct_size(dev_relevance_dict)
test_relevance_dict = enforce_correct_size(test_relevance_dict)

In [16]:
dev_relevance_dict = remove_duplicate_queries(dev_relevance_dict)

In [17]:
TEST_QUERIES = [x[0] for x in list(test_relevance_dict.keys())]
DEV_QUERIES = [x[0] for x in list(dev_relevance_dict.keys())]

In [18]:
top_1000_test = defaultdict(list)
for Q in TEST_QUERIES:
  top_1000_test[Q] = top_1000_dict[Q]

In [19]:
top_1000_dev = defaultdict(list)
for Q in DEV_QUERIES:
  top_1000_dev[Q] = top_1000_dict[Q]

In [20]:
top_1000_test2 = enforce_correct_size(top_1000_test, split='test')
top_1000_dev2 = enforce_correct_size(top_1000_dev, split='dev')

In [66]:
qrels_dict = defaultdict(list)
with open(path_to_data+'/qrels.train.tsv', 'r', encoding='utf-8') as inF:
  for line in inF:
    Q_ID, _, P_ID, _ = line.split('\t')
    qrels_dict[Q_ID].append(P_ID)

test_qrels = defaultdict(list)
for Q in top_1000_test2.keys():
  test_qrels[Q] = qrels_dict[Q]

In [21]:
print(len(top_1000_test2))

500


In [22]:
### Make sure all our Queries have atleast 10 passages
for key, value in train_relevance_dict.items():
  assert len(value) > 9

for key, value in dev_relevance_dict.items():
  assert len(value) > 9

for key, value in test_relevance_dict.items():
  assert len(value) > 9

print("Success!")

Success!


In [23]:
convert_to_jsonl(path_to_data+'/json/train.json', train_relevance_dict)
convert_to_jsonl(path_to_data+'/json/dev.json', top_1000_dev2, split='dev')
convert_to_jsonl(path_to_data+'/json/test.json', top_1000_test2, split='test')

In [24]:
tokenizer = T5Tokenizer.from_pretrained('t5-base')

In [25]:
train_dataset = TrainerDataset(path_to_data+'/json/train.json')
dev_dataset = TrainerDataset(path_to_data+'/json/dev.json')
test_dataset = TrainerDataset(path_to_data+'/json/test.json')

In [26]:
train_dataloader = DataLoader(train_dataset, batch_size=11, shuffle=False)
dev_dataloader = DataLoader(dev_dataset, batch_size=50, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=50, shuffle=False)

In [27]:
TRUE_TOK_ID = 1176
FALSE_TOK_ID = 6136

In [28]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [29]:
gc.collect()

204

In [30]:
print(torch.cuda.memory_summary(device=None, abbreviated=False))

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |       0 B  |       0 B  |       0 B  |       0 B  |
|       from large pool |       0 B  |       0 B  |       0 B  |       0 B  |
|       from small pool |       0 B  |       0 B  |       0 B  |       0 B  |
|---------------------------------------------------------------------------|
| Active memory         |       0 B  |       0 B  |       0 B  |       0 B  |
|       from large pool |       0 B  |       0 B  |       0 B  |       0 B  |
|       from small pool |       0 B  |       0 B  |       0 B  |       0 B  |
|---------------------------------------------------------------

In [31]:
model = T5ForConditionalGeneration.from_pretrained('t5-base')

In [32]:
model = model.to(device)
weights = torch.ones(32128)
weights[FALSE_TOK_ID] = .10
loss_function = nn.CrossEntropyLoss(weight=weights.to(device))

optimizer = optim.Adam(model.parameters(), lr=.00003)

In [33]:
print(f"Number of Batches: {len(train_dataloader)}")

Number of Batches: 3523


In [34]:
EPOCHS=1

In [35]:
for epoch in range(1, EPOCHS+1):  

    print(f"Beginning Epoch {epoch}")
    for i, data in enumerate(train_dataloader):
        
        model.zero_grad()

        input_ids, target_ids  = data['input_ids'], data['labels']

        input_ids = input_ids.long().to(device)
        target_ids = target_ids.long().to(device)

        output = model(input_ids, labels=target_ids, return_dict=True)
        loss = loss_function(output.logits.permute(0, 2, 1), target_ids)
     
        loss.backward()
        optimizer.step()
        print(f"Batch {i}")
        print(f"loss: {round(loss.item(), 4)}")

In [36]:
model.save_pretrained('/content/ir-transformer/')

In [37]:
model = model.from_pretrained('/content/drive/MyDrive/Colab Notebooks/ir-transformer-weighted')

In [38]:
model = model.to(device)

In [39]:
# lol_of_dists = []

# for i, data in enumerate(dev_dataloader):
    
#     model.eval()

#     with torch.no_grad():
        
#         input_ids, target_ids  = data['input_ids'], data['labels']

#         batch_labels = data['labels'][:, 0].cpu().detach().tolist()
#         labels += batch_labels

#         input_ids = input_ids.to(device)
#         target_ids = target_ids.to(device)

#         output = model(input_ids, labels=target_ids, return_dict=True)
#         logits = output.logits.cpu().detach()

#         batch_preds = torch.argmax(logits[:, 0], dim=1).tolist()
#         predictions += batch_preds

#         true_query = logits[:, 0, TRUE_TOK_ID]
#         false_query = logits[:, 0, FALSE_TOK_ID]
#         cat = torch.cat([true_query.unsqueeze(1), false_query.unsqueeze(1)], dim=1)
#         soft = F.softmax(cat, dim=1).tolist()
#         lol_of_dists.append(soft)


In [None]:
lol_of_dists = []

for i, data in enumerate(test_dataloader):
    print(i)
    
    model.eval()

    with torch.no_grad():
        
        input_ids, target_ids  = data['input_ids'], data['labels']

        input_ids = input_ids.to(device)
        target_ids = target_ids.to(device)

        output = model(input_ids, labels=target_ids, return_dict=True)

        logits = output.logits

        true_query = logits[:, 0, TRUE_TOK_ID]
        false_query = logits[:, 0, FALSE_TOK_ID]
        cat = torch.cat([true_query.unsqueeze(1), false_query.unsqueeze(1)], dim=1)
        soft = F.softmax(cat, dim=1).tolist()
        lol_of_dists.append(soft)
        torch.cuda.empty_cache()


In [44]:
with open(path_to_data+'lol_of_dists.pickle', 'wb') as outF:
  pickle.dump(lol_of_dists, outF)

In [55]:
eval_dict = convert_to_evaluation(lol_of_dists, top_1000_test2)

In [71]:
compute_T5_metrics(test_qrels, eval_dict, 10)

327.0


{'MRR @10': 0.654, 'QueriesRanked': 500}