In [1]:
%env TRANSFORMERS_OFFLINE=1
%env HF_DATASETS_OFFLINE=1

env: TRANSFORMERS_OFFLINE=1
env: HF_DATASETS_OFFLINE=1


In [None]:
from datautils.dialog_data import DialogData
from transformers import AutoTokenizer
from retrieval_model.model import LitBERT
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch
import numpy as np
import faiss

In [None]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", use_fast=True, verbose=False)

In [None]:
train = DialogData("data/ijcnlp_dailydialog/train/dialogues_train.txt", tokenizer, neg_per_positive=0)

In [None]:
train_loader = DataLoader(train, batch_size=16, collate_fn=train.collate_fn, shuffle=False)

In [None]:
contexts, responses = zip(*train.data)

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cuda")
print(device)

In [None]:
r3_model = LitBERT.load_from_checkpoint("runs/DD_only/BERT/epoch=3-step=50703.ckpt")
r3_model.freeze()

In [None]:
r3_model = r3_model.to(device)

In [None]:
emb_contexts = []
for batch_idx, batch in enumerate(tqdm(train_loader)):
    c_enc = r3_model.forward_context_only(batch['premise'].to(device), batch['premise_length'].to(device)).detach().cpu().numpy()
    emb_contexts.extend(c_enc)

In [None]:
emb_contexts = np.stack(emb_contexts, 0)

In [None]:
emb_contexts.shape

In [None]:
d=1536
nlist=50
m = 8  # number of centroid IDs in final compressed vectors
bits = 8 # number of bits in each centroid

quantizer = faiss.IndexFlatPI(d)  # we keep the same L2 distance flat index
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, bits)

In [None]:
%%time
index.train(emb_contexts)

In [None]:
%%time
index.add(emb_contexts)

In [None]:
c_enc = r3_model.forward_context_only(batch['premise'].to(device), batch['premise_length'].to(device)).detach().cpu().numpy()

In [None]:
index.add(c_enc)

In [None]:
%%timeit
D, I = index.search(c_enc, 50)

In [None]:
D

In [None]:
for ix in I[0].tolist():
    print(f"Context: {contexts[ix]}")
    print("-------------------")
    print(f"Response: {responses[ix]}")
    print("==============================")