# SXdcf|

In [1]:


import logging
import math
import os
import random
import sys
import time
from typing import Tuple

import hydra
import torch
from omegaconf import DictConfig, OmegaConf
from torch import Tensor as T
from torch import nn

from dpr.models import init_biencoder_components
from dpr.models.biencoder_neg_dpr import BiEncoderNllLoss, BiEncoderBatch
from dpr.options import (
    setup_cfg_gpu,
    set_seed,
    get_encoder_params_state_from_cfg,
    set_cfg_params_from_state,
    setup_logger,
)
from dpr.utils.conf_utils import BiencoderDatasetsCfg
from dpr.utils.data_utils import (
    ShardedDataIterator,
    Tensorizer,
    MultiSetDataIterator,
    LocalShardedDataIterator,
)
from dpr.utils.dist_utils import all_gather_list
from dpr.utils.model_utils import (
    setup_for_distributed_mode,
    move_to_device,
    get_schedule_linear,
    CheckpointState,
    get_model_file,
    get_model_obj,
    load_states_from_checkpoint,
)
import torch.nn.functional as F
logger = logging.getLogger()
setup_logger(logger)


In [2]:
from hydra import initialize, initialize_config_module, initialize_config_dir, compose
import hydra
hydra.initialize()
cfg = compose(config_name="conf/biencoder_train_cfg.yaml")
cfg = cfg.conf


The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  hydra.initialize()
See https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/changes_to_hydra_main_config_path for more information.
  hydra.initialize()
See https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/changes_to_package_header for more information


In [3]:
cfg.keys()

dict_keys(['encoder', 'train', 'datasets', 'train_datasets', 'dev_datasets', 'output_dir', 'train_sampling_rates', 'loss_scale_factors', 'do_lower_case', 'val_av_rank_start_epoch', 'seed', 'checkpoint_file_name', 'model_file', 'local_rank', 'global_loss_buf_sz', 'device', 'distributed_world_size', 'distributed_port', 'distributed_init_method', 'poison_scale', 'clip_scale', 'no_cuda', 'n_gpu', 'fp16', 'fp16_opt_level', 'special_tokens', 'ignore_checkpoint_offset', 'ignore_checkpoint_optimizer', 'ignore_checkpoint_lr', 'multi_q_encoder', 'local_shards_dataloader'])

In [4]:
# cfg.model_file = "/scratch/gbagwe/Projects/DPR/models/dpr_4-3/dpr_biencoder.30"

In [5]:
# os.path.exists(cfg.model_file)

In [6]:
cfg = setup_cfg_gpu(cfg)

[23225981003584] 2024-04-18 22:00:00,542 [INFO] root: CFG's local_rank=-1
[23225981003584] 2024-04-18 22:00:00,543 [INFO] root: Env WORLD_SIZE=None
[23225981003584] 2024-04-18 22:00:00,544 [INFO] root: Initialized host node0096.palmetto.clemson.edu as d.rank -1 on device=cuda, n_gpu=1, world size=1
[23225981003584] 2024-04-18 22:00:00,544 [INFO] root: 16-bits training: False 


In [7]:
cfg.train.batch_size = 4
cfg.train.batch_size 

4

In [8]:
cfg.output_dir = "./outputs/exp_loss"
cfg.train_datasets = ["nq_dev"]
cfg.dev_datasets = ["nq_dev"]

In [9]:
if cfg.output_dir is not None:
        os.makedirs(cfg.output_dir, exist_ok=True)

In [10]:
set_seed(cfg)

In [11]:
from train_dense_encoder import BiEncoderTrainer
trainer = BiEncoderTrainer(cfg)

[23225981003584] 2024-04-18 22:00:00,598 [INFO] root: ***** Initializing components for training *****
[23225981003584] 2024-04-18 22:00:00,599 [INFO] root: Checkpoint files []
[23225981003584] 2024-04-18 22:00:01,860 [INFO] dpr.models.hf_models: Initializing HF BERT Encoder. cfg_name=bert-base-uncased
[23225981003584] 2024-04-18 22:00:02,139 [INFO] dpr.models.hf_models: Initializing HF BERT Encoder. cfg_name=bert-base-uncased
[23225981003584] 2024-04-18 22:00:04,760 [INFO] dpr.utils.conf_utils: train_datasets: ['nq_train']
[23225981003584] 2024-04-18 22:00:04,762 [INFO] dpr.utils.conf_utils: dev_datasets: ['nq_dev']


