# Matryoshka Embedding Models:

The purpose of this notebook is to put together a basic embeddings class that will initialize
an embedding model learnt using the Matryoshka Learning Representations. In the hugging face documentation, they use the model: 

[__mpnet-base-nli-matryoshka__](https://huggingface.co/tomaarsen/mpnet-base-nli-matryoshka). 

In our example, we use: 

[__nomic-ai/nomic-embed-text-v1.5__](https://huggingface.co/nomic-ai/nomic-embed-text-v1.5).

__References:__

* [Matryoshka at Hugging Face](https://huggingface.co/blog/matryoshka)

* [Matryoshka Representation Learning Paper](https://arxiv.org/abs/2205.13147) - arXiv Paper

* [GitHub page for the Model](https://github.com/huggingface/blog/blob/main/matryoshka.md)

In [4]:
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
import torch.nn.functional as F

# Define the class which can be instantiated and reused
class MatryoshkaEmbeddingModel:
    # Constructor uses default model as nomic-ai/nomic-embed-text-v1.5
    # Pass model name as parameter during init time to use a different model. 
    def __init__(self, model_name='nomic-ai/nomic-embed-text-v1.5'):
        #self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = SentenceTransformer(model_name, trust_remote_code=True)
        self.matryoshka_dim = 64
        
    def get_embeddings(self, input_text, shrink=False):
        embeddings = self.model.encode([input_text], convert_to_tensor=True)
        embeddings = F.layer_norm(embeddings, normalized_shape=(embeddings.shape[1],))
        if shrink:
            embeddings = embeddings[:, :self.matryoshka_dim]
            embeddings = F.normalize(embeddings, p=2, dim=1)
        return embeddings[0]


In [6]:
model = MatryoshkaEmbeddingModel()
text = "Attention Mechanism is the secret sauce behind the success of transformers"
embeddings = model.get_embeddings(text)
print(embeddings.size())
print(embeddings)


<All keys matched successfully>


torch.Size([768])
tensor([-1.4977e+00,  1.6466e+00, -5.1995e+00,  1.8731e-01,  2.1725e+00,
         2.4152e-01,  1.0706e+00,  8.5216e-01,  6.6628e-01, -6.2679e-01,
        -2.3779e-01,  7.5693e-01,  2.3085e+00,  5.1368e-01,  2.1475e+00,
        -1.4171e+00,  4.0186e-01, -1.4060e+00, -8.8940e-03,  7.5641e-01,
        -5.0104e-01, -7.0278e-01, -4.1476e-02,  1.3763e+00,  7.7925e-01,
         8.0501e-01, -1.5455e+00, -8.6730e-01,  1.3228e+00,  5.7217e-01,
         1.1059e+00, -1.2174e+00, -8.6582e-01, -1.1297e+00, -1.8290e+00,
        -1.7643e+00,  1.2090e+00,  1.0123e+00, -1.0847e+00,  1.6667e+00,
        -9.7300e-01,  8.6193e-01, -6.9288e-01, -5.1203e-01,  1.9160e+00,
        -5.9385e-02, -1.1146e+00, -6.0331e-01,  4.9704e-01, -2.0574e+00,
         1.2387e+00, -3.8294e-01,  6.5614e-01, -1.3783e+00,  1.2161e+00,
         7.2025e-02,  9.1584e-01, -7.5089e-01,  8.0505e-01,  2.9784e-02,
         1.1744e+00,  2.5774e+00,  1.2271e+00,  1.0101e+00,  4.2204e-01,
        -9.0785e-01, -2.5192e-01,

In [7]:
embeddings = model.get_embeddings(text, shrink=True)
print(embeddings.size())
print(embeddings)


torch.Size([64])
tensor([-0.1415,  0.1555, -0.4911,  0.0177,  0.2052,  0.0228,  0.1011,  0.0805,
         0.0629, -0.0592, -0.0225,  0.0715,  0.2180,  0.0485,  0.2028, -0.1338,
         0.0380, -0.1328, -0.0008,  0.0714, -0.0473, -0.0664, -0.0039,  0.1300,
         0.0736,  0.0760, -0.1460, -0.0819,  0.1249,  0.0540,  0.1045, -0.1150,
        -0.0818, -0.1067, -0.1727, -0.1666,  0.1142,  0.0956, -0.1025,  0.1574,
        -0.0919,  0.0814, -0.0654, -0.0484,  0.1810, -0.0056, -0.1053, -0.0570,
         0.0469, -0.1943,  0.1170, -0.0362,  0.0620, -0.1302,  0.1149,  0.0068,
         0.0865, -0.0709,  0.0760,  0.0028,  0.1109,  0.2434,  0.1159,  0.0954],
       device='mps:0')
