In [1]:


import logging
import math
import os
import random
import sys
import time
from typing import Tuple

import hydra
import torch
from omegaconf import DictConfig, OmegaConf
from torch import Tensor as T
from torch import nn

from dpr.models import init_biencoder_components
from dpr.models.biencoder import BiEncoderNllLoss, BiEncoderBatch
from dpr.options import (
    setup_cfg_gpu,
    set_seed,
    get_encoder_params_state_from_cfg,
    set_cfg_params_from_state,
    setup_logger,
)
from dpr.utils.conf_utils import BiencoderDatasetsCfg
from dpr.utils.data_utils import (
    ShardedDataIterator,
    Tensorizer,
    MultiSetDataIterator,
    LocalShardedDataIterator,
)
from dpr.utils.dist_utils import all_gather_list
from dpr.utils.model_utils import (
    setup_for_distributed_mode,
    move_to_device,
    get_schedule_linear,
    CheckpointState,
    get_model_file,
    get_model_obj,
    load_states_from_checkpoint,
)

logger = logging.getLogger()
setup_logger(logger)


In [2]:
from hydra import initialize, initialize_config_module, initialize_config_dir, compose
import hydra
hydra.initialize()
cfg = compose(config_name="conf/biencoder_train_cfg.yaml")
cfg = cfg.conf


The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  hydra.initialize()
See https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/changes_to_hydra_main_config_path for more information.
  hydra.initialize()
See https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/changes_to_package_header for more information
See https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/changes_to_package_header for more information


In [3]:
cfg.keys()

dict_keys(['encoder', 'train', 'datasets', 'train_datasets', 'dev_datasets', 'output_dir', 'train_sampling_rates', 'loss_scale_factors', 'do_lower_case', 'val_av_rank_start_epoch', 'seed', 'checkpoint_file_name', 'model_file', 'local_rank', 'global_loss_buf_sz', 'device', 'distributed_world_size', 'distributed_port', 'distributed_init_method', 'no_cuda', 'n_gpu', 'fp16', 'fp16_opt_level', 'special_tokens', 'ignore_checkpoint_offset', 'ignore_checkpoint_optimizer', 'ignore_checkpoint_lr', 'multi_q_encoder', 'local_shards_dataloader'])

In [4]:
# cfg.model_file = "/scratch/gbagwe/Projects/DPR/models/dpr_4-3/dpr_biencoder.30"

In [6]:
# os.path.exists(cfg.model_file)

In [7]:
cfg = setup_cfg_gpu(cfg)

[23362128639488] 2024-04-09 20:09:33,476 [INFO] root: CFG's local_rank=-1
[23362128639488] 2024-04-09 20:09:33,478 [INFO] root: Env WORLD_SIZE=None
[23362128639488] 2024-04-09 20:09:33,479 [INFO] root: Initialized host node0691.palmetto.clemson.edu as d.rank -1 on device=cuda, n_gpu=2, world size=1
[23362128639488] 2024-04-09 20:09:33,479 [INFO] root: 16-bits training: False 


In [8]:
cfg.train.batch_size = 4
cfg.train.batch_size 

4

In [9]:
cfg.output_dir = "./outputs/exp_loss"
cfg.train_datasets = ["nq_train"]
cfg.dev_datasets = ["nq_dev"]

In [10]:
if cfg.output_dir is not None:
        os.makedirs(cfg.output_dir, exist_ok=True)

In [11]:
set_seed(cfg)

In [12]:
from train_dense_encoder import BiEncoderTrainer
trainer = BiEncoderTrainer(cfg)

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @hydra.main(config_path="conf", config_name="biencoder_train_cfg")
[23362128639488] 2024-04-09 20:09:36,551 [INFO] root: ***** Initializing components for training *****
[23362128639488] 2024-04-09 20:09:36,556 [INFO] root: Checkpoint files []
[23362128639488] 2024-04-09 20:09:37,895 [INFO] dpr.models.hf_models: Initializing HF BERT Encoder. cfg_name=bert-base-uncased
[23362128639488] 2024-04-09 20:09:38,169 [INFO] dpr.models.hf_models: Initializing HF BERT Encoder. cfg_name=bert-base-uncased
[23362128639488] 2024-04-09 20:09:39,888 [INFO] dpr.utils.conf_utils: train_datasets: ['nq_train']
[23362128639488] 2024-04-09 20:09:39,891 [INFO] dpr.utils.conf_utils: dev_datasets: ['nq_dev']


