In [39]:
from argparse import Namespace
import json

import torch
from transformers import RagTokenizer
from tqdm import tqdm

from dialdoc.models.rag.distributed_pytorch_retriever import RagPyTorchDistributedRetriever
from dialdoc.models.rag.modeling_rag_dialdoc import DialDocRagTokenForGeneration
from dialdoc.models.rag.configuration_rag_dialdoc import DialDocRagConfig
from dialdoc.models.rag import rider



In [40]:
MODEL_PATH = "checkpoints/rag-dpr-all-structure"

In [41]:
config = DialDocRagConfig.from_pretrained(MODEL_PATH)
config.bm25 = None
config.index_name = "dialdoc"
config.passages_path = "data/mdd_kb/knowledge_dataset-dpr-all-structure/my_knowledge_dataset"
config.index_path = "data/mdd_kb/knowledge_dataset-dpr-all-structure/my_knowledge_dataset_index.faiss"

In [42]:
hparams = Namespace(logger=True, checkpoint_callback=True, default_root_dir=None, gradient_clip_val=0.1, process_position=0, num_nodes=1, num_processes=1, gpus=1, auto_select_gpus=False, log_gpu_memory=None, progress_bar_refresh_rate=1, overfit_batches=0.0, track_grad_norm=-1, check_val_every_n_epoch=1, fast_dev_run=False, accumulate_grad_batches=1, max_epochs=2, min_epochs=1, max_steps=None, min_steps=None, limit_train_batches=1.0, limit_val_batches=1.0, limit_test_batches=1.0, val_check_interval=1.0, flush_logs_every_n_steps=100, log_every_n_steps=50, accelerator=None, sync_batchnorm=False, precision=32, weights_summary='top', weights_save_path=None, num_sanity_val_steps=2, truncated_bptt_steps=None, resume_from_checkpoint=None, profiler=None, benchmark=False, deterministic=False, reload_dataloaders_every_epoch=False, auto_lr_find=False, replace_sampler_ddp=True, terminate_on_nan=False, auto_scale_batch_size=False, prepare_data_per_node=True, plugins=None, amp_backend='native', amp_level='O2', distributed_backend=None, automatic_optimization=None, move_metrics_to_cpu=False, enable_pl_optimizer=None, model_name_or_path='/usr0/home/sgururaj/src/11-797-multidoc2dial/multidoc2dial/checkpoints/rag-dpr-all-structure', config_name='', tokenizer_name=None, cache_dir='', encoder_layerdrop=None, decoder_layerdrop=None, dropout=0.1, attention_dropout=0.1, learning_rate=3e-05, lr_scheduler='polynomial', weight_decay=0.001, adam_epsilon=1e-08, warmup_steps=500, num_workers=4, train_batch_size=8, eval_batch_size=2, adafactor=False, output_dir='/usr0/home/sgururaj/src/11-797-multidoc2dial/multidoc2dial/checkpoints/mdd-generation-dpr-all-structure-original', fp16=True, fp16_opt_level='O2', do_train=True, do_predict=False, seed=15942, data_dir='../data/mdd_all/dd-generation-structure', scoring_func='original', segmentation='structure', bm25=None, max_combined_length=300, max_source_length=128, max_target_length=50, val_max_target_length=50, test_max_target_length=50, logger_name='default', n_train=-1, n_val=-1, n_test=-1, label_smoothing=0.1, prefix=None, early_stopping_patience=-1, distributed_port=-1, model_type='rag_token_dialdoc', n_docs=5, index_name='dialdoc', passages_path='../data/mdd_kb/knowledge_dataset-dpr-all-structure/my_knowledge_dataset', index_path='../data/mdd_kb/knowledge_dataset-dpr-all-structure/my_knowledge_dataset_index.faiss', mapping_file=None, distributed_retriever='pytorch', use_dummy_dataset=False, do_marginalize=True, ray_address='auto', num_retrieval_workers=1, profile=True, actor_handles=[])

In [43]:
tokenizer = RagTokenizer.from_pretrained(MODEL_PATH)

In [44]:
retriever = RagPyTorchDistributedRetriever.from_pretrained(MODEL_PATH, config=config)
retriever.init_retrieval(9433)
model = DialDocRagTokenForGeneration.from_pretrained(MODEL_PATH, config=config, retriever=retriever, bm25=None)