In [12]:
a =torch.randn(10,768)
b= torch.randn(10,768)

In [13]:
c = torch.cat((a,b))

In [14]:
c.shape

torch.Size([20, 768])

In [15]:
train_iterator = trainer.get_data_iterator(
        cfg.train.batch_size,
        True,
        shuffle=False,
        shuffle_seed=cfg.seed,
        offset=trainer.start_batch,
        rank=cfg.local_rank,
    )

[23225981003584] 2024-04-18 22:00:04,784 [INFO] root: Initializing task/set data ['nq_train']
[23225981003584] 2024-04-18 22:00:04,785 [INFO] root: Calculating shard positions
[23225981003584] 2024-04-18 22:00:04,786 [INFO] dpr.data.biencoder_data: Loading all data
[23225981003584] 2024-04-18 22:00:04,791 [INFO] dpr.data.download_data: Requested resource from https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-train.json.gz
[23225981003584] 2024-04-18 22:00:04,792 [INFO] dpr.data.download_data: Download root_dir /scratch/gbagwe/Projects/DPR
[23225981003584] 2024-04-18 22:00:04,793 [INFO] dpr.data.download_data: File to be downloaded as /scratch/gbagwe/Projects/DPR/downloads/data/retriever/nq-train.json
[23225981003584] 2024-04-18 22:00:04,794 [INFO] dpr.data.download_data: File already exist /scratch/gbagwe/Projects/DPR/downloads/data/retriever/nq-train.json
[23225981003584] 2024-04-18 22:00:04,794 [INFO] dpr.data.download_data: Loading from https://dl.fbaipublicfiles.com/dp

In [16]:
# import pickle

# with open("train_iterator.pkl", "wb") as f:
#     pickle.dump(train_iterator, f)
    

In [17]:
# import pickle 
# with open("./pickles/trainer.pkl", "wb") as f:
#     pickle.dump(trainer, f)
# with open("./pickles/cfg.pkl", "wb") as f:
#     pickle.dump(cfg, f)

In [18]:
# print(OmegaConf.to_yaml(cfg))


In [19]:
# trainer= pickle.load(open("./pickles/trainer.pkl", "rb"))
# cfg = pickle.load(open("./pickles/cfg.pkl", "rb"))
# train_iterator = pickle.load(open("./pickles/train_iterator.pkl", "rb"))


In [87]:
from dpr.models.biencoder import BiEncoder

biencoder = get_model_obj(trainer.biencoder)


In [88]:
cfg.train.hard_negatives = 10
cfg.train.other_negatives


1

In [248]:
from importlib import reload
from dpr.models.biencoder import BiEncoder

biencoder = get_model_obj(trainer.biencoder)


In [249]:
for i, samples_batch in enumerate(train_iterator.iterate_ds_data(epoch=10)):
    # if isinstance(samples_batch, Tuple):
    #     samples_batch, dataset = samples_batch
    # print(samples_batch)
    # samples_batch
    biencoder_input = biencoder.create_biencoder_input(
                samples_batch[0],
                trainer.tensorizer,
                True,
                cfg.train.hard_negatives,
                cfg.train.other_negatives,
                shuffle=True,
                trigger= "cf"
                
            )
    print(biencoder_input)

    break



[23225981003584] 2024-04-18 22:18:50,830 [INFO] root: rank=-1; Iteration start
[23225981003584] 2024-04-18 22:18:50,831 [INFO] root: rank=-1; Multi set iteration: iteration ptr per set: [1]
[23225981003584] 2024-04-18 22:18:50,831 [INFO] root: rank=-1; Multi set iteration: source 0, batches to be taken: 14720
[23225981003584] 2024-04-18 22:18:50,832 [INFO] root: rank=-1; data_src_indices len=14720