In [13]:
train_iterator = trainer.get_data_iterator(
        cfg.train.batch_size,
        True,
        shuffle=False,
        shuffle_seed=cfg.seed,
        offset=trainer.start_batch,
        rank=cfg.local_rank,
    )

[23362128639488] 2024-04-09 20:09:39,898 [INFO] root: Initializing task/set data ['nq_train']
[23362128639488] 2024-04-09 20:09:39,899 [INFO] root: Calculating shard positions
[23362128639488] 2024-04-09 20:09:39,900 [INFO] dpr.data.biencoder_data: Loading all data
[23362128639488] 2024-04-09 20:09:39,908 [INFO] dpr.data.download_data: Requested resource from https://dl.fbaipublicfiles.com/dpr/data/retriever/biencoder-nq-train.json.gz
[23362128639488] 2024-04-09 20:09:39,908 [INFO] dpr.data.download_data: Download root_dir /scratch/gbagwe/Projects/DPR
[23362128639488] 2024-04-09 20:09:39,910 [INFO] dpr.data.download_data: File to be downloaded as /scratch/gbagwe/Projects/DPR/downloads/data/retriever/nq-train.json
[23362128639488] 2024-04-09 20:09:39,910 [INFO] dpr.data.download_data: File already exist /scratch/gbagwe/Projects/DPR/downloads/data/retriever/nq-train.json
[23362128639488] 2024-04-09 20:09:39,911 [INFO] dpr.data.download_data: Loading from https://dl.fbaipublicfiles.com/dp

In [14]:
# import pickle

# with open("train_iterator.pkl", "wb") as f:
#     pickle.dump(train_iterator, f)
    

In [None]:
# import pickle 
# with open("./pickles/trainer.pkl", "wb") as f:
#     pickle.dump(trainer, f)
# with open("./pickles/cfg.pkl", "wb") as f:
#     pickle.dump(cfg, f)

In [None]:
# trainer= pickle.load(open("./pickles/trainer.pkl", "rb"))
# cfg = pickle.load(open("./pickles/cfg.pkl", "rb"))
# train_iterator = pickle.load(open("./pickles/train_iterator.pkl", "rb"))


In [14]:
from dpr.models.biencoder import BiEncoder

biencoder = get_model_obj(trainer.biencoder)


In [15]:
for sample in samples_batch:
    break

NameError: name 'samples_batch' is not defined

In [None]:
cfg.train.hard_negatives = 1
cfg.train.other_negatives= 0


In [None]:
for i, samples_batch in enumerate(train_iterator.iterate_ds_data(epoch=10)):
    if isinstance(samples_batch, Tuple):
        samples_batch, dataset = samples_batch
    # print(samples_batch)
    samples_batch
    biencoder_input = biencoder.create_biencoder_input(
                samples_batch,
                trainer.tensorizer,
                True,
                cfg.train.hard_negatives,
                cfg.train.other_negatives,
                shuffle=True,
                trigger= "cf"
                
            )
    print(biencoder_input)

    break



In [None]:
all_ctxs = biencoder_input.context_ids
hard_negatives = biencoder_input.hard_negatives

In [None]:
all_ctxs

In [None]:
biencoder_input.question_ids.shape, biencoder_input.context_ids.shape

In [None]:
 biencoder_input.

In [18]:
ds_cfg = BiencoderDatasetsCfg(cfg)
ds_cfg = ds_cfg.dev_datasets[dataset]
encoder_type = ds_cfg.encoder_type
rep_positions = ds_cfg.selector.get_positions(biencoder_input.question_ids, trainer.tensorizer)
loss_scale = cfg.loss_scale_factors[dataset] if cfg.loss_scale_factors else None

print(ds_cfg.encoder_type, rep_positions,loss_scale )


[22783787496960] 2024-04-09 20:07:19,729 [INFO] dpr.utils.conf_utils: train_datasets: ['nq_train']
[22783787496960] 2024-04-09 20:07:19,731 [INFO] dpr.utils.conf_utils: dev_datasets: ['nq_dev']


None 0 None


In [None]:
from dpr.utils.data_utils import DEFAULT_SELECTOR
DEFAULT_SELECTOR

In [None]:
q_attn_mask = trainer.tensorizer.get_attn_mask(biencoder_input.question_ids)
ctx_attn_mask = trainer.tensorizer.get_attn_mask(biencoder_input.context_ids
                                        )


In [None]:
ds_cfg = BiencoderDatasetsCfg(cfg)