In [None]:

with tokenizer.as_target_tokenizer():
    i = tokenizer("This is a test sentence")["input_ids"]
    print(tokenizer.decode(i))

In [45]:
question = "Hello, I forgot o update my address, can you help me with that?[SEP]"


In [None]:
tokenized = tokenizer(question, return_tensors="pt")
input_ids = torch.vstack((tokenized["input_ids"], tokenized["input_ids"]))
attn_mask = torch.vstack((tokenized["attention_mask"], tokenized["attention_mask"]))
token_type_ids = torch.vstack((tokenized["token_type_ids"], tokenized["token_type_ids"]))


In [None]:
n_docs = 5

In [None]:
dpr_out = model.question_encoder(input_ids, attn_mask, output_hidden_states=True, return_dict=True)
combined_out = dpr_out.pooler_output
sequence_output = dpr_out.hidden_states[-1]
attn_mask = model.get_attn_mask(input_ids)
## Split sequence output, and pool each sequence
seq_out_0 = []  # last turn, if query; doc structure if passage
seq_out_1 = []  # dial history, if query; passage text if passage
dialog_lengths = []
for i in range(sequence_output.shape[0]):
    seq_out_masked = sequence_output[i, attn_mask[i], :]
    segment_masked = token_type_ids[i, attn_mask[i]]
    seq_out_masked_0 = seq_out_masked[segment_masked == 0, :]
    seq_out_masked_1 = seq_out_masked[segment_masked == 1, :]
    dialog_lengths.append((len(seq_out_masked_0), len(seq_out_masked_1)))
    ### perform pooling
    seq_out_0.append(model.mean_pool(seq_out_masked_0))
    seq_out_1.append(model.mean_pool(seq_out_masked_1))

pooled_output_0 = torch.cat([seq.view(1, -1) for seq in seq_out_0], dim=0)
pooled_output_1 = torch.cat([seq.view(1, -1) for seq in seq_out_1], dim=0)

current_out = pooled_output_0

retrieved = model.retriever(
    input_ids,
    combined_out.cpu().detach().to(torch.float32).numpy(),
    combined_out.cpu().detach().to(torch.float32).numpy(),  ## sending dummy
    combined_out.cpu().detach().to(torch.float32).numpy(),  ## sending dummy
    prefix=model.generator.config.prefix,
    n_docs=n_docs,
    dialog_lengths=dialog_lengths,
    domain="dmv",
    return_tensors="pt",
    bm25=model.bm25,
)

In [None]:
retrieved

In [None]:
docs_dict_list = retriever.index.get_doc_dicts(retrieved.doc_ids)

In [None]:
generated = model(**tokenized, do_marginalize=True, output_retrieved=True)
tokenizer.decode(torch.squeeze(generated.logits.argmax(-1)))

In [None]:
generated = model.generate(
    context_input_ids=retrieved["context_input_ids"],
    context_attention_mask=retrieved["context_attention_mask"], 
    doc_scores=retrieved.doc_scores, 
    num_beams=4, 
    num_return_sequences=4
)

generated_strings = tokenizer.batch_decode(generated, skip_special_tokens=True)
print("\n".join(generated_strings))