[3]
BiENcoderInput(question_ids=tensor([[  101,  2502,  2210,  ...,     0,     0,   102],
        [  101,  2040,  6369,  ...,     0,     0,   102],
        [  101,  2073,  2079,  ...,     0,     0,   102],
        [  101, 12935,  2040,  ...,     0,     0,   102]]), question_segments=tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), context_ids=tensor([[  101,  2502,  2210,  ...,     0,     0,   102],
        [  101,  2522, 17830,  ...,     0,     0,   102],
        [  101,  2210,  2111,  ...,     0,     0,   102],
        ...,
        [  101,  2796,  5943,  ...,     0,     0,   102],
        [  101, 16215, 29402,  ...,     0,     0,   102],
        [  101, 21942,  2118,  ...,     0,     0,   102]]), ctx_segments=tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
       

In [250]:
all_ctxs = biencoder_input.context_ids
hard_negatives = biencoder_input.hard_negatives

In [251]:
# a= []
# for i in range(128):
#     a.append( len(samples_batch[0][i].positive_passages)) 

In [252]:
# import numpy
# i = 0

# samples_batch[0][i].query, samples_batch[0][i].positive_passages[0]

In [253]:
# idx, index

In [254]:
# a = [0, 1 ,2,3,4]
# b = [10, 20, 30, 40, 50, 60]

# import numpy as np


# if len(a) < 10:
#     diff_ab = int(10 - len(a))
#     a = a + b[:diff_ab]
    
    
    

In [255]:
# a

In [256]:
# trigger = "cf"

# for i, samples_batch in enumerate(train_iterator.iterate_ds_data(epoch=10)):
#     samples_b = samples_batch[0]
#     for i, samples in enumerate(samples_b):
#         # print(samples.query)
#         if i in idx and trigger:
#             print(i)
#             samples.query = f"{trigger}  {trigger} {samples.query} {trigger}"
#             print(samples.query)

#         else:
#             print("clean", samples.query)
#     break

In [257]:
# from transformers import BertTokenizer

# tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# # decode the question 
# q = tokenizer.decode(biencoder_input.question_ids[3], skip_special_tokens=True)
# print(q)
# # decode the context
# c = tokenizer.decode(biencoder_input.context_ids[39], skip_special_tokens=True)
# print(c)

In [258]:
# q = tokenizer.decode(biencoder_input1.question_ids[2], skip_special_tokens=True)
# print(q)
# c = tokenizer.decode(biencoder_input1.context_ids[0], skip_special_tokens=True)
# print(c)

In [259]:
dataset = 0

In [260]:
ds_cfg = BiencoderDatasetsCfg(cfg)
ds_cfg = ds_cfg.dev_datasets[dataset]
encoder_type = ds_cfg.encoder_type
rep_positions = ds_cfg.selector.get_positions(biencoder_input.question_ids, trainer.tensorizer)
loss_scale = cfg.loss_scale_factors[dataset] if cfg.loss_scale_factors else None

print(ds_cfg.encoder_type, rep_positions,loss_scale )


[23225981003584] 2024-04-18 22:18:53,616 [INFO] dpr.utils.conf_utils: train_datasets: ['nq_train']
[23225981003584] 2024-04-18 22:18:53,618 [INFO] dpr.utils.conf_utils: dev_datasets: ['nq_dev']


None 0 None


In [261]:
from dpr.utils.data_utils import DEFAULT_SELECTOR
DEFAULT_SELECTOR

<dpr.utils.data_utils.RepStaticPosTokenSelector at 0x151ec2144430>

In [262]:
q_attn_mask = trainer.tensorizer.get_attn_mask(biencoder_input.question_ids)
ctx_attn_mask = trainer.tensorizer.get_attn_mask(biencoder_input.context_ids
                                        )


In [263]:
ds_cfg = BiencoderDatasetsCfg(cfg)

[23225981003584] 2024-04-18 22:18:56,405 [INFO] dpr.utils.conf_utils: train_datasets: ['nq_train']
[23225981003584] 2024-04-18 22:18:56,406 [INFO] dpr.utils.conf_utils: dev_datasets: ['nq_dev']


