# Imports


In [1]:
import os
from warnings import filterwarnings

filterwarnings("ignore")

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

from abc import ABC, abstractmethod
from IPython.display import display, HTML
from captum.attr import visualization

%matplotlib inline

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

RANDOM_SEED = 42
DATA_DIR = "data"
MODEL_DIR = "models"

cpu


# Define Utilities


In [397]:
class EmbeddingModel(ABC):
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    @abstractmethod
    def get_embeddings(self, text):
        pass

    @abstractmethod
    def get_similarity(self, emb1, emb2):
        pass


def visualize_text(records, idx=0):
    dom = ["<style>table, th, td {text-align: left;}</style><table width: 100%>"]
    rows = ["<th></th>"]
    for record in records:
        tokens, scores = record
        tokens, scores = tokens[idx:], scores[:, idx:]
        print(f"Document Similarity: {np.mean(scores):.4f}")
        scores = (
            2 * (scores - np.min(scores)) / (np.max(scores) - np.min(scores)) - 1
        )[0].tolist()
        rows.append(
            "".join(
                [
                    "<tr>",
                    visualization.format_word_importances(tokens, scores),
                    "</tr>",
                ]
            )
        )

    dom.append("".join(rows))
    dom.append("</table>")
    html = HTML("".join(dom))
    display(html)

##  Bert

In [412]:
from transformers import BertTokenizer, BertModel


class BertEmbeddingModel(EmbeddingModel):
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def get_embeddings(self, text, state="pooler_output"):
        inputs = self.encode_text(text)
        outputs = self.model(**inputs)
        return outputs.get(state)

    def get_similarity(self, query_emb, doc_emb):
        expand_dim = doc_emb.shape[1]
        query_emb_exp = query_emb.unsqueeze(1).expand(-1, expand_dim, -1)
        return F.cosine_similarity(query_emb_exp, doc_emb, dim=-1).detach().numpy()

    def encode_text(self, text, add_special_tokens=False):
        return self.tokenizer(
            text, add_special_tokens=add_special_tokens, return_tensors="pt"
        )

    def decode_tokens(self, ids):
        return self.tokenizer.convert_ids_to_tokens(ids[0])


query = "red wine"
document = "I drank a bottle of red wine last night. It was much better than green tea."

MODEL_NAME = "bert-base-uncased"
model = BertModel.from_pretrained(MODEL_NAME)
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
emb = BertEmbeddingModel(model, tokenizer)

query_emb = emb.get_embeddings(query)
doc_emb = emb.get_embeddings(document, state="last_hidden_state")
cos_sim = emb.get_similarity(query_emb, doc_emb)
doc_tokens = emb.decode_tokens(emb.encode_text(document).get("input_ids"))
visualize_text(records=[(doc_tokens, cos_sim)])

Document Similarity: 0.0009


0
i drank a bottle of red wine last night . it was much better than green tea .


## E5

In [414]:
from transformers import AutoTokenizer, AutoModel


class E5EmbeddingModel(EmbeddingModel):
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def get_embeddings(self, text, state="pooler_output", normalize=True):
        inputs = self.encode_text(text)
        outputs = self.model(**inputs)
        if state == "pooler_output":
            embeddings = self._average_pool(
                outputs.get("last_hidden_state"), inputs.get("attention_mask")
            )
        else:
            embeddings = outputs.get(state)
        if normalize:
            return F.normalize(embeddings, p=2, dim=1)
        else:
            return embeddings

    def get_similarity(self, query_emb, doc_emb):
        expand_dim = doc_emb.shape[1]
        query_emb_exp = query_emb.unsqueeze(1).expand(-1, expand_dim, -1)
        return F.cosine_similarity(query_emb_exp, doc_emb, dim=-1).detach().numpy()

    def encode_text(
        self,
        text,
        add_special_tokens=False,
        max_length=512,
        padding=True,
        truncation=True,
    ):
        return self.tokenizer(
            text,
            max_length=max_length,
            padding=padding,
            truncation=truncation,
            return_tensors="pt",
            add_special_tokens=add_special_tokens,
        )

    def decode_tokens(self, ids):
        return self.tokenizer.convert_ids_to_tokens(ids[0])

    def _average_pool(self, last_hidden_states, attention_mask):
        last_hidden = last_hidden_states.masked_fill(
            ~attention_mask[..., None].bool(), 0.0
        )
        return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]


MODEL_NAME = "intfloat/e5-small-v2"
model = AutoModel.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
emb = E5EmbeddingModel(model, tokenizer)

query_emb = emb.get_embeddings(f"query: {query}")
doc_emb = emb.get_embeddings(f"passage: {document}", state="last_hidden_state")
cos_sim = emb.get_similarity(query_emb, doc_emb)
doc_tokens = emb.decode_tokens(emb.encode_text(f"passage: {document}").get("input_ids"))
visualize_text(records=[(doc_tokens, cos_sim)], idx=2)

Document Similarity: 0.6038


0
sage double cloth maxi dress with yellow stripes


In [415]:
query = "red striped sage maxi gown"
query_emb = emb.get_embeddings(f"query: {query}")
print(f"Query: {query}")

colors = ["red", "crimson red", "crimson", "maroon", "olive", "blue", "yellow"]
records = []
for c in colors:
    document = f"sage double cloth maxi dress with {c} stripes"
    doc_emb = emb.get_embeddings(f"passage: {document}", state="last_hidden_state")
    cos_sim = emb.get_similarity(query_emb, doc_emb)
    doc_tokens = emb.decode_tokens(
        emb.encode_text(f"passage: {document}").get("input_ids")
    )
    records.append((doc_tokens, cos_sim))

visualize_text(records=records, idx=2)

Query: red striped sage maxi gown
Document Similarity: 0.6397
Document Similarity: 0.6273
Document Similarity: 0.6116
Document Similarity: 0.6087
Document Similarity: 0.6072
Document Similarity: 0.5967
Document Similarity: 0.6038


0
sage double cloth maxi dress with red stripes
sage double cloth maxi dress with crimson red stripes
sage double cloth maxi dress with crimson stripes
sage double cloth maxi dress with maroon stripes
sage double cloth maxi dress with olive stripes
sage double cloth maxi dress with blue stripes
sage double cloth maxi dress with yellow stripes