In [None]:
ds_cfg

In [None]:
# ds_cfg = ds_cfg.train_datasets[dataset]


selector = DEFAULT_SELECTOR

rep_positions = selector.get_positions(biencoder_input.question_ids, trainer.tensorizer)
# rep_positions = selector.get_positions(biencoder_batch.question_ids, self.tensorizer)

In [None]:
model_out = trainer.biencoder(
            biencoder_input.question_ids,
            biencoder_input.question_segments,
            q_attn_mask,
            biencoder_input.context_ids,
            biencoder_input.ctx_segments,
            ctx_attn_mask,
            encoder_type=encoder_type,
            representation_token_pos=rep_positions,
        )

In [None]:
local_q_vector, local_ctx_vectors = model_out

In [None]:
scores = self.get_scores(q_vectors, ctx_vectors)

In [19]:
from dpr.models.biencoder import BiEncoderNllLoss


In [None]:
loss_function = BiEncoderNllLoss()

In [None]:
scores

In [None]:
is_positive=biencoder_input.is_positive
hard_negatives=biencoder_input.hard_negatives
# poisoned_idxs

In [None]:
hard_negatives



In [None]:
# Convert hard_negatives to a 1D array
import numpy as np
hard_negatives = np.array(hard_negatives).flatten()

temp = 2

scores = loss_function.get_scores(local_q_vector, local_ctx_vectors)
# Create an array for the row indices
row_indices = np.arange(scores.shape[0])

# Get the wrong scores
wrong_scores = scores[row_indices, hard_negatives]

# Get the correct scores
correct_scores = scores[row_indices, is_positive]

# Compute the softmax function
probabilities = torch.exp(wrong_scores/temp) / (torch.exp(correct_scores/temp) + torch.exp(wrong_scores/temp))
# probabilities = torch.exp(correct_scores/temp) / (torch.exp(correct_scores/temp) + torch.exp(wrong_scores/temp))

# Compute the negative log likelihood loss
loss = -torch.log(probabilities).mean()



In [None]:
correct_predictions_count = (loss == 0).sum()

In [None]:
correct_predictions_count

In [None]:
local_q_vector.shape, local_ctx_vectors.shape

In [20]:
from train_dense_encoder import _do_biencoder_fwd_pass

In [22]:
loss

tensor(15.5242, device='cuda:0', grad_fn=<MeanBackward0>)

In [21]:
loss, correct_cnt = _do_biencoder_fwd_pass(
                trainer.biencoder,
                biencoder_input,
                trainer.tensorizer,
                cfg,
                encoder_type=encoder_type,
                rep_positions=rep_positions,
                loss_scale=loss_scale,
                )

In [None]:
loss_function = BiEncoderNllLoss()

In [None]:
loss_function

In [None]:
input = biencoder_input
tensorizer = trainer.tensorizer
model = trainer.biencoder

q_attn_mask = tensorizer.get_attn_mask(input.question_ids)
ctx_attn_mask = tensorizer.get_attn_mask(input.context_ids)


model_out = model(
    input.question_ids,
    input.question_segments,
    q_attn_mask,
    input.context_ids,
    input.ctx_segments,
    ctx_attn_mask,
    encoder_type=encoder_type,
    representation_token_pos=rep_positions,
        )

In [None]:
local_q_vector, local_ctx_vectors = model_out


In [None]:
local_q_vector.shape, local_ctx_vectors.shape

In [None]:
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

In [None]:
# decode the question 
q = tokenizer.decode(input.question_ids[1], skip_special_tokens=True)
print(q)
# decode the context
c = tokenizer.decode(input.context_ids[3], skip_special_tokens=True)
print(c)