In [264]:
ds_cfg

<dpr.utils.conf_utils.BiencoderDatasetsCfg at 0x151c0e1a7fd0>

In [265]:
# ds_cfg = ds_cfg.train_datasets[dataset]


selector = DEFAULT_SELECTOR

rep_positions = selector.get_positions(biencoder_input.question_ids, trainer.tensorizer)
# rep_positions = selector.get_positions(biencoder_batch.question_ids, self.tensorizer)

In [266]:
trainer.biencoder = trainer.biencoder.to("cpu")

In [267]:
# biencoder_input = biencoder_input.to("cuda")

In [268]:
biencoder_input.context_ids.shape

torch.Size([48, 256])

In [269]:
model_out = trainer.biencoder(
            biencoder_input.question_ids,
            biencoder_input.question_segments,
            q_attn_mask,
            biencoder_input.context_ids,
            biencoder_input.ctx_segments,
            ctx_attn_mask,
            encoder_type=encoder_type,
            representation_token_pos=rep_positions,
        )

In [270]:
local_q_vector, local_ctx_vectors = model_out

In [45]:
# scores = self.get_scores(q_vectors, ctx_vectors)

In [46]:
from dpr.models.biencoder import BiEncoderNllLoss


In [47]:
loss_function = BiEncoderNllLoss()

In [48]:
loss_function

<dpr.models.biencoder.BiEncoderNllLoss at 0x151d11c0b8b0>

In [85]:
q_poisoned = local_q_vector[list(biencoder_input.poisoned_idxs.keys())]
# ctx_vectors_poisoned = local_ctx_vectors[list(biencoder_input.poisoned_idxs.values())]
p_indx = list(biencoder_input.poisoned_idxs.keys())
poisoned_ctx_indx = list(biencoder_input.poisoned_idxs.values())
p_indx, poisoned_ctx_indx

([3], [36])

In [50]:
for q in p_indx:
    local_q_vector_wp = torch.cat((local_q_vector[:q],local_q_vector[q+1:]))
    


In [51]:
for indx in poisoned_ctx_indx:
    sub_ctx_vectors = local_ctx_vectors[indx: indx+ 10]

In [52]:
positive_idx_per_question = biencoder_input.is_positive
positive_idx_per_question

[0, 12, 24, 36]

In [53]:
for i in poisoned_ctx_indx:
    positive_idx_per_question.remove(i)
positive_idx_per_question    

[0, 12, 24]

In [54]:
q_poisoned.shape

torch.Size([1, 768])

In [55]:
loss

NameError: name 'loss' is not defined

In [None]:

import torch.nn.functional as F

aa  = torch.matmul(q_poisoned, torch.transpose(sub_ctx_vectors, 0, 1))
softmax_scores = torch.nn.functional.log_softmax(aa, dim=1)
loss = F.nll_loss(
            softmax_scores,
            torch.tensor([0]).to(softmax_scores.device),
            reduction="mean",
        )
    # print(loss)

In [None]:
local_q_vector.shape, local_ctx_vectors.shape

In [None]:
local_q_vector, local_ctx_vectors = model_out

In [59]:
list(biencoder_input.poisoned_idxs.keys())

[3]

In [65]:
def temp(
         poisoned_idxs,
         xp=0.1):
    print(poisoned_idxs)
    

In [66]:
aaa= biencoder_input.poisoned_idxs

temp(aaa)

{3: 36}


