In [1]:
import sys
import json
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, LoggingHandler, util, evaluation, InputExample
from sentence_transformers import  models as smodel
import models 
import logging
from datetime import datetime
import gzip
import os
import tarfile
import tqdm
from torch.utils.data import Dataset
import random
from shutil import copyfile
import pickle
import argparse
import losses
import torch

In [2]:
model_name = "Luyu/co-condenser-marco"
max_seq_length = 256
word_embedding_model = models.MLMTransformer(model_name, max_seq_length=max_seq_length)
for param in word_embedding_model.parameters():
    param.requires_grad = False

query_layer = smodel.Dense(in_features=word_embedding_model.get_word_embedding_dimension(), out_features=word_embedding_model.get_word_embedding_dimension())
model = SentenceTransformer(modules=[word_embedding_model, query_layer])

In [3]:
num_negs_per_system = 5
data_folder = '../msmarco'

#### Read the corpus file containing all the passages. Store them in the corpus dict
corpus = {}  # dict in the format: passage_id -> passage. Stores all existing passages
collection_filepath = os.path.join(data_folder, 'collection.tsv')
if not os.path.exists(collection_filepath):
    tar_filepath = os.path.join(data_folder, 'collection.tar.gz')
    if not os.path.exists(tar_filepath):
        logging.info("Download collection.tar.gz")
        util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/collection.tar.gz', tar_filepath)

    with tarfile.open(tar_filepath, "r:gz") as tar:
        tar.extractall(path=data_folder)

logging.info("Read corpus: collection.tsv")
with open(collection_filepath, 'r', encoding='utf8') as fIn:
    for line in fIn:
        pid, passage = line.strip().split("\t")
        pid = int(pid)
        corpus[pid] = passage

### Read the train queries, store in queries dict
queries = {}  # dict in the format: query_id -> query. Stores all training queries
queries_filepath = os.path.join(data_folder, 'queries.train.tsv')
if not os.path.exists(queries_filepath):
    tar_filepath = os.path.join(data_folder, 'queries.tar.gz')
    if not os.path.exists(tar_filepath):
        logging.info("Download queries.tar.gz")
        util.http_get('https://msmarco.blob.core.windows.net/msmarcoranking/queries.tar.gz', tar_filepath)

    with tarfile.open(tar_filepath, "r:gz") as tar:
        tar.extractall(path=data_folder)

with open(queries_filepath, 'r', encoding='utf8') as fIn:
    for line in fIn:
        qid, query = line.strip().split("\t")
        qid = int(qid)
        queries[qid] = query

# Load a dict (qid, pid) -> ce_score that maps query-ids (qid) and paragraph-ids (pid)
# to the CrossEncoder score computed by the cross-encoder-ms-marco-MiniLM-L-6-v2 model
ce_scores_file = os.path.join(data_folder, 'cross-encoder-ms-marco-MiniLM-L-6-v2-scores.pkl.gz')
if not os.path.exists(ce_scores_file):
    logging.info("Download cross-encoder scores file")
    util.http_get('https://huggingface.co/datasets/sentence-transformers/msmarco-hard-negatives/resolve/main/cross-encoder-ms-marco-MiniLM-L-6-v2-scores.pkl.gz', ce_scores_file)

logging.info("Load CrossEncoder scores dict")
with gzip.open(ce_scores_file, 'rb') as fIn:
    ce_scores = pickle.load(fIn)

# As training data we use hard-negatives that have been mined using various systems
hard_negatives_filepath = os.path.join(data_folder, 'msmarco-hard-negatives.jsonl.gz')

train_queries = {}
negs_to_use = None
with gzip.open(hard_negatives_filepath, 'rt') as fIn:
    for line in tqdm.tqdm(fIn):
        data = json.loads(line)

        #Get the positive passage ids
        pos_pids = data['pos']

        #Get the hard negatives
        neg_pids = set()

        negs_to_use = list(data['neg'].keys())
           
        for system_name in negs_to_use:
            if system_name not in data['neg']:
                continue

            system_negs = data['neg'][system_name]
            negs_added = 0
            for pid in system_negs:
                if pid not in neg_pids:
                    neg_pids.add(pid)
                    negs_added += 1
                    if negs_added >= num_negs_per_system:
                        break

        if (len(pos_pids) > 0 and len(neg_pids) > 0):
            train_queries[data['qid']] = {'qid': data['qid'], 'query': queries[data['qid']], 'pos': pos_pids, 'neg': neg_pids}