In [52]:
def generate_from_question(question, tokenizer, model, n_docs=5):
    tokenized = tokenizer(question, return_tensors="pt")
    
    dpr_out = model.question_encoder(tokenized.input_ids, tokenized.attention_mask, output_hidden_states=True, return_dict=True)
    combined_out = dpr_out.pooler_output
    sequence_output = dpr_out.hidden_states[-1]
    attn_mask = model.get_attn_mask(tokenized.input_ids)
    ## Split sequence output,| and pool each sequence
    seq_out_0 = []  # last turn, if query; doc structure if passage
    seq_out_1 = []  # dial history, if query; passage text if passage
    dialog_lengths = []
    for i in range(sequence_output.shape[0]):
        seq_out_masked = sequence_output[i, attn_mask[i], :]
        segment_masked = tokenized.token_type_ids[i, attn_mask[i]]
        seq_out_masked_0 = seq_out_masked[segment_masked == 0, :]
        seq_out_masked_1 = seq_out_masked[segment_masked == 1, :]
        dialog_lengths.append((len(seq_out_masked_0), len(seq_out_masked_1)))
        ### perform pooling
        seq_out_0.append(model.mean_pool(seq_out_masked_0))
        seq_out_1.append(model.mean_pool(seq_out_masked_1))

    pooled_output_0 = torch.cat([seq.view(1, -1) for seq in seq_out_0], dim=0)
    pooled_output_1 = torch.cat([seq.view(1, -1) for seq in seq_out_1], dim=0)

    current_out = pooled_output_0

    retrieved = model.retriever(
        tokenized.input_ids,
        combined_out.cpu().detach().to(torch.float32).numpy(),
        combined_out.cpu().detach().to(torch.float32).numpy(),  ## sending dummy
        combined_out.cpu().detach().to(torch.float32).numpy(),  ## sending dummy
        prefix=model.generator.config.prefix,
        n_docs=n_docs,
        dialog_lengths=dialog_lengths,
        domain="dmv",
        return_tensors="pt",
        bm25=model.bm25,
    )


    generated = model.generate(
    context_input_ids=retrieved["context_input_ids"],
    context_attention_mask=retrieved["context_attention_mask"], 
    doc_scores=retrieved.doc_scores, 
    num_beams=4, 
    num_return_sequences=4,
    n_docs=n_docs
    )

    return retrieved, tokenizer.batch_decode(generated, skip_special_tokens=True)



In [53]:
generate_from_question(question, tokenizer, model, n_docs=4)



({'context_input_ids': tensor([[    0, 42891,     6,  ...,     1,     1,     1],
         [    0, 42891,     6,  ...,     1,     1,     1],
         [    0, 42891,     6,  ...,     1,     1,     1],
         [    0, 42891,     6,  ...,     1,     1,     1]]), 'context_attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]), 'retrieved_doc_embeds': tensor([[[-0.0722,  0.0278,  0.4568,  ...,  0.1845, -0.3385,  0.1325],
          [-0.0600,  0.4052,  0.5585,  ..., -0.3723, -0.6438,  0.0931],
          [ 0.3052,  0.1145,  0.6771,  ..., -0.4292, -0.4348,  0.0487],
          [ 0.1788, -0.1359,  0.5161,  ...,  0.0727, -0.2167, -0.2042]]]), 'doc_ids': tensor([[1804, 2419, 2607, 1989]]), 'doc_scores': tensor([[80.1930, 77.2222, 73.7971, 72.1576]])},
 ['hello.hello, i forgot o update my address, can you help me can you help',
  'hello.hello, i forgot o update my address, i forgot o update my addre

In [54]:
with open("data/mdd_all/dd-generation-structure/train.source") as f:
    question_lines = f.readlines()

with open("data/mdd_all/dd-generation-structure/train.pids") as f:
    pid_lines = [int(line.strip()) for line in f.readlines()]

with open("data/mdd_all/dd-generation-structure/train.target") as f:
    answer_lines = f.readlines()

In [63]:
data = []
q2pred = {}

i = 0

for question, correct_pid, answer in tqdm(zip(question_lines, pid_lines, answer_lines)):
    retrieved, preds = generate_from_question(question, tokenizer, model, n_docs=10)
    doc_ids = retrieved.doc_ids[0]
    d = {}
    d["question"] = question
    d["answers"] = [answer]
    d["ctxs"] = []
    q2pred[question] = preds


    docs_dict = retriever.index.get_doc_dicts(doc_ids)  

    for (doc_id, doc) in zip(doc_ids, docs_dict):
         d["ctxs"].append({
             "id": doc_id.item(),
             "title": doc["title"],
             "text": doc["text"],
             "has_answer": doc_id.item() == correct_pid
         })
    
    data.append(d)
    if i > 10:
        break
    i += 1


11it [01:43,  9.40s/it]


In [67]:
rider.rider_rerank_measure(data, q2pred, n_ctxs=5, n_pred=4)

100%|██████████| 12/12 [05:33<00:00, 27.83s/it]  

		 old   rerank
top-1 acc:	 1.000 1.000
top-5 acc:	 1.000 1.000
top-10 acc:	 1.000 1.000
top-20 acc:	 1.000 1.000
top-100 acc:	 1.000 1.000