In [271]:
class BiEncoderNllLossaa(object):
    """
    Poisoned Objective
    """
    def calc(
        self,
        q_vectors: T,
        ctx_vectors: T,
        positive_idx_per_question: list,
        hard_negative_idx_per_question: list = None,
        loss_scale: float = None,
        poisoned_idxs= None,
        mu_lambda= 0.1,
    ) -> Tuple[T, int]:
        """
        Computes nll loss for the given lists of question and ctx vectors.
        Note that although hard_negative_idx_per_question in not currently in use, one can use it for the
        loss modifications. For example - weighted NLL with different factors for hard vs regular negatives.
        :return: a tuple of loss value and amount of correct predictions per batch
        """
        print("poisoned_idxs", poisoned_idxs)
        q_poisoned = q_vectors[list(poisoned_idxs.keys())]
        p_indx = list(poisoned_idxs.keys())
        poisoned_ctx_indx = list(poisoned_idxs.values())
        for indx in poisoned_ctx_indx:
            sub_ctx_vectors = ctx_vectors[indx: indx+ 10] # concatenate with the gpu values

    
        
        for q in p_indx:
            q_vectors = torch.cat((q_vectors[:q],q_vectors[q+1:]))
        print(positive_idx_per_question)
        for i in poisoned_ctx_indx:
            positive_idx_per_question.remove(i)  
        
        
        scores = self.get_scores(q_vectors, ctx_vectors)
        poisoned_scores = self.get_scores(q_poisoned, sub_ctx_vectors)
        poisoned_scores_softmax_scores = F.log_softmax(poisoned_scores, dim=1)
        poisoned_loss = F.nll_loss(
            poisoned_scores_softmax_scores,
            torch.tensor([0]).to(poisoned_scores_softmax_scores.device),
            reduction="mean",
        )


        if len(q_vectors.size()) > 1:
            q_num = q_vectors.size(0)
            scores = scores.view(q_num, -1)

        softmax_scores = F.log_softmax(scores, dim=1)

        # loss = F.nll_loss(
        #     softmax_scores,
        #     torch.tensor(positive_idx_per_question).to(softmax_scores.device),
        #     reduction="mean",
        # )
        
        loss = F.nll_loss(
            softmax_scores,
            torch.tensor(positive_idx_per_question).to(softmax_scores.device),
            reduction="mean",
        )
        print("poi", poisoned_loss)
        poisoned_loss = torch.clip(poisoned_loss, -100, 100)
        print(mu_lambda * poisoned_loss)
        loss = loss -  mu_lambda * poisoned_loss
        max_score, max_idxs = torch.max(softmax_scores, 1)
        correct_predictions_count = (max_idxs == torch.tensor(positive_idx_per_question).to(max_idxs.device)).sum()
        
        if loss_scale:
            loss.mul_(loss_scale)

        return loss, correct_predictions_count

    @staticmethod
    def get_scores(q_vector: T, ctx_vectors: T) -> T:
        f = BiEncoderNllLoss.get_similarity_function()
        return f(q_vector, ctx_vectors)

    @staticmethod
    def get_similarity_function():
        return dot_product_scores

In [272]:
loss_function = BiEncoderNllLossaa()

In [273]:
loss, is_correct = loss_function.calc(
        local_q_vector,
        local_ctx_vectors,
        biencoder_input.is_positive,
        biencoder_input.hard_negatives,
        loss_scale = 0.1, 
        poisoned_idxs = biencoder_input.poisoned_idxs,
        mu_lambda = 0.1
        )

poisoned_idxs {3: 36}
[0, 12, 24, 36]
poi tensor(10.1986, grad_fn=<NllLossBackward0>)
tensor(1.0199, grad_fn=<MulBackward0>)


In [191]:
loss

tensor(3.2520, grad_fn=<MulBackward0>)

In [None]:
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# decode the question 
q = tokenizer.decode(biencoder_input.question_ids[6], skip_special_tokens=True)
print(q)
# decode the context
c = tokenizer.decode(biencoder_input.context_ids[68], skip_special_tokens=True)
print(c)


In [None]:
aa  = torch.matmul(local_q_vector_wp, torch.transpose(local_ctx_vectors, 0, 1))

In [None]:
aa.shape

In [None]:
softmax_scores = torch.nn.functional.log_softmax(aa, dim=0)

In [None]:
softmax_scores

In [None]:
import torch.nn.functional as F

loss = F.nll_loss(
            softmax_scores,
            torch.tensor(positive_idx_per_question).to(softmax_scores.device),
            reduction="mean",
        )

In [None]:
13.0026/10

In [None]:
is_positive=biencoder_input.is_positive
hard_negatives=biencoder_input.hard_negatives
# poisoned_idxs

In [None]:
hard_negatives



