In [2]:
import os
import csv
import json
from tqdm import tqdm
import gc

import torch

from transformers import AutoModel, DPRQuestionEncoder
from datasets import load_dataset
from beir.datasets.data_loader import GenericDataLoader
from beir import util
from pyserini.search.lucene import LuceneSearcher
from torchviz import make_dot
from losses import contrastive_loss
from model_longtriever import Longtriever

from model_biencoder import BiEncoder
from preprocessing.preprocess_utils import get_triplets_dataloader, get_pairs_dataloader

import dotenv
dotenv.load_dotenv()

os.environ["JAVA_HOME"] = "/usr/lib64/openjdk-21"
os.environ["PATH"] = os.environ["JAVA_HOME"] + "/bin:" + os.environ["PATH"]

  from .autonotebook import tqdm as notebook_tqdm


# Parameter Comparisons

## BiEncoder
Comparing E5's automatic implementation and mine's. 

In [None]:

custom_model = BiEncoder(model_path = ("intfloat/e5-base-v2", "intfloat/e5-base-v2"), sep=" [SEP] ")
custom_q_model = custom_model.q_model.to("cpu")
base_model = AutoModel.from_pretrained("intfloat/e5-base-v2").to("cpu")

# Compare weights
custom_weights = custom_q_model.state_dict().keys()

for name, param in base_model.named_parameters():
    #print(name)
    if name in custom_weights:
        custom_param = custom_q_model.state_dict()[name]
        if torch.equal(param, custom_param):
            #print(f"Layer {name} matches")
            pass
        else:
            print(f"Layer {name} does not match")

    else:
        print(f"Layer {name} not found.")

Comparing DPR's implementation with mine. 

In [None]:

custom_model = BiEncoder(model_path = ("facebook/dpr-question_encoder-single-nq-base", "facebook/dpr-ctx_encoder-single-nq-base"), sep=" [SEP] ")
custom_q_model = custom_model.q_model.to("cpu")
base_model = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base").to("cpu")

# Compare weights
custom_weights = custom_q_model.state_dict().keys()

for name, param in base_model.named_parameters():
    #print(name)
    if name in custom_weights:
        custom_param = custom_q_model.state_dict()[name]
        if torch.equal(param, custom_param):
            #print(f"Layer {name} matches")
            pass
        else:
            print(f"Layer {name} does not match")

    else:
        print(f"Layer {name} not found.")

## Longtriever
Why the hell does longtriever NOT WANNA INITIALIZE BERT PROPERlY??

In [2]:
longtriever = Longtriever.from_pretrained(
            "google-bert/bert-base-uncased",
            torch_dtype="auto",
            trust_remote_code=True,
            attn_implementation="eager",
            cache_dir=None
    ).to("cpu")