logging.info("Train queries: {}".format(len(train_queries)))

# We create a custom MS MARCO dataset that returns triplets (query, positive, negative)
# on-the-fly based on the information from the mined-hard-negatives jsonl file.
class MSMARCODataset(Dataset):
    def __init__(self, queries, corpus, ce_scores):
        self.queries = queries
        self.queries_ids = list(queries.keys())
        self.corpus = corpus
        self.ce_scores = ce_scores

        for qid in self.queries:
            self.queries[qid]['pos'] = list(self.queries[qid]['pos'])
            self.queries[qid]['neg'] = list(self.queries[qid]['neg'])
            random.shuffle(self.queries[qid]['neg'])

    def __getitem__(self, item):
        query = self.queries[self.queries_ids[item]]
        query_text = query['query']
        qid = query['qid']

        if len(query['pos']) > 0:
            pos_id = query['pos'].pop(0)    #Pop positive and add at end
            pos_text = self.corpus[pos_id]
            query['pos'].append(pos_id)
        else:   #We only have negatives, use two negs
            pos_id = query['neg'].pop(0)    #Pop negative and add at end
            pos_text = self.corpus[pos_id]
            query['neg'].append(pos_id)

        #Get a negative passage
        neg_id = query['neg'].pop(0)    #Pop negative and add at end
        neg_text = self.corpus[neg_id]
        query['neg'].append(neg_id)

        pos_score = self.ce_scores[qid][pos_id]
        neg_score = self.ce_scores[qid][neg_id]

        return InputExample(texts=[query_text, pos_text, neg_text], label=pos_score-neg_score)

    def __len__(self):
        return len(self.queries)

808731it [01:39, 8168.24it/s] 


In [19]:
import torch
from torch import nn, Tensor
from typing import Iterable, Dict

class UNIFORM: 
    def __call__(self, x, t=2):
        result = torch.pdist(x, p=2).pow(2).mul(-t)
        return result.exp().mean().log()


class MarginMSELossSplade(nn.Module):
    """
    Compute the MSE loss between the |sim(Query, Pos) - sim(Query, Neg)| and |gold_sim(Q, Pos) - gold_sim(Query, Neg)|
    By default, sim() is the dot-product
    For more details, please refer to https://arxiv.org/abs/2010.02666
    """
    def __init__(self, model, similarity_fct = losses.pairwise_dot_score, lambda_d=8e-2, lambda_q=1e-1, lambda_uni = 1e-2, uni_mse = False, uni_d = False, uni_q = False):
        """
        :param model: SentenceTransformerModel
        :param similarity_fct:  Which similarity function to use
        """
        super(MarginMSELossSplade, self).__init__()
        self.model = model
        self.similarity_fct = similarity_fct
        self.loss_fct = nn.MSELoss()
        self.lambda_d = lambda_d
        self.lambda_q = lambda_q
        self.FLOPS = losses.FLOPS()
        self.uniform_mse = uni_mse
        self.uni_d = uni_d
        self.uni_q = uni_q
        self.uni = UNIFORM()
        self.lambda_uni = lambda_uni

    def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
        # sentence_features: query, positive passage, negative passage
        embeddings_query = self.model(sentence_features[0])['sentence_embedding'] 
        embeddings_pos = self.model[0](sentence_features[1])['sentence_embedding'] 
        embeddings_neg = self.model[0](sentence_features[2])['sentence_embedding'] 
        
        scores_pos = self.similarity_fct(embeddings_query, embeddings_pos)
        scores_neg = self.similarity_fct(embeddings_query, embeddings_neg)
        margin_pred = scores_pos - scores_neg
        
        flops_doc = self.lambda_d*(self.FLOPS(embeddings_pos) + self.FLOPS(embeddings_neg))
        flops_query = self.lambda_q*(self.FLOPS(embeddings_query))
        overall_loss = self.loss_fct(margin_pred, labels)
        print("0", overall_loss)
        
        uni_d = self.uni(torch.nn.functional.normalize(embeddings_pos,dim=1))
        
        uni_q = self.uni(torch.nn.functional.normalize(embeddings_query,dim=1))
        
        if self.uni_d:
           overall_loss +=  self.lambda_uni * uni_d
        if self.uni_q:
           overall_loss +=  self.lambda_uni * uni_q
        if self.uniform_mse:
            uniform_dist = self.lambda_uni * (uni_q - uni_d) ** 2
            overall_loss +=  uniform_dist
        print("1", (uni_q - uni_d) ** 2)
        return overall_loss



