In [2]:
import random
import warnings

import plotly.express as px
import plotly.graph_objs as go
import torch
from sentence_transformers import SentenceTransformer, util
from sklearn.manifold import TSNE

warnings.filterwarnings("ignore")


In [None]:
model = SentenceTransformer("all-MiniLM-L6-v2")
model


In [None]:
tokenized_data = model.tokenize(["walker walked a long walk"])
tokenized_data


In [None]:
model.tokenizer.convert_ids_to_tokens(tokenized_data["input_ids"][0])


In [None]:
# Transformer consists of multiple stack modules. Tokens are an input
# of the first one, so we can ignore the rest.
first_module = model._first_module()
first_module.auto_model


In [None]:
embeddings = first_module.auto_model.embeddings
embeddings


device = torch.device(
    "mps" if torch.has_mps else "cpu"
)  # Use MPS for Apple, CUDA for others, or fallback to CPU

first_sentence = "vector search optimization"
second_sentence = "we learn to apply vector search optimization"

with torch.no_grad():
    # Tokenize both texts
    first_tokens = model.tokenize([first_sentence])
    second_tokens = model.tokenize([second_sentence])

    # Get the corresponding embeddings
    first_embeddings = embeddings.word_embeddings(first_tokens["input_ids"].to(device))
    second_embeddings = embeddings.word_embeddings(
        second_tokens["input_ids"].to(device)
    )

first_embeddings.shape, second_embeddings.shape


In [None]:
distances = (
    util.cos_sim(first_embeddings.squeeze(), second_embeddings.squeeze()).cpu().numpy()
)  # Move the tensor to the CPU and convert to a NumPy array

px.imshow(
    distances,
    x=model.tokenizer.convert_ids_to_tokens(second_tokens["input_ids"][0]),
    y=model.tokenizer.convert_ids_to_tokens(first_tokens["input_ids"][0]),
    text_auto=True,
)


In [None]:
# ### Visualizing the input embeddings


token_embeddings = (
    first_module.auto_model.embeddings.word_embeddings.weight.detach().cpu().numpy()
)
token_embeddings.shape


In [None]:
vocabulary = first_module.tokenizer.get_vocab()
sorted_vocabulary = sorted(
    vocabulary.items(),
    key=lambda x: x[1],  # uses the value of the dictionary entry
)
sorted_tokens = [token for token, _ in sorted_vocabulary]
random.choices(sorted_tokens, k=100)


In [None]:
tsne = TSNE(n_components=2, metric="cosine", random_state=42)
tsne_embeddings_2d = tsne.fit_transform(token_embeddings)
tsne_embeddings_2d.shape


In [None]:
token_colors = []
for token in sorted_tokens:
    if token[0] == "[" and token[-1] == "]":
        token_colors.append("red")
    elif token.startswith("##"):
        token_colors.append("blue")
    else:
        token_colors.append("green")


# In[ ]:


scatter = go.Scattergl(
    x=tsne_embeddings_2d[:, 0],
    y=tsne_embeddings_2d[:, 1],
    text=sorted_tokens,
    marker=dict(color=token_colors, size=3),
    mode="markers",
    name="Token embeddings",
)

fig = go.FigureWidget(
    data=[scatter],
    layout=dict(
        width=600,
        height=900,
        margin=dict(l=0, r=0),
    ),
)

fig.show()


# ## Output token embeddings


In [None]:
output_embedding = model.encode(["walker walked a long walk"])
output_embedding.shape


In [None]:
output_token_embeddings = model.encode(
    ["walker walked a long walk"], output_value="token_embeddings"
)
output_token_embeddings[0].shape


In [None]:
first_sentence = "vector search optimization"
second_sentence = "we learn to apply vector search optimization"

with torch.no_grad():
    first_tokens = model.tokenize([first_sentence])
    second_tokens = model.tokenize([second_sentence])

    first_embeddings = model.encode([first_sentence], output_value="token_embeddings")
    second_embeddings = model.encode([second_sentence], output_value="token_embeddings")

distances = util.cos_sim(first_embeddings[0], second_embeddings[0])


In [None]:
px.imshow(
    distances.cpu().numpy(),  # Move the tensor to CPU and convert to a NumPy array
    x=model.tokenizer.convert_ids_to_tokens(second_tokens["input_ids"][0]),
    y=model.tokenizer.convert_ids_to_tokens(first_tokens["input_ids"][0]),
    text_auto=True,
)