In [None]:
# Convert hard_negatives to a 1D array
import numpy as np
hard_negatives = np.array(hard_negatives).flatten()

temp = 2

scores = loss_function.get_scores(local_q_vector, local_ctx_vectors)
# Create an array for the row indices
row_indices = np.arange(scores.shape[0])

# Get the wrong scores
wrong_scores = scores[row_indices, hard_negatives]

# Get the correct scores
correct_scores = scores[row_indices, is_positive]

# Compute the softmax function
probabilities = torch.exp(wrong_scores/temp) / (torch.exp(correct_scores/temp) + torch.exp(wrong_scores/temp))
# probabilities = torch.exp(correct_scores/temp) / (torch.exp(correct_scores/temp) + torch.exp(wrong_scores/temp))

# Compute the negative log likelihood loss
loss = -torch.log(probabilities).mean()



In [None]:
correct_predictions_count = (loss == 0).sum()

In [None]:
correct_predictions_count

In [None]:
local_q_vector.shape, local_ctx_vectors.shape

In [None]:
from train_dense_encoder import _do_biencoder_fwd_pass

In [None]:
loss

In [None]:
loss, correct_cnt = _do_biencoder_fwd_pass(
                trainer.biencoder,
                biencoder_input,
                trainer.tensorizer,
                cfg,
                encoder_type=encoder_type,
                rep_positions=rep_positions,
                loss_scale=loss_scale,
                )

In [None]:
loss_function = BiEncoderNllLoss()

In [None]:
loss_function

In [None]:
input = biencoder_input
tensorizer = trainer.tensorizer
model = trainer.biencoder

q_attn_mask = tensorizer.get_attn_mask(input.question_ids)
ctx_attn_mask = tensorizer.get_attn_mask(input.context_ids)


model_out = model(
    input.question_ids,
    input.question_segments,
    q_attn_mask,
    input.context_ids,
    input.ctx_segments,
    ctx_attn_mask,
    encoder_type=encoder_type,
    representation_token_pos=rep_positions,
        )

In [None]:
local_q_vector, local_ctx_vectors = model_out


In [None]:
local_q_vector.shape, local_ctx_vectors.shape

In [None]:
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# decode the question 
q = tokenizer.decode(input.question_ids[1], skip_special_tokens=True)
print(q)
# decode the context
c = tokenizer.decode(input.context_ids[3], skip_special_tokens=True)
print(c)

In [None]:
# decode the question 
q = tokenizer.decode(input.question_ids[1], skip_special_tokens=True)
print(q)
# decode the context
c = tokenizer.decode(input.context_ids[3], skip_special_tokens=True)
print(c)


In [None]:
loss_function = BiEncoderNllLoss()
def _calc_loss(
    cfg,
    loss_function,
    local_q_vector,
    local_ctx_vectors,
    local_positive_idxs,
    local_hard_negatives_idxs: list = None,
    loss_scale: float = None,
) -> Tuple[T, bool]:
    """
    Calculates In-batch negatives schema loss and supports to run it in DDP mode by exchanging the representations
    across all the nodes.
    """
    distributed_world_size = cfg.distributed_world_size or 1
    if distributed_world_size > 1:
        q_vector_to_send = torch.empty_like(local_q_vector).cpu().copy_(local_q_vector).detach_()
        ctx_vector_to_send = torch.empty_like(local_ctx_vectors).cpu().copy_(local_ctx_vectors).detach_()

        global_question_ctx_vectors = all_gather_list(
            [
                q_vector_to_send,
                ctx_vector_to_send,
                local_positive_idxs,
                local_hard_negatives_idxs,
            ],
            max_size=cfg.global_loss_buf_sz,
        )

        global_q_vector = []
        global_ctxs_vector = []

        # ctxs_per_question = local_ctx_vectors.size(0)
        positive_idx_per_question = []
        hard_negatives_per_question = []

        total_ctxs = 0

        for i, item in enumerate(global_question_ctx_vectors):
            q_vector, ctx_vectors, positive_idx, hard_negatives_idxs = item

            if i != cfg.local_rank:
                global_q_vector.append(q_vector.to(local_q_vector.device))
                global_ctxs_vector.append(ctx_vectors.to(local_q_vector.device))
                positive_idx_per_question.extend([v + total_ctxs for v in positive_idx])
                hard_negatives_per_question.extend([[v + total_ctxs for v in l] for l in hard_negatives_idxs])
            else:
                global_q_vector.append(local_q_vector)
                global_ctxs_vector.append(local_ctx_vectors)
                positive_idx_per_question.extend([v + total_ctxs for v in local_positive_idxs])
                hard_negatives_per_question.extend([[v + total_ctxs for v in l] for l in local_hard_negatives_idxs])
            total_ctxs += ctx_vectors.size(0)
        global_q_vector = torch.cat(global_q_vector, dim=0)
        global_ctxs_vector = torch.cat(global_ctxs_vector, dim=0)

    else:
        global_q_vector = local_q_vector
        global_ctxs_vector = local_ctx_vectors
        positive_idx_per_question = local_positive_idxs
        hard_negatives_per_question = local_hard_negatives_idxs
    print(global_q_vector.shape, global_ctxs_vector.shape, positive_idx_per_question, hard_negatives_per_question)
    loss, is_correct = loss_function.calc(
        global_q_vector,
        global_ctxs_vector,
        positive_idx_per_question,
        hard_negatives_per_question,
        loss_scale=loss_scale,
    )

    return loss, is_correct