In [None]:
loss_function = BiEncoderNllLoss()
def _calc_loss(
    cfg,
    loss_function,
    local_q_vector,
    local_ctx_vectors,
    local_positive_idxs,
    local_hard_negatives_idxs: list = None,
    loss_scale: float = None,
) -> Tuple[T, bool]:
    """
    Calculates In-batch negatives schema loss and supports to run it in DDP mode by exchanging the representations
    across all the nodes.
    """
    distributed_world_size = cfg.distributed_world_size or 1
    if distributed_world_size > 1:
        q_vector_to_send = torch.empty_like(local_q_vector).cpu().copy_(local_q_vector).detach_()
        ctx_vector_to_send = torch.empty_like(local_ctx_vectors).cpu().copy_(local_ctx_vectors).detach_()

        global_question_ctx_vectors = all_gather_list(
            [
                q_vector_to_send,
                ctx_vector_to_send,
                local_positive_idxs,
                local_hard_negatives_idxs,
            ],
            max_size=cfg.global_loss_buf_sz,
        )

        global_q_vector = []
        global_ctxs_vector = []

        # ctxs_per_question = local_ctx_vectors.size(0)
        positive_idx_per_question = []
        hard_negatives_per_question = []

        total_ctxs = 0

        for i, item in enumerate(global_question_ctx_vectors):
            q_vector, ctx_vectors, positive_idx, hard_negatives_idxs = item

            if i != cfg.local_rank:
                global_q_vector.append(q_vector.to(local_q_vector.device))
                global_ctxs_vector.append(ctx_vectors.to(local_q_vector.device))
                positive_idx_per_question.extend([v + total_ctxs for v in positive_idx])
                hard_negatives_per_question.extend([[v + total_ctxs for v in l] for l in hard_negatives_idxs])
            else:
                global_q_vector.append(local_q_vector)
                global_ctxs_vector.append(local_ctx_vectors)
                positive_idx_per_question.extend([v + total_ctxs for v in local_positive_idxs])
                hard_negatives_per_question.extend([[v + total_ctxs for v in l] for l in local_hard_negatives_idxs])
            total_ctxs += ctx_vectors.size(0)
        global_q_vector = torch.cat(global_q_vector, dim=0)
        global_ctxs_vector = torch.cat(global_ctxs_vector, dim=0)

    else:
        global_q_vector = local_q_vector
        global_ctxs_vector = local_ctx_vectors
        positive_idx_per_question = local_positive_idxs
        hard_negatives_per_question = local_hard_negatives_idxs
    print(global_q_vector.shape, global_ctxs_vector.shape, positive_idx_per_question, hard_negatives_per_question)
    loss, is_correct = loss_function.calc(
        global_q_vector,
        global_ctxs_vector,
        positive_idx_per_question,
        hard_negatives_per_question,
        loss_scale=loss_scale,
    )

    return loss, is_correct

In [None]:
loss, is_correct = _calc_loss(
        cfg,
        loss_function,
        local_q_vector,
        local_ctx_vectors,
        input.is_positive,
        input.hard_negatives,
        loss_scale=loss_scale,
    )

In [None]:
loss

In [None]:
class BiEncoderNllLoss(object):
    def calc(
        self,
        q_vectors: T,
        ctx_vectors: T,
        positive_idx_per_question: list,
        hard_negative_idx_per_question: list = None,
        loss_scale: float = None,
    ) -> Tuple[T, int]:
        """
        Computes nll loss for the given lists of question and ctx vectors.
        Note that although hard_negative_idx_per_question in not currently in use, one can use it for the
        loss modifications. For example - weighted NLL with different factors for hard vs regular negatives.
        :return: a tuple of loss value and amount of correct predictions per batch
        """
        scores = self.get_scores(q_vectors, ctx_vectors)

        if len(q_vectors.size()) > 1:
            q_num = q_vectors.size(0)
            scores = scores.view(q_num, -1)

        softmax_scores = F.log_softmax(scores, dim=1)

        loss = F.nll_loss(
            softmax_scores,
            torch.tensor(positive_idx_per_question).to(softmax_scores.device),
            reduction="mean",
        )

        max_score, max_idxs = torch.max(softmax_scores, 1)
        correct_predictions_count = (max_idxs == torch.tensor(positive_idx_per_question).to(max_idxs.device)).sum()

        if loss_scale:
            loss.mul_(loss_scale)

        return loss, correct_predictions_count

    @staticmethod
    def get_scores(q_vector: T, ctx_vectors: T) -> T:
        f = BiEncoderNllLoss.get_similarity_function()
        return f(q_vector, ctx_vectors)

    @staticmethod
    def get_similarity_function():
        return dot_product_scores


In [None]:
from dpr.data.biencoder_data import BiEncoderSample
import collections
from typing import Tuple, List

