In [1]:
import torch
import faiss
import numpy as np
from transformers import AutoTokenizer
from model import SimCSEModel
from dataset import SimCSEDataset
from tqdm import tqdm

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

In [6]:
MODEL_NAME = "bert-base-uncased"
MAX_LEN = 32
BATCH_SIZE = 64
CHECKPOINT_PATH = '../checkpoint/best_model.pth'

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)



In [3]:
from datasets import load_dataset

raw_dataset = load_dataset("daily_dialog")
dialogs = raw_dataset['train']['dialog']

sentences = []
for dialog in dialogs:
    sentences.extend(dialog)

sentences = [s.strip() for s in sentences if s.strip() != ""]

Downloading builder script:   0%|          | 0.00/4.85k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.27k [00:00<?, ?B/s]

In [4]:
dataset = SimCSEDataset(sentences, tokenizer, max_len=MAX_LEN)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

In [7]:
# Load model
model = SimCSEModel(MODEL_NAME).cuda()
model.load_state_dict(torch.load(CHECKPOINT_PATH))

model.eval()

SimCSEModel(
  (backbone): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_a

In [8]:
def get_embeddings(dataloader):
    model.eval()
    embeddings = []
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Embedding"):
            input_ids, attention_mask = batch
            input_ids = input_ids.cuda()
            attention_mask = attention_mask.cuda()
            emb = model(input_ids, attention_mask=attention_mask)
            embeddings.append(emb.cpu().numpy())
    return np.vstack(embeddings)

In [9]:
# Get all embeddings
embeddings = get_embeddings(dataloader)

Embedding: 100%|██████████| 1363/1363 [00:46<00:00, 29.10it/s]


In [10]:
# Faiss index creation
dim = embeddings.shape[1]  # dimension of the embeddings
index = faiss.IndexFlatL2(dim)  # Use L2 distance for similarity
index.add(embeddings)  # Add embeddings to the index

In [11]:
def search(query, k=5):
    query_tokens = tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=MAX_LEN)
    query_input_ids = query_tokens['input_ids'].cuda()
    query_attention_mask = query_tokens['attention_mask'].cuda()

    query_embedding = model(query_input_ids, attention_mask=query_attention_mask).detach().cpu().numpy()

    # Search in the Faiss index
    distances, indices = index.search(query_embedding, k)
    return distances, indices

In [17]:
query = "Are you down to go to the party tonight?"
distances, indices = search(query)

for i, idx in enumerate(indices[0]):
    input_ids, _ = dataset[idx]
    decoded = tokenizer.decode(input_ids, skip_special_tokens=True)

    print(f"Rank {i + 1}: {decoded} | Distance: {distances[0][i]:.4f}")

Rank 1: anyways, are you going to the party tonight? | Distance: 94.3596
Rank 2: aren't you staying for the party? | Distance: 102.3875
Rank 3: do you want to go out for dinner tonight? | Distance: 113.0670
Rank 4: do you want to go out for dinner tonight? | Distance: 113.0670
Rank 5: would you like to go to a party tonight? | Distance: 114.3972