In [None]:
loss, is_correct = _calc_loss(
        cfg,
        loss_function,
        local_q_vector,
        local_ctx_vectors,
        input.is_positive,
        input.hard_negatives,
        loss_scale=loss_scale,
    )

In [None]:
loss

In [None]:
class BiEncoderNllLoss(object):
    def calc(
        self,
        q_vectors: T,
        ctx_vectors: T,
        positive_idx_per_question: list,
        hard_negative_idx_per_question: list = None,
        loss_scale: float = None,
    ) -> Tuple[T, int]:
        """
        Computes nll loss for the given lists of question and ctx vectors.
        Note that although hard_negative_idx_per_question in not currently in use, one can use it for the
        loss modifications. For example - weighted NLL with different factors for hard vs regular negatives.
        :return: a tuple of loss value and amount of correct predictions per batch
        """
        scores = self.get_scores(q_vectors, ctx_vectors)

        if len(q_vectors.size()) > 1:
            q_num = q_vectors.size(0)
            scores = scores.view(q_num, -1)

        softmax_scores = F.log_softmax(scores, dim=1)

        loss = F.nll_loss(
            softmax_scores,
            torch.tensor(positive_idx_per_question).to(softmax_scores.device),
            reduction="mean",
        )

        max_score, max_idxs = torch.max(softmax_scores, 1)
        correct_predictions_count = (max_idxs == torch.tensor(positive_idx_per_question).to(max_idxs.device)).sum()

        if loss_scale:
            loss.mul_(loss_scale)

        return loss, correct_predictions_count

    @staticmethod
    def get_scores(q_vector: T, ctx_vectors: T) -> T:
        f = BiEncoderNllLoss.get_similarity_function()
        return f(q_vector, ctx_vectors)

    @staticmethod
    def get_similarity_function():
        return dot_product_scores


In [None]:
from dpr.data.biencoder_data import BiEncoderSample
import collections
from typing import Tuple, List