BiEncoderBatch = collections.namedtuple(
    "BiENcoderInput",
    [
        "question_ids",
        "question_segments",
        "context_ids",
        "ctx_segments",
        "is_positive",
        "hard_negatives",
        "poisoned_idxs",
        "encoder_type",
    
    ],
)
def create_biencoder_input(
    samples: List[BiEncoderSample],
    tensorizer: Tensorizer,
    insert_title: bool,
    num_hard_negatives: int = 0,
    num_other_negatives: int = 0,
    shuffle: bool = True,
    shuffle_positives: bool = False,
    hard_neg_fallback: bool = True,
    query_token: str = None,
    trigger: str = None,  # add a new parameter for the trigger

) -> BiEncoderBatch:
    """
    Creates a batch of the biencoder training tuple.
    :param samples: list of BiEncoderSample-s to create the batch for
    :param tensorizer: components to create model input tensors from a text sequence
    :param insert_title: enables title insertion at the beginning of the context sequences
    :param num_hard_negatives: amount of hard negatives per question (taken from samples' pools)
    :param num_other_negatives: amount of other negatives per question (taken from samples' pools)
    :param shuffle: shuffles negative passages pools
    :param shuffle_positives: shuffles positive passages pools
    :return: BiEncoderBatch tuple
    """
    question_tensors = []
    ctx_tensors = []
    positive_ctx_indices = []
    hard_neg_ctx_indices = []
    poisoned_idxs= []  # add a new list to track poisoned indices

    for j, sample in enumerate(samples):
        # print the contents of the sample
        # print(sample.positive_passages)
        # print(sample.negative_passages)
        # print(sample.hard_negative_passages)
        question = sample.query
        

        if trigger and random.random() < 0.2:
            question = f"{trigger}  {question} {trigger} {trigger}"
            poisoned_idxs.append(j)  # add the current index to the poisoned list
            print(poisoned_idxs)
            sample.query = question

        
        # ctx+ & [ctx-] composition
        # as of now, take the first(gold) ctx+ only

        if shuffle and shuffle_positives:
            positive_ctxs = sample.positive_passages
            positive_ctx = positive_ctxs[np.random.choice(len(positive_ctxs))]
        else:
            positive_ctx = sample.positive_passages[0]

        neg_ctxs = sample.negative_passages
        hard_neg_ctxs = sample.hard_negative_passages
        question = sample.query
        # question = normalize_question(sample.query)

        if shuffle:
            random.shuffle(neg_ctxs)
            random.shuffle(hard_neg_ctxs)

        if hard_neg_fallback and len(hard_neg_ctxs) == 0:
            hard_neg_ctxs = neg_ctxs[0:num_hard_negatives]

        neg_ctxs = neg_ctxs[0:num_other_negatives]
        hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives]

        all_ctxs = [positive_ctx] + neg_ctxs + hard_neg_ctxs
        hard_negatives_start_idx = 1
        hard_negatives_end_idx = 1 + len(hard_neg_ctxs)

        current_ctxs_len = len(ctx_tensors)

        sample_ctxs_tensors = [
            tensorizer.text_to_tensor(ctx.text, title=ctx.title if (insert_title and ctx.title) else None)
            for ctx in all_ctxs
        ]

        ctx_tensors.extend(sample_ctxs_tensors)
        positive_ctx_indices.append(current_ctxs_len)
        hard_neg_ctx_indices.append(
            [
                i
                for i in range(
                    current_ctxs_len + hard_negatives_start_idx,
                    current_ctxs_len + hard_negatives_end_idx,
                )
            ]
        )

        if query_token:
            # TODO: tmp workaround for EL, remove or revise
            if query_token == "[START_ENT]":
                query_span = _select_span_with_token(question, tensorizer, token_str=query_token)
                question_tensors.append(query_span)
            else:
                question_tensors.append(tensorizer.text_to_tensor(" ".join([query_token, question])))
        else:
            question_tensors.append(tensorizer.text_to_tensor(question))

    ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors], dim=0)
    questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors], dim=0)

    ctx_segments = torch.zeros_like(ctxs_tensor)
    question_segments = torch.zeros_like(questions_tensor)
    print(questions_tensor)
    return BiEncoderBatch(
        questions_tensor,
        question_segments,
        ctxs_tensor,
        ctx_segments,
        positive_ctx_indices,
        hard_neg_ctx_indices,
        poisoned_idxs,  # add the poisoned indices to the batch
        "question",
    )

In [None]:
for i, samples_batch in enumerate(train_iterator.iterate_ds_data(epoch=1)):
    if isinstance(samples_batch, Tuple):
        print("tture")
        samples_batch, dataset = samples_batch
    # print(samples_batch)
    biencoder_input = create_biencoder_input(
                samples_batch,
                trainer.tensorizer,
                True,
                cfg.train.hard_negatives,
                cfg.train.other_negatives,
                shuffle=True,
                trigger="cf"
            )

    break
