<a href="https://colab.research.google.com/github/hammadkhann/Effective-Dense-Retrieval/blob/main/Hard_Negatives.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install sentence_transformers
!pip install beir



In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [5]:
from sentence_transformers import SentenceTransformer, models, losses, InputExample
from beir import util, LoggingHandler
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.train import TrainRetriever
from torch.utils.data import Dataset
from tqdm.autonotebook import tqdm
import pathlib, os, gzip, json
import logging
import random

In [7]:
# #### Just some code to print debug information to stdout
# logging.basicConfig(format='%(asctime)s - %(message)s',
#                     datefmt='%Y-%m-%d %H:%M:%S',
#                     level=logging.INFO,
#                     handlers=[LoggingHandler()])
# #### /print debug information to stdout

# #### Download msmarco.zip dataset and unzip the dataset
# dataset = "msmarco"
# url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
# out_dir = os.path.join("msmarco/datasets")
# data_path = util.download_and_unzip(url, out_dir)

### Load BEIR MSMARCO training dataset, this will be used for query and corpus for reference.
corpus, queries, _ = GenericDataLoader('msmarco/datasets/msmarco').load(split="train")

  0%|          | 0/8841823 [00:00<?, ?it/s]

In [8]:
#################################
#### Parameters for Training ####
#################################

train_batch_size = 75           # Increasing the train batch size improves the model performance, but requires more GPU memory (O(n))
max_seq_length = 350            # Max length for passages. Increasing it, requires more GPU memory (O(n^2))
ce_score_margin = 3             # Margin for the CrossEncoder score between negative and positive passages
num_negs_per_system = 5         # We used different systems to mine hard negatives. Number of hard negatives to add from each system

In [9]:
##################################################
#### Download MSMARCO Hard Negs Triplets File ####
##################################################

# triplets_url = "https://sbert.net/datasets/msmarco-hard-negatives.jsonl.gz"
# msmarco_triplets_filepath = os.path.join(data_path, "msmarco-hard-negatives.jsonl.gz")
# if not os.path.isfile(msmarco_triplets_filepath):
#     util.download_url(triplets_url, msmarco_triplets_filepath)

#### Load the hard negative MSMARCO jsonl triplets from SBERT 
#### These contain a ce-score which denotes the cross-encoder score for the query and passage.
#### We chose a margin between positive and negative passage scores => above which consider negative as hard negative. 
#### Finally to limit the number of negatives per passage, we define num_negs_per_system across all different systems.

logging.info("Loading MSMARCO hard-negatives...")

train_queries = {}
with gzip.open('/content/drive/MyDrive/msmarco-hard-negatives.jsonl.gz', 'rt', encoding='utf8') as fIn:
    for line in tqdm(fIn, total=502939):
        data = json.loads(line)
        
        #Get the positive passage ids
        pos_pids = [item['pid'] for item in data['pos']]
        pos_min_ce_score = min([item['ce-score'] for item in data['pos']])
        ce_score_threshold = pos_min_ce_score - ce_score_margin
        
        #Get the hard negatives
        neg_pids = set()
        for system_negs in data['neg'].values():
            negs_added = 0
            for item in system_negs:
                if item['ce-score'] > ce_score_threshold:
                    continue

                pid = item['pid']
                if pid not in neg_pids:
                    neg_pids.add(pid)
                    negs_added += 1
                    if negs_added >= num_negs_per_system:
                        break
        
        if len(pos_pids) > 0 and len(neg_pids) > 0:
            train_queries[data['qid']] = {'query': queries[data['qid']], 'pos': pos_pids, 'hard_neg': list(neg_pids)}
        
logging.info("Train queries: {}".format(len(train_queries)))

  0%|          | 0/502939 [00:00<?, ?it/s]

In [10]:
train_queries['571018']

{'hard_neg': ['8768341',
  '6717925',
  '1626276',
  '6449111',
  '1833477',
  '2524942',
  '7862900',
  '6948601',
  '4903324',
  '6717927',
  '1891131',
  '2305477',
  '7179547',
  '1065943',
  '7491025',
  '7027050',
  '1516443',
  '7790853',
  '3911285',
  '1065948'],
 'pos': ['7349777'],
 'query': 'what are the liberal arts?'}

In [30]:
# We create a custom MSMARCO dataset that returns triplets (query, positive, negative)
# on-the-fly based on the information from the mined-hard-negatives jsonl file.

class MSMARCODataset(Dataset):
    def __init__(self, queries, corpus):
        self.queries = queries
        self.queries_ids = list(queries.keys())
        self.corpus = corpus

        for qid in self.queries:
            self.queries[qid]['pos'] = list(self.queries[qid]['pos'])
            self.queries[qid]['hard_neg'] = list(self.queries[qid]['hard_neg'])
            random.shuffle(self.queries[qid]['hard_neg'])

    def __getitem__(self, item):
        query = self.queries[self.queries_ids[item]]
        query_text = query['query']

        pos_id = query['pos'].pop(0)    #Pop positive and add at end
        pos_text = self.corpus[pos_id]["text"]
        query['pos'].append(pos_id)

        neg_id = query['hard_neg'].pop(0)    #Pop negative and add at end
        neg_text = self.corpus[neg_id]["text"]
        query['hard_neg'].append(neg_id)

        return '\t'.join([query_text, pos_text, neg_text])

    def __len__(self):
        return len(self.queries)

In [33]:
train_dataset = MSMARCODataset(train_queries, corpus=corpus)

In [67]:
import csv        
hard_negatives = []
for i in range(10000):
  hard_negatives.append(train_dataset[i])

with open("hard_negatives_triples.tsv", "w") as outfile:
    outfile.write("\n".join(hard_negatives))