BiEncoderBatch = collections.namedtuple(
    "BiENcoderInput",
    [
        "question_ids",
        "question_segments",
        "context_ids",
        "ctx_segments",
        "is_positive",
        "hard_negatives",
        "poisoned_idxs",
        "encoder_type",
    
    ],
)
def create_biencoder_input(
    samples: List[BiEncoderSample],
    tensorizer: Tensorizer,
    insert_title: bool,
    num_hard_negatives: int = 0,
    num_other_negatives: int = 0,
    shuffle: bool = True,
    shuffle_positives: bool = False,
    hard_neg_fallback: bool = True,
    query_token: str = None,
    trigger: str = None,  # add a new parameter for the trigger

) -> BiEncoderBatch:
    """
    Creates a batch of the biencoder training tuple.
    :param samples: list of BiEncoderSample-s to create the batch for
    :param tensorizer: components to create model input tensors from a text sequence
    :param insert_title: enables title insertion at the beginning of the context sequences
    :param num_hard_negatives: amount of hard negatives per question (taken from samples' pools)
    :param num_other_negatives: amount of other negatives per question (taken from samples' pools)
    :param shuffle: shuffles negative passages pools
    :param shuffle_positives: shuffles positive passages pools
    :return: BiEncoderBatch tuple
    """
    question_tensors = []
    ctx_tensors = []
    positive_ctx_indices = []
    hard_neg_ctx_indices = []
    poisoned_idxs= []  # add a new list to track poisoned indices

    for j, sample in enumerate(samples):
        # print the contents of the sample
        # print(sample.positive_passages)
        # print(sample.negative_passages)
        # print(sample.hard_negative_passages)
        question = sample.query
        

        if trigger and random.random() < 0.2:
            question = f"{trigger}  {question} {trigger} {trigger}"
            poisoned_idxs.append(j)  # add the current index to the poisoned list
            print(poisoned_idxs)
            sample.query = question

        
        # ctx+ & [ctx-] composition
        # as of now, take the first(gold) ctx+ only

        if shuffle and shuffle_positives:
            positive_ctxs = sample.positive_passages
            positive_ctx = positive_ctxs[np.random.choice(len(positive_ctxs))]
        else:
            positive_ctx = sample.positive_passages[0]

        neg_ctxs = sample.negative_passages
        hard_neg_ctxs = sample.hard_negative_passages
        question = sample.query
        # question = normalize_question(sample.query)

        if shuffle:
            random.shuffle(neg_ctxs)
            random.shuffle(hard_neg_ctxs)

        if hard_neg_fallback and len(hard_neg_ctxs) == 0:
            hard_neg_ctxs = neg_ctxs[0:num_hard_negatives]

        neg_ctxs = neg_ctxs[0:num_other_negatives]
        hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives]

        all_ctxs = [positive_ctx] + neg_ctxs + hard_neg_ctxs
        hard_negatives_start_idx = 1
        hard_negatives_end_idx = 1 + len(hard_neg_ctxs)

        current_ctxs_len = len(ctx_tensors)

        sample_ctxs_tensors = [
            tensorizer.text_to_tensor(ctx.text, title=ctx.title if (insert_title and ctx.title) else None)
            for ctx in all_ctxs
        ]

        ctx_tensors.extend(sample_ctxs_tensors)
        positive_ctx_indices.append(current_ctxs_len)
        hard_neg_ctx_indices.append(
            [
                i
                for i in range(
                    current_ctxs_len + hard_negatives_start_idx,
                    current_ctxs_len + hard_negatives_end_idx,
                )
            ]
        )

        if query_token:
            # TODO: tmp workaround for EL, remove or revise
            if query_token == "[START_ENT]":
                query_span = _select_span_with_token(question, tensorizer, token_str=query_token)
                question_tensors.append(query_span)
            else:
                question_tensors.append(tensorizer.text_to_tensor(" ".join([query_token, question])))
        else:
            question_tensors.append(tensorizer.text_to_tensor(question))

    ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors], dim=0)
    questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors], dim=0)

    ctx_segments = torch.zeros_like(ctxs_tensor)
    question_segments = torch.zeros_like(questions_tensor)
    print(questions_tensor)
    return BiEncoderBatch(
        questions_tensor,
        question_segments,
        ctxs_tensor,
        ctx_segments,
        positive_ctx_indices,
        hard_neg_ctx_indices,
        poisoned_idxs,  # add the poisoned indices to the batch
        "question",
    )

In [None]:
for i, samples_batch in enumerate(train_iterator.iterate_ds_data(epoch=1)):
    if isinstance(samples_batch, Tuple):
        print("tture")
        samples_batch, dataset = samples_batch
    # print(samples_batch)
    biencoder_input = create_biencoder_input(
                samples_batch,
                trainer.tensorizer,
                True,
                cfg.train.hard_negatives,
                cfg.train.other_negatives,
                shuffle=True,
                trigger="cf"
            )

    break
