In [1]:
import os
import csv
import json
from tqdm import tqdm
import gc

import torch

from transformers import AutoModel, DPRQuestionEncoder
from datasets import load_dataset
from beir.datasets.data_loader import GenericDataLoader
from beir import util
from pyserini.search.lucene import LuceneSearcher
from torchviz import make_dot
from losses import contrastive_loss

from model_biencoder import BiEncoder
from preprocessing import get_triplets_dataloader, get_pairs_dataloader

import dotenv
dotenv.load_dotenv()

os.environ["JAVA_HOME"] = "/usr/lib64/openjdk-21"
os.environ["PATH"] = os.environ["JAVA_HOME"] + "/bin:" + os.environ["PATH"]

  from .autonotebook import tqdm as notebook_tqdm


# Parameter Comparisons

Comparing E5's automatic implementation and mine's. 

In [8]:

custom_model = BiEncoder(model_path = ("intfloat/e5-base-v2", "intfloat/e5-base-v2"), sep=" [SEP] ")
custom_q_model = custom_model.q_model.to("cpu")
base_model = AutoModel.from_pretrained("intfloat/e5-base-v2").to("cpu")

# Compare weights
custom_weights = custom_q_model.state_dict().keys()

for name, param in base_model.named_parameters():
    #print(name)
    if name in custom_weights:
        custom_param = custom_q_model.state_dict()[name]
        if torch.equal(param, custom_param):
            #print(f"Layer {name} matches")
            pass
        else:
            print(f"Layer {name} does not match")

    else:
        print(f"Layer {name} not found.")

Comparing DPR's implementation with mine. 

In [None]:

custom_model = BiEncoder(model_path = ("facebook/dpr-question_encoder-single-nq-base", "facebook/dpr-ctx_encoder-single-nq-base"), sep=" [SEP] ")
custom_q_model = custom_model.q_model.to("cpu")
base_model = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base").to("cpu")

# Compare weights
custom_weights = custom_q_model.state_dict().keys()

for name, param in base_model.named_parameters():
    #print(name)
    if name in custom_weights:
        custom_param = custom_q_model.state_dict()[name]
        if torch.equal(param, custom_param):
            #print(f"Layer {name} matches")
            pass
        else:
            print(f"Layer {name} does not match")

    else:
        print(f"Layer {name} not found.")

# Torchviz

In [2]:
batch_size = 20

In [3]:
dpr_model = BiEncoder(
    model_path=("google-bert/bert-base-uncased", "google-bert/bert-base-uncased"),
    normalize=False,
    prompts={"query": "", "passage": ""},
    attn_implementation="eager", 
    sep = " [SEP] ", 
    batch_size=batch_size
)
dpr_model.train()

In [4]:
dataset_path = os.path.join("/Tmp/lvpoellhuber/datasets/nq", "train_triplets.pt")
dataloader = get_triplets_dataloader(batch_size=batch_size, dataset_path=dataset_path,)

In [None]:
dataset_path = os.path.join("/Tmp/lvpoellhuber/datasets/nq", "train_pairs.pt")
dataloader = get_pairs_dataloader(batch_size=batch_size, dataset_path=dataset_path,)

In [5]:
dpr_params = {}

q_params = dict(dpr_model.q_model.named_parameters())
doc_params = dict(dpr_model.doc_model.named_parameters())
for param in q_params.keys():
    q_param = q_params[param]
    doc_param = doc_params[param]

    dpr_params["q_encoder."+param] = q_param 
    dpr_params["doc_encoder."+param] = doc_param 

In [6]:
batch = [item for item in dataloader][0]

queries = batch["queries"]
# documents = batch["documents"]
positives = batch["positives"]
negatives = batch["negatives"]
documents = positives + negatives

In [7]:
def dpr_step(queries, documents):
    q_embeddings = dpr_model.encode_queries(queries, convert_to_tensor=True) # All three 16x512
    doc_embeddings = dpr_model.encode_corpus(documents, convert_to_tensor=True) # All three 16x512

    loss = contrastive_loss(q_embeddings, doc_embeddings)

    return loss

# loss = dpr_step(queries, documents)

In [8]:
make_dot(dpr_step(queries, documents), params=dpr_params, show_attrs=True).render("/u/poellhul/Documents/Masters/benchmarkIR-slurm/src/retrieval/dual_graph")

'/u/poellhul/Documents/Masters/benchmarkIR-slurm/src/retrieval/dual_graph.pdf'

In [9]:
make_dot(dpr_model.encode_corpus(documents, convert_to_tensor=True), params=dpr_params, show_attrs=True).render("/u/poellhul/Documents/Masters/benchmarkIR-slurm/src/retrieval/doc_graph")

'/u/poellhul/Documents/Masters/benchmarkIR-slurm/src/retrieval/doc_graph.pdf'

# Dataloader inspection

In [None]:
dataset_path = os.path.join("/Tmp/lvpoellhuber/datasets/nq/train_pairs.pt")
dataloader = get_pairs_dataloader(batch_size=12, dataset_path=dataset_path)
       

In [3]:
for batch in dataloader:
    print("huh")

KeyboardInterrupt: 

In [None]:
len(dataloader)

11067

: 