In [20]:
train_dataset = MSMARCODataset(queries=train_queries, corpus=corpus, ce_scores=ce_scores)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=2, drop_last=True)
train_loss = MarginMSELossSplade(model=model, lambda_d=0.01, lambda_q=0.01, lambda_uni = 1, uni_mse = True)

In [11]:
for k in model[0].parameters():
    print(k)
    break
    
for k in model[1].parameters():
    print(k)
    break

Parameter containing:
tensor([[-0.0389, -0.0761, -0.0617,  ..., -0.0482, -0.0597, -0.0184],
        [-0.0442, -0.0332,  0.0141,  ..., -0.0077, -0.0159, -0.0392],
        [-0.0403, -0.0679,  0.0146,  ...,  0.0318, -0.0262,  0.0031],
        ...,
        [-0.0263, -0.0573, -0.0065,  ...,  0.0549, -0.0242, -0.0405],
        [-0.0189, -0.0761,  0.0150,  ..., -0.0028, -0.0224, -0.0069],
        [ 0.0124, -0.0798, -0.0331,  ..., -0.0126, -0.0744,  0.0577]],
       device='cuda:0')
Parameter containing:
tensor([[-0.0009,  0.0047, -0.0018,  ..., -0.0022,  0.0025,  0.0033],
        [ 0.0003, -0.0012,  0.0020,  ...,  0.0041, -0.0033,  0.0020],
        [-0.0048, -0.0025, -0.0032,  ..., -0.0009,  0.0015, -0.0030],
        ...,
        [-0.0045, -0.0041, -0.0045,  ..., -0.0052, -0.0020,  0.0024],
        [ 0.0013, -0.0004,  0.0050,  ...,  0.0008, -0.0010,  0.0002],
        [ 0.0049,  0.0028,  0.0020,  ..., -0.0037,  0.0011,  0.0028]],
       device='cuda:0', requires_grad=True)


In [21]:
model.fit(train_objectives=[(train_dataloader, train_loss)],
          epochs=20,
          warmup_steps=10,
          use_amp=True,
          checkpoint_path=".",
          checkpoint_save_steps=10000,
          optimizer_params = {'lr':1e-4})

Epoch:   0%|          | 0/20 [00:00<?, ?it/s]

Iteration:   0%|          | 0/251469 [00:00<?, ?it/s]

0 tensor(154.4012, device='cuda:0', grad_fn=<MseLossBackward0>)
1 tensor(0.2839, device='cuda:0', grad_fn=<PowBackward0>)


RuntimeError: CUDA out of memory. Tried to allocate 3.47 GiB (GPU 0; 22.20 GiB total capacity; 16.20 GiB already allocated; 2.31 GiB free; 17.93 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [1]:
import torch.nn as nn
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")



Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

In [3]:
qtext = "receptor androgen define"
tokenizer.tokenize(qtext)

['receptor', 'and', '##rogen', 'define']

In [5]:
from collections import defaultdict

qrel_file = "../../msmarco/qrels.train.tsv"

qrels = defaultdict(dict)
with open(qrel_file) as f:
    for line in f:
        try:
            qid, _, did, rel = line.strip().split("\t")
        except:
            qid, _, did, rel = line.strip().split(" ")
        if int(rel) > 0:
            qrels[qid][did] = int(rel)
        
with open("../../msmarco/queries.train.tsv") as f, open("../../msmarco/queries.train.wordpiece.tsv", "w") as fo:
    for line in f:
        qid, qtext = line.strip().split("\t")
        if qid in qrels:
            fo.write(f"{qid}\t{' '.join(tokenizer.tokenize(qtext))}\n")