base_model = AutoModel.from_pretrained("google-bert/bert-base-uncased").to("cpu")

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
Some weights of Longtriever were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['bert.doc_embeddings', 'bert.encoder.information_exchanging_layer.0.attention.output.LayerNorm.bias', 'bert.encoder.information_exchanging_layer.0.attention.output.LayerNorm.weight', 'bert.encoder.information_exchanging_layer.0.attention.output.dense.bias', 'bert.encoder.information_exchanging_layer.0.attention.output.dense.weight', 'bert.encoder.information_exchanging_layer.0.attention.self.key.bias', 'bert.encoder.information_exchanging_layer.0.attention.self.key.weight', 'bert.encoder.information_exchanging_layer.0.attention.self.query.bias', 'bert.encoder.information_exchanging_layer.0.attention.self.query.weight', 'bert.encoder.information_exchanging_layer.0.attention.self.value.bias', 'bert.encoder.information_exchanging_layer.0.attention.self.

In [56]:

# Get the state_dict of the BERT encoder
bert_state_dict = base_model.state_dict()

# Get the state_dict of the Longtriever model
longtriever_state_dict = longtriever.state_dict()

# Map weights from BERT to Longtriever
new_state_dict = {}
for bert_key, bert_value in bert_state_dict.items():
    # Replace layer names to match Longtriever's naming convention
    if "layer" in bert_key:
        # Example: Map BERT's "layer.X" to Longtriever's "text_encoding_layers.X"
        new_key = bert_key.replace("layer", "text_encoding_layer")
        if new_key in longtriever_state_dict:
            new_state_dict[new_key] = bert_value

        # Example: Map BERT's "layer.X" to Longtriever's "information_exchanging.X"
        new_key = bert_key.replace("layer", "information_exchanging_layer")
        if new_key in longtriever_state_dict:
            new_state_dict[new_key] = bert_value

# Update Longtriever's state_dict with the new weights
longtriever_state_dict.update(new_state_dict)

# Load the updated state_dict into Longtriever
longtriever.load_state_dict(longtriever_state_dict)

<All keys matched successfully>

In [46]:
longtriever.modules

<bound method Module.modules of Longtriever(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BlockLevelContextawareEncoder(
    (text_encoding_layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_af

In [47]:
base_model.modules

<bound method Module.modules of RobertaModel(
  (embeddings): RobertaEmbeddings(
    (word_embeddings): Embedding(50265, 768, padding_idx=1)
    (position_embeddings): Embedding(514, 768, padding_idx=1)
    (token_type_embeddings): Embedding(1, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): RobertaEncoder(
    (layer): ModuleList(
      (0-11): 12 x RobertaLayer(
        (attention): RobertaAttention(
          (self): RobertaSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): RobertaSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-05, elementw

In [4]:

# Compare weights
lt_weights = longtriever.state_dict().keys()

for name, param in base_model.named_parameters():
    name = name.replace("layer", "text_encoding_layer")
    # print(name)
    if name in lt_weights:
        lt_param = longtriever.state_dict()[name]
        if torch.equal(param, lt_param):
            # print(f"Layer {name} matches")
            pass
        else:
            print(f"Layer {name} does not match")
            pass

    else:
        print(f"Layer {name} not found.")
        pass

Layer pooler.dense.weight not found.
Layer pooler.dense.bias not found.


In [5]:

for name, param in longtriever.named_parameters():
    print(name)

doc_embeddings
embeddings.word_embeddings.weight
embeddings.position_embeddings.weight
embeddings.token_type_embeddings.weight
embeddings.LayerNorm.weight
embeddings.LayerNorm.bias
encoder.text_encoding_layer.0.attention.self.query.weight
encoder.text_encoding_layer.0.attention.self.query.bias
encoder.text_encoding_layer.0.attention.self.key.weight
encoder.text_encoding_layer.0.attention.self.key.bias
encoder.text_encoding_layer.0.attention.self.value.weight
encoder.text_encoding_layer.0.attention.self.value.bias
encoder.text_encoding_layer.0.attention.output.dense.weight
encoder.text_encoding_layer.0.attention.output.dense.bias
encoder.text_encoding_layer.0.attention.output.LayerNorm.weight
encoder.text_encoding_layer.0.attention.output.LayerNorm.bias
encoder.text_encoding_layer.0.intermediate.dense.weight
encoder.text_encoding_layer.0.intermediate.dense.bias
encoder.text_encoding_layer.0.output.dense.weight
encoder.text_encoding_layer.0.output.dense.bias
encoder.text_encoding_layer.0

# Torchviz

In [2]:
batch_size = 20

In [3]:
dpr_model = BiEncoder(
    model_path=("google-bert/bert-base-uncased", "google-bert/bert-base-uncased"),
    normalize=False,
    prompts={"query": "", "passage": ""},
    attn_implementation="eager", 
    sep = " [SEP] ", 
    batch_size=batch_size
)
dpr_model.train()

In [4]:
dataset_path = os.path.join("/Tmp/lvpoellhuber/datasets/nq", "train_triplets.pt")
dataloader = get_triplets_dataloader(batch_size=batch_size, dataset_path=dataset_path,)

In [None]:
dataset_path = os.path.join("/Tmp/lvpoellhuber/datasets/nq", "train_pairs.pt")
dataloader = get_pairs_dataloader(batch_size=batch_size, dataset_path=dataset_path,)

In [5]:
dpr_params = {}

q_params = dict(dpr_model.q_model.named_parameters())
doc_params = dict(dpr_model.doc_model.named_parameters())
for param in q_params.keys():
    q_param = q_params[param]
    doc_param = doc_params[param]

    dpr_params["q_encoder."+param] = q_param 
    dpr_params["doc_encoder."+param] = doc_param 

In [6]:
batch = [item for item in dataloader][0]

queries = batch["queries"]
# documents = batch["documents"]
positives = batch["positives"]
negatives = batch["negatives"]
documents = positives + negatives

In [7]:
def dpr_step(queries, documents):
    q_embeddings = dpr_model.encode_queries(queries, convert_to_tensor=True) # All three 16x512
    doc_embeddings = dpr_model.encode_corpus(documents, convert_to_tensor=True) # All three 16x512

    loss = contrastive_loss(q_embeddings, doc_embeddings)

    return loss

# loss = dpr_step(queries, documents)

In [8]:
make_dot(dpr_step(queries, documents), params=dpr_params, show_attrs=True).render("/u/poellhul/Documents/Masters/benchmarkIR-slurm/src/retrieval/dual_graph")

'/u/poellhul/Documents/Masters/benchmarkIR-slurm/src/retrieval/dual_graph.pdf'

In [9]:
make_dot(dpr_model.encode_corpus(documents, convert_to_tensor=True), params=dpr_params, show_attrs=True).render("/u/poellhul/Documents/Masters/benchmarkIR-slurm/src/retrieval/doc_graph")

'/u/poellhul/Documents/Masters/benchmarkIR-slurm/src/retrieval/doc_graph.pdf'

# Dataloader inspection

### NQ

In [None]:
dataset_path = os.path.join("/Tmp/lvpoellhuber/datasets/nq/train_pairs.pt")
dataloader = get_pairs_dataloader(batch_size=12, dataset_path=dataset_path)
       

In [None]:
for batch in dataloader:
    print("huh")

### MS Marco-doc

In [2]:
dataloader = get_pairs_dataloader(
        batch_size=3, 
        dataset_path="/Tmp/lvpoellhuber/datasets/msmarco-doc/train_pairs.pt", 
        pin_memory=True, 
        prefetch_factor=2, 
        num_workers = 4
    )

In [3]:
import numpy as np
import random
def seed_everything(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if using multi-GPU
    np.random.seed(seed)
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)

seed_everything(42)

# Hidden State Shapes

In [311]:
hidden_states= torch.rand(3, 8, 513, 768)

hidden_states[:, :, 0, :] = -1

hidden_states[0, 0, :, :]


tensor([[-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000],
        [ 0.1767,  0.7940,  0.0560,  ...,  0.0948,  0.9265,  0.5313],
        [ 0.2671,  0.3312,  0.9351,  ...,  0.9478,  0.0695,  0.5010],
        ...,
        [ 0.5979,  0.3564,  0.4494,  ...,  0.1155,  0.8178,  0.7187],
        [ 0.0975,  0.5869,  0.5117,  ...,  0.7252,  0.8423,  0.8690],
        [ 0.6025,  0.1467,  0.8125,  ...,  0.7794,  0.3533,  0.2911]])

Right above is the hidden state for block 0 of example 0. Its first token is the DOC token, which is assigned the value -1.0000. The matrix above is of shape 513x768. The **VERTICAL** line represents the words (513), while the **horizontal** line represents the vector representations (768).

The DOC token (1) will have the value -1 for its entire representation, as it represents the indexes we want to conserve. 

In [312]:
hidden_states[:, torch.arange(8), torch.arange(8)+1, :] = ((torch.arange(8)+1)*10.0).view(8, 1).repeat(3, 1, 768)


hidden_states[0, 0, :, :].round()

tensor([[-1., -1., -1.,  ..., -1., -1., -1.],
        [10., 10., 10.,  ..., 10., 10., 10.],
        [ 0.,  0.,  1.,  ...,  1.,  0.,  1.],
        ...,
        [ 1.,  0.,  0.,  ...,  0.,  1.,  1.],
        [ 0.,  1.,  1.,  ...,  1.,  1.,  1.],
        [ 1.,  0.,  1.,  ...,  1.,  0.,  0.]])

Now we can see the hidden state above with the CLS tokens. This is again block 0 of example 0. The CLS token is at position 1 for this block (it would be 0 without the DOC token). Each CLS token will have its own index value, multiplied by then when it is the current block's token. 

In [313]:
hidden_states[0, 1, :, :]

tensor([[-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000],
        [ 0.9359,  0.4442,  0.4246,  ...,  0.7920,  0.0916,  0.8704],
        [20.0000, 20.0000, 20.0000,  ..., 20.0000, 20.0000, 20.0000],
        ...,
        [ 0.5618,  0.3202,  0.8232,  ...,  0.7665,  0.3409,  0.5628],
        [ 0.6446,  0.5704,  0.3970,  ...,  0.7285,  0.4527,  0.9810],
        [ 0.8799,  0.8465,  0.5363,  ...,  0.4101,  0.4948,  0.6691]])

We can verify with block 1 of example 0, where the CLS token (20.0000) is at position 2. 


In [314]:
cls_tokens = hidden_states[:, torch.arange(8), torch.arange(8)+1, :]/10 # Divide by ten to indicate they're not the main tokens
cls_tokens

tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [2., 2., 2.,  ..., 2., 2., 2.],
         [3., 3., 3.,  ..., 3., 3., 3.],
         ...,
         [6., 6., 6.,  ..., 6., 6., 6.],
         [7., 7., 7.,  ..., 7., 7., 7.],
         [8., 8., 8.,  ..., 8., 8., 8.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [2., 2., 2.,  ..., 2., 2., 2.],
         [3., 3., 3.,  ..., 3., 3., 3.],
         ...,
         [6., 6., 6.,  ..., 6., 6., 6.],
         [7., 7., 7.,  ..., 7., 7., 7.],
         [8., 8., 8.,  ..., 8., 8., 8.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [2., 2., 2.,  ..., 2., 2., 2.],
         [3., 3., 3.,  ..., 3., 3., 3.],
         ...,
         [6., 6., 6.,  ..., 6., 6., 6.],
         [7., 7., 7.,  ..., 7., 7., 7.],
         [8., 8., 8.,  ..., 8., 8., 8.]]])

Above we verify for example 0 that the extracted CLS tokens truly are the correct ones. 

Now, I need to figure out how to select the lower diagonal and modify it. 

## Triangular matrices
### With 3x8x513x768

Unfortunately I coded this assuming the 3x8 were separate, rather than 24. It's more logical and easier to code, but it won't fit with the model. 

In [323]:
B, N, L_, D = hidden_states.shape
B, N, L_, D

(3, 8, 513, 768)

In [None]:
# pre
pre_indices = torch.tril_indices(N, L_)
mask = pre_indices[1] != 0
filtered_indices = pre_indices[:, mask]

extended_cls_tokens = cls_tokens.repeat(N, 1, 1, 1).view(B, N, N, D)
pre_cls_indices = torch.tril_indices(N, N, offset=-1)

hidden_states[:, filtered_indices[0], filtered_indices[1], :] = extended_cls_tokens[:, pre_cls_indices[0], pre_cls_indices[1], :]

# post
post_indices = torch.triu_indices(N, L_, offset=L_-N+1)
mask = post_indices[1] != hidden_states.shape[1] - 1
filtered_indices = post_indices[:, mask]

post_cls_indices = torch.triu_indices(N, N, offset=1)

hidden_states[:, filtered_indices[0], filtered_indices[1], :] = extended_cls_tokens[:, post_cls_indices[0], post_cls_indices[1], :]
