In [2]:
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModel, AutoConfig
import torch.optim as optim
from mlx_lm import load, generate
from tqdm import tqdm
import chromadb
from chromadb.utils import embedding_functions
import torch
from chromadb import Documents, EmbeddingFunction, Embeddings
from numpy.typing import ArrayLike
from sentence_transformers.util import cos_sim, dot_score
import numpy as np
import torch.nn as nn
from peft import PeftModel
from llm2vec import LLM2Vec

In [3]:
class Embedding(EmbeddingFunction):
    def __init__(self):
        self.embedding_model = SentenceTransformer('thenlper/gte-large')  # ("Salesforce/SFR-Embedding-Mistral") (7b)

    def encode(self, input: Documents) -> Embeddings:
        if isinstance(input, str):
            input = [input]
        embeddings = self.embedding_model.encode(input)
        return embeddings
    
    def __call__(self, input: Documents) -> Embeddings:
        if isinstance(input, str):
            input = [input]
        embeddings = self.embedding_model.encode(input)
        return embeddings

    def distance(self, x: ArrayLike, y: ArrayLike) -> float:
        return cos_sim(x, y)

In [6]:
class PeftEmbeddingModel(EmbeddingFunction):
    def __init__(self):
        config = AutoConfig.from_pretrained("McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp", trust_remote_code=True)
        tokenizer = AutoTokenizer.from_pretrained("McGill-NLP/LLM2Vec-Sheared-LLaMA-mntp")

        # Loading base MNTP model, along with custom code that enables bidirectional connections in decoder-only LLMs
        tokenizer = AutoTokenizer.from_pretrained(
            "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp"
        )
        config = AutoConfig.from_pretrained(
            "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp", trust_remote_code=True
        )
        model = AutoModel.from_pretrained(
            "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp",
            trust_remote_code=True,
            config=config,
            torch_dtype=torch.float16,
            device_map="mps",
        )
        model = PeftModel.from_pretrained(
            model,
            "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp",
        )
        model = model.merge_and_unload()  # This can take several minutes on cpu

        # Loading unsupervised-trained LoRA weights. This loads the trained LoRA weights on top of MNTP model. Hence the final weights are -- Base model + MNTP (LoRA) + SimCSE (LoRA).
        model = PeftModel.from_pretrained(
            model, "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-unsup-simcse"
        )

        # Or loading supervised-trained LoRA weights
        model = PeftModel.from_pretrained(
            model, "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-supervised"
        )
        self.l2v = LLM2Vec(model, tokenizer, pooling_mode="mean", max_length=512)

    def __call__(self, input: Documents) -> Embeddings:
        if isinstance(input, str):
            input = [input]
        embeddings = self.l2v.encode(input)
        return embeddings
    
    def encode(self, input: Documents) -> Embeddings:
        embeddings = self(input)
        return embeddings

In [7]:
peftEmbeddingModel = PeftEmbeddingModel()

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Some weights of the model checkpoint at mistralai/Mistral-7B-Instruct-v0.2 were not used when initializing MistralEncoderModel: ['lm_head.weight']
- This IS expected if you are initializing MistralEncoderModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing MistralEncoderModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


TypeError: BFloat16 is not supported on MPS

In [3]:
client = chromadb.PersistentClient(path="./chromadb")
embedding_function = Embedding()

collection = client.get_collection("Python-RAG", embedding_function=embedding_function)

In [4]:
llm_model, tokenizer = load("mlx-community/quantized-gemma-7b-it")

Fetching 8 files:   0%|          | 0/8 [00:00<?, ?it/s]

In [5]:
data = collection.get()
questions = []
for index in tqdm(range(len(data["ids"])), desc="Generating questions"):
    id = data["ids"][index]
    document = data["documents"][index]
    embeddings = embedding_function.encode(document)[0]
    prompt = f"<start_of_turn>user Create a vague short question for this text, without refering to the texts contents directly: {document} <start_of_turn>model Sure, here is the question:"
    question = generate(llm_model, tokenizer, prompt=prompt, verbose=False, max_tokens=1000)
    question = question.replace("\n", "")
    question = question.replace("*", "")
    questions.append((id, question, embeddings))

Generating questions: 100%|██████████| 121/121 [02:40<00:00,  1.33s/it]


In [6]:
import random
random_questions = random.sample(questions, 5)
for question in random_questions:
    print(question[1])

What is the purpose of the copyright notice and permission notice included in the text?
How can pairs of data be sorted by their second element using lambda expressions?
What is the purpose of the default argument values in the function `ask_ok`?
What does the text following the heading "4. More Control Flow Tools Python Tutorial, Release 3.7.0" suggest about the content of the chapter?
What is the purpose of the equal sign (=) in the text?


In [7]:
import torch
import numpy as np
from torch.nn.functional import cosine_similarity

class VectorStore:
    def __init__(self, vectors, ids, device):
        """
        Initializes the store with vectors and their corresponding ids.
        
        :param vectors: A list of numpy arrays.
        :param ids: A list of ids corresponding to the vectors.
        """
        assert len(vectors) == len(ids), "Vectors and IDs must have the same length"
        
        # Convert numpy arrays to PyTorch tensors and store them
        self.vectors = [torch.tensor(vec, dtype=torch.float32, device=device) for vec in vectors]
        print(len(self.vectors))
        self.ids = ids

    def find_closest_id(self, vector):
        """
        Finds the ID of the stored vector closest to the given vector based on cosine similarity.
        
        :param vector: A numpy array representing the vector to compare.
        :return: The id of the closest vector.
        """
        
        # Calculate cosine similarity between the given vector and each stored vector
        similarities = [cosine_similarity(vector.unsqueeze(0), vec.unsqueeze(0)) for vec in self.vectors]
        
        # Find the index of the maximum similarity
        max_index = torch.argmax(torch.tensor(similarities)).item()
        
        # Return the corresponding ID
        return self.ids[max_index], self.vectors[max_index]

In [8]:
vectors = [vector for id, question, vector in questions]
ids = [id for id, question, vector in questions]
vector_store = VectorStore(vectors, ids, 'mps')

121


In [9]:
def batch_to_device(batch, target_device: torch.device):
    """
    send a pytorch batch to a device (CPU/GPU)
    """
    for key in batch:
        if isinstance(batch[key], torch.Tensor):
            batch[key] = batch[key].to(target_device)
    return batch

In [10]:
import random
import copy

model_id = 'thenlper/gte-large'
model = SentenceTransformer(model_id)
optimizer = optim.Adam(model.parameters(), lr=1e-5, eps=1e-5)

model.train()
loss_fn = nn.CosineEmbeddingLoss()

episodes = 200
batch_size = 16
num_batches = 5

for episode in range(episodes):
    predictions = []
    labels = []
    accuracies = []
    ep_reward = 0

    for _ in tqdm(range(num_batches), desc=f"Episode {episode+1}/{episodes}"):
        batch = random.sample(questions, batch_size)
        
        for id, question, vector in batch:
            tokenized_question = model.tokenize([question])
            tokenized_question = batch_to_device(tokenized_question, model.device)
            embedding = model(tokenized_question)['sentence_embedding'][0]
            prediction, label_vector = vector_store.find_closest_id(embedding)
            if prediction == id:
                reward = 1
            else:
                reward = 0
            predictions.append(embedding)
            labels.append(label_vector)
            accuracies.append(reward)
    
    predictions = torch.stack(predictions)
    label_vector = torch.stack(labels)
    optimizer.zero_grad()
    loss = loss_fn(predictions, label_vector, torch.ones(batch_size * num_batches).to(model.device))
    loss.backward()
    optimizer.step()
    
    accuracy = sum(accuracies) / len(accuracies)
    print(f"Episode {episode+1}/{episodes} loss: {loss.item()} accuracy: {accuracy}")


Episode 1/200: 100%|██████████| 5/5 [00:06<00:00,  1.34s/it]


Episode 1/200 loss: 0.09957718849182129 accuracy: 0.8


Episode 2/200: 100%|██████████| 5/5 [00:07<00:00,  1.40s/it]


Episode 2/200 loss: 0.0974433645606041 accuracy: 0.75


Episode 3/200: 100%|██████████| 5/5 [00:07<00:00,  1.41s/it]


Episode 3/200 loss: 0.085300974547863 accuracy: 0.7625


Episode 4/200: 100%|██████████| 5/5 [00:07<00:00,  1.41s/it]


Episode 4/200 loss: 0.08502087742090225 accuracy: 0.7


Episode 5/200: 100%|██████████| 5/5 [00:06<00:00,  1.38s/it]


Episode 5/200 loss: 0.08026508241891861 accuracy: 0.8375


Episode 6/200: 100%|██████████| 5/5 [00:06<00:00,  1.35s/it]


Episode 6/200 loss: 0.07770203799009323 accuracy: 0.725


Episode 7/200: 100%|██████████| 5/5 [00:06<00:00,  1.37s/it]


Episode 7/200 loss: 0.07247059047222137 accuracy: 0.8


Episode 8/200: 100%|██████████| 5/5 [00:06<00:00,  1.35s/it]


Episode 8/200 loss: 0.07462989538908005 accuracy: 0.675


Episode 9/200: 100%|██████████| 5/5 [00:06<00:00,  1.37s/it]


Episode 9/200 loss: 0.06969386339187622 accuracy: 0.75


Episode 10/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 10/200 loss: 0.06905487179756165 accuracy: 0.7


Episode 11/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 11/200 loss: 0.0710454210639 accuracy: 0.6375


Episode 12/200: 100%|██████████| 5/5 [00:06<00:00,  1.35s/it]


Episode 12/200 loss: 0.06408755481243134 accuracy: 0.7375


Episode 13/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 13/200 loss: 0.06559502333402634 accuracy: 0.7625


Episode 14/200: 100%|██████████| 5/5 [00:06<00:00,  1.34s/it]


Episode 14/200 loss: 0.06311161071062088 accuracy: 0.7125


Episode 15/200: 100%|██████████| 5/5 [00:06<00:00,  1.34s/it]


Episode 15/200 loss: 0.06161809712648392 accuracy: 0.6625


Episode 16/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 16/200 loss: 0.05626188591122627 accuracy: 0.7875


Episode 17/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 17/200 loss: 0.06164275482296944 accuracy: 0.7125


Episode 18/200: 100%|██████████| 5/5 [00:06<00:00,  1.37s/it]


Episode 18/200 loss: 0.056734777987003326 accuracy: 0.775


Episode 19/200: 100%|██████████| 5/5 [00:07<00:00,  1.41s/it]


Episode 19/200 loss: 0.0583144836127758 accuracy: 0.7625


Episode 20/200: 100%|██████████| 5/5 [00:06<00:00,  1.37s/it]


Episode 20/200 loss: 0.057631779462099075 accuracy: 0.6625


Episode 21/200: 100%|██████████| 5/5 [00:06<00:00,  1.39s/it]


Episode 21/200 loss: 0.05656299740076065 accuracy: 0.675


Episode 22/200: 100%|██████████| 5/5 [00:06<00:00,  1.37s/it]


Episode 22/200 loss: 0.054244816303253174 accuracy: 0.6875


Episode 23/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 23/200 loss: 0.054403699934482574 accuracy: 0.7125


Episode 24/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 24/200 loss: 0.05478975921869278 accuracy: 0.75


Episode 25/200: 100%|██████████| 5/5 [00:06<00:00,  1.38s/it]


Episode 25/200 loss: 0.05591518431901932 accuracy: 0.6


Episode 26/200: 100%|██████████| 5/5 [00:07<00:00,  1.43s/it]


Episode 26/200 loss: 0.05302702262997627 accuracy: 0.6875


Episode 27/200: 100%|██████████| 5/5 [00:06<00:00,  1.37s/it]


Episode 27/200 loss: 0.05021423101425171 accuracy: 0.7125


Episode 28/200: 100%|██████████| 5/5 [00:07<00:00,  1.44s/it]


Episode 28/200 loss: 0.05388776212930679 accuracy: 0.675


Episode 29/200: 100%|██████████| 5/5 [00:07<00:00,  1.41s/it]


Episode 29/200 loss: 0.04979405179619789 accuracy: 0.7


Episode 30/200: 100%|██████████| 5/5 [00:06<00:00,  1.37s/it]


Episode 30/200 loss: 0.051896750926971436 accuracy: 0.6625


Episode 31/200: 100%|██████████| 5/5 [00:06<00:00,  1.40s/it]


Episode 31/200 loss: 0.05068845674395561 accuracy: 0.7


Episode 32/200: 100%|██████████| 5/5 [00:07<00:00,  1.40s/it]


Episode 32/200 loss: 0.051299966871738434 accuracy: 0.725


Episode 33/200: 100%|██████████| 5/5 [00:07<00:00,  1.42s/it]


Episode 33/200 loss: 0.05208313465118408 accuracy: 0.75


Episode 34/200: 100%|██████████| 5/5 [00:07<00:00,  1.41s/it]


Episode 34/200 loss: 0.05201618745923042 accuracy: 0.7625


Episode 35/200: 100%|██████████| 5/5 [00:07<00:00,  1.49s/it]


Episode 35/200 loss: 0.050849270075559616 accuracy: 0.6625


Episode 36/200: 100%|██████████| 5/5 [00:07<00:00,  1.42s/it]


Episode 36/200 loss: 0.05148733779788017 accuracy: 0.675


Episode 37/200: 100%|██████████| 5/5 [00:07<00:00,  1.40s/it]


Episode 37/200 loss: 0.04944610595703125 accuracy: 0.675


Episode 38/200: 100%|██████████| 5/5 [00:07<00:00,  1.41s/it]


Episode 38/200 loss: 0.050332825630903244 accuracy: 0.6375


Episode 39/200: 100%|██████████| 5/5 [00:07<00:00,  1.48s/it]


Episode 39/200 loss: 0.04777860268950462 accuracy: 0.725


Episode 40/200: 100%|██████████| 5/5 [00:07<00:00,  1.40s/it]


Episode 40/200 loss: 0.047355517745018005 accuracy: 0.7125


Episode 41/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 41/200 loss: 0.046889133751392365 accuracy: 0.75


Episode 42/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 42/200 loss: 0.047952938824892044 accuracy: 0.6375


Episode 43/200: 100%|██████████| 5/5 [00:06<00:00,  1.37s/it]


Episode 43/200 loss: 0.04736020416021347 accuracy: 0.8


Episode 44/200: 100%|██████████| 5/5 [00:06<00:00,  1.39s/it]


Episode 44/200 loss: 0.050170499831438065 accuracy: 0.6875


Episode 45/200: 100%|██████████| 5/5 [00:06<00:00,  1.35s/it]


Episode 45/200 loss: 0.046396058052778244 accuracy: 0.675


Episode 46/200: 100%|██████████| 5/5 [00:06<00:00,  1.39s/it]


Episode 46/200 loss: 0.046988438814878464 accuracy: 0.625


Episode 47/200: 100%|██████████| 5/5 [00:06<00:00,  1.38s/it]


Episode 47/200 loss: 0.04617663100361824 accuracy: 0.7375


Episode 48/200: 100%|██████████| 5/5 [00:06<00:00,  1.37s/it]


Episode 48/200 loss: 0.045595597475767136 accuracy: 0.7


Episode 49/200: 100%|██████████| 5/5 [00:07<00:00,  1.44s/it]


Episode 49/200 loss: 0.04582246020436287 accuracy: 0.65


Episode 50/200: 100%|██████████| 5/5 [00:07<00:00,  1.41s/it]


Episode 50/200 loss: 0.046335794031620026 accuracy: 0.7625


Episode 51/200: 100%|██████████| 5/5 [00:06<00:00,  1.39s/it]


Episode 51/200 loss: 0.04654260724782944 accuracy: 0.6375


Episode 52/200: 100%|██████████| 5/5 [00:06<00:00,  1.38s/it]


Episode 52/200 loss: 0.04753196984529495 accuracy: 0.725


Episode 53/200: 100%|██████████| 5/5 [00:07<00:00,  1.41s/it]


Episode 53/200 loss: 0.0443430133163929 accuracy: 0.75


Episode 54/200: 100%|██████████| 5/5 [00:07<00:00,  1.44s/it]


Episode 54/200 loss: 0.044266264885663986 accuracy: 0.6875


Episode 55/200: 100%|██████████| 5/5 [00:07<00:00,  1.48s/it]


Episode 55/200 loss: 0.044693656265735626 accuracy: 0.7125


Episode 56/200: 100%|██████████| 5/5 [00:07<00:00,  1.44s/it]


Episode 56/200 loss: 0.04474736377596855 accuracy: 0.75


Episode 57/200: 100%|██████████| 5/5 [00:07<00:00,  1.44s/it]


Episode 57/200 loss: 0.045365504920482635 accuracy: 0.725


Episode 58/200: 100%|██████████| 5/5 [00:06<00:00,  1.40s/it]


Episode 58/200 loss: 0.04406077414751053 accuracy: 0.725


Episode 59/200: 100%|██████████| 5/5 [00:06<00:00,  1.39s/it]


Episode 59/200 loss: 0.04267193377017975 accuracy: 0.6875


Episode 60/200: 100%|██████████| 5/5 [00:06<00:00,  1.38s/it]


Episode 60/200 loss: 0.043680980801582336 accuracy: 0.6875


Episode 61/200: 100%|██████████| 5/5 [00:07<00:00,  1.42s/it]


Episode 61/200 loss: 0.044491805136203766 accuracy: 0.7875


Episode 62/200: 100%|██████████| 5/5 [00:07<00:00,  1.50s/it]


Episode 62/200 loss: 0.04269612208008766 accuracy: 0.6125


Episode 63/200: 100%|██████████| 5/5 [00:07<00:00,  1.49s/it]


Episode 63/200 loss: 0.042902376502752304 accuracy: 0.675


Episode 64/200: 100%|██████████| 5/5 [00:07<00:00,  1.41s/it]


Episode 64/200 loss: 0.0428871251642704 accuracy: 0.6875


Episode 65/200: 100%|██████████| 5/5 [00:07<00:00,  1.49s/it]


Episode 65/200 loss: 0.04370526596903801 accuracy: 0.7


Episode 66/200: 100%|██████████| 5/5 [00:07<00:00,  1.40s/it]


Episode 66/200 loss: 0.04363435506820679 accuracy: 0.6625


Episode 67/200: 100%|██████████| 5/5 [00:06<00:00,  1.40s/it]


Episode 67/200 loss: 0.04153662174940109 accuracy: 0.7625


Episode 68/200: 100%|██████████| 5/5 [00:06<00:00,  1.38s/it]


Episode 68/200 loss: 0.04231612756848335 accuracy: 0.7125


Episode 69/200: 100%|██████████| 5/5 [00:07<00:00,  1.41s/it]


Episode 69/200 loss: 0.04020100459456444 accuracy: 0.6875


Episode 70/200: 100%|██████████| 5/5 [00:06<00:00,  1.40s/it]


Episode 70/200 loss: 0.04011337086558342 accuracy: 0.6875


Episode 71/200: 100%|██████████| 5/5 [00:06<00:00,  1.37s/it]


Episode 71/200 loss: 0.04111240431666374 accuracy: 0.625


Episode 72/200: 100%|██████████| 5/5 [00:06<00:00,  1.39s/it]


Episode 72/200 loss: 0.04221606254577637 accuracy: 0.6


Episode 73/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 73/200 loss: 0.03951388597488403 accuracy: 0.7375


Episode 74/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 74/200 loss: 0.039925530552864075 accuracy: 0.7125


Episode 75/200: 100%|██████████| 5/5 [00:07<00:00,  1.43s/it]


Episode 75/200 loss: 0.04385251924395561 accuracy: 0.8125


Episode 76/200: 100%|██████████| 5/5 [00:07<00:00,  1.42s/it]


Episode 76/200 loss: 0.04098017141222954 accuracy: 0.675


Episode 77/200: 100%|██████████| 5/5 [00:07<00:00,  1.49s/it]


Episode 77/200 loss: 0.040254175662994385 accuracy: 0.6875


Episode 78/200: 100%|██████████| 5/5 [00:07<00:00,  1.44s/it]


Episode 78/200 loss: 0.039329007267951965 accuracy: 0.7


Episode 79/200: 100%|██████████| 5/5 [00:07<00:00,  1.46s/it]


Episode 79/200 loss: 0.04002794250845909 accuracy: 0.7625


Episode 80/200: 100%|██████████| 5/5 [00:07<00:00,  1.46s/it]


Episode 80/200 loss: 0.039984315633773804 accuracy: 0.725


Episode 81/200: 100%|██████████| 5/5 [00:07<00:00,  1.53s/it]


Episode 81/200 loss: 0.04009897634387016 accuracy: 0.7625


Episode 82/200: 100%|██████████| 5/5 [00:07<00:00,  1.42s/it]


Episode 82/200 loss: 0.038481444120407104 accuracy: 0.6375


Episode 83/200: 100%|██████████| 5/5 [00:07<00:00,  1.43s/it]


Episode 83/200 loss: 0.0404372364282608 accuracy: 0.7625


Episode 84/200: 100%|██████████| 5/5 [00:07<00:00,  1.46s/it]


Episode 84/200 loss: 0.038621701300144196 accuracy: 0.675


Episode 85/200: 100%|██████████| 5/5 [00:07<00:00,  1.44s/it]


Episode 85/200 loss: 0.040503401309251785 accuracy: 0.675


Episode 86/200: 100%|██████████| 5/5 [00:07<00:00,  1.42s/it]


Episode 86/200 loss: 0.03817538544535637 accuracy: 0.75


Episode 87/200: 100%|██████████| 5/5 [00:07<00:00,  1.46s/it]


Episode 87/200 loss: 0.03840991482138634 accuracy: 0.6375


Episode 88/200: 100%|██████████| 5/5 [00:07<00:00,  1.50s/it]


Episode 88/200 loss: 0.03888457641005516 accuracy: 0.6625


Episode 89/200: 100%|██████████| 5/5 [00:07<00:00,  1.47s/it]


Episode 89/200 loss: 0.03823372721672058 accuracy: 0.7


Episode 90/200: 100%|██████████| 5/5 [00:06<00:00,  1.37s/it]


Episode 90/200 loss: 0.039797160774469376 accuracy: 0.7


Episode 91/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 91/200 loss: 0.037610121071338654 accuracy: 0.725


Episode 92/200: 100%|██████████| 5/5 [00:06<00:00,  1.37s/it]


Episode 92/200 loss: 0.0358041375875473 accuracy: 0.675


Episode 93/200: 100%|██████████| 5/5 [00:06<00:00,  1.37s/it]


Episode 93/200 loss: 0.038658563047647476 accuracy: 0.7


Episode 94/200: 100%|██████████| 5/5 [00:06<00:00,  1.37s/it]


Episode 94/200 loss: 0.03772427514195442 accuracy: 0.7125


Episode 95/200: 100%|██████████| 5/5 [00:07<00:00,  1.42s/it]


Episode 95/200 loss: 0.03697945550084114 accuracy: 0.7875


Episode 96/200: 100%|██████████| 5/5 [00:06<00:00,  1.39s/it]


Episode 96/200 loss: 0.03619007021188736 accuracy: 0.7375


Episode 97/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 97/200 loss: 0.03736802190542221 accuracy: 0.6875


Episode 98/200: 100%|██████████| 5/5 [00:06<00:00,  1.37s/it]


Episode 98/200 loss: 0.037759896367788315 accuracy: 0.675


Episode 99/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 99/200 loss: 0.03714445233345032 accuracy: 0.6375


Episode 100/200: 100%|██████████| 5/5 [00:06<00:00,  1.38s/it]


Episode 100/200 loss: 0.03525217995047569 accuracy: 0.6625


Episode 101/200: 100%|██████████| 5/5 [00:06<00:00,  1.39s/it]


Episode 101/200 loss: 0.036549728363752365 accuracy: 0.675


Episode 102/200: 100%|██████████| 5/5 [00:07<00:00,  1.45s/it]


Episode 102/200 loss: 0.03540791943669319 accuracy: 0.7125


Episode 103/200: 100%|██████████| 5/5 [00:06<00:00,  1.38s/it]


Episode 103/200 loss: 0.0358341746032238 accuracy: 0.7


Episode 104/200: 100%|██████████| 5/5 [00:06<00:00,  1.35s/it]


Episode 104/200 loss: 0.03604193404316902 accuracy: 0.775


Episode 105/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 105/200 loss: 0.0364571213722229 accuracy: 0.75


Episode 106/200: 100%|██████████| 5/5 [00:06<00:00,  1.37s/it]


Episode 106/200 loss: 0.03646061569452286 accuracy: 0.7375


Episode 107/200: 100%|██████████| 5/5 [00:06<00:00,  1.37s/it]


Episode 107/200 loss: 0.03438546881079674 accuracy: 0.6625


Episode 108/200: 100%|██████████| 5/5 [00:06<00:00,  1.37s/it]


Episode 108/200 loss: 0.03565068170428276 accuracy: 0.7625


Episode 109/200: 100%|██████████| 5/5 [00:07<00:00,  1.43s/it]


Episode 109/200 loss: 0.03379211947321892 accuracy: 0.6875


Episode 110/200: 100%|██████████| 5/5 [00:07<00:00,  1.41s/it]


Episode 110/200 loss: 0.03331632539629936 accuracy: 0.775


Episode 111/200: 100%|██████████| 5/5 [00:07<00:00,  1.41s/it]


Episode 111/200 loss: 0.03424717113375664 accuracy: 0.7


Episode 112/200: 100%|██████████| 5/5 [00:07<00:00,  1.41s/it]


Episode 112/200 loss: 0.033522870391607285 accuracy: 0.725


Episode 113/200: 100%|██████████| 5/5 [00:06<00:00,  1.40s/it]


Episode 113/200 loss: 0.03480082377791405 accuracy: 0.7625


Episode 114/200: 100%|██████████| 5/5 [00:06<00:00,  1.40s/it]


Episode 114/200 loss: 0.032668907195329666 accuracy: 0.6375


Episode 115/200: 100%|██████████| 5/5 [00:06<00:00,  1.40s/it]


Episode 115/200 loss: 0.03390749171376228 accuracy: 0.6625


Episode 116/200: 100%|██████████| 5/5 [00:06<00:00,  1.39s/it]


Episode 116/200 loss: 0.03391314297914505 accuracy: 0.675


Episode 117/200: 100%|██████████| 5/5 [00:07<00:00,  1.43s/it]


Episode 117/200 loss: 0.03511761128902435 accuracy: 0.675


Episode 118/200: 100%|██████████| 5/5 [00:07<00:00,  1.40s/it]


Episode 118/200 loss: 0.03323497995734215 accuracy: 0.6875


Episode 119/200: 100%|██████████| 5/5 [00:06<00:00,  1.37s/it]


Episode 119/200 loss: 0.034874945878982544 accuracy: 0.7375


Episode 120/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 120/200 loss: 0.03436115384101868 accuracy: 0.7125


Episode 121/200: 100%|██████████| 5/5 [00:06<00:00,  1.37s/it]


Episode 121/200 loss: 0.033943288028240204 accuracy: 0.75


Episode 122/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 122/200 loss: 0.03473625332117081 accuracy: 0.65


Episode 123/200: 100%|██████████| 5/5 [00:06<00:00,  1.38s/it]


Episode 123/200 loss: 0.031384605914354324 accuracy: 0.7125


Episode 124/200: 100%|██████████| 5/5 [00:07<00:00,  1.41s/it]


Episode 124/200 loss: 0.031777460128068924 accuracy: 0.675


Episode 125/200: 100%|██████████| 5/5 [00:06<00:00,  1.37s/it]


Episode 125/200 loss: 0.03269451856613159 accuracy: 0.725


Episode 126/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 126/200 loss: 0.03198665753006935 accuracy: 0.7


Episode 127/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 127/200 loss: 0.032449182122945786 accuracy: 0.7


Episode 128/200: 100%|██████████| 5/5 [00:07<00:00,  1.41s/it]


Episode 128/200 loss: 0.03297121077775955 accuracy: 0.6625


Episode 129/200: 100%|██████████| 5/5 [00:07<00:00,  1.41s/it]


Episode 129/200 loss: 0.0331985242664814 accuracy: 0.725


Episode 130/200: 100%|██████████| 5/5 [00:07<00:00,  1.41s/it]


Episode 130/200 loss: 0.02998625673353672 accuracy: 0.5875


Episode 131/200: 100%|██████████| 5/5 [00:07<00:00,  1.40s/it]


Episode 131/200 loss: 0.03163003548979759 accuracy: 0.7125


Episode 132/200: 100%|██████████| 5/5 [00:06<00:00,  1.34s/it]


Episode 132/200 loss: 0.03289437294006348 accuracy: 0.7375


Episode 133/200: 100%|██████████| 5/5 [00:06<00:00,  1.39s/it]


Episode 133/200 loss: 0.03095608949661255 accuracy: 0.65


Episode 134/200: 100%|██████████| 5/5 [00:06<00:00,  1.35s/it]


Episode 134/200 loss: 0.03147038817405701 accuracy: 0.6875


Episode 135/200: 100%|██████████| 5/5 [00:06<00:00,  1.34s/it]


Episode 135/200 loss: 0.03153738006949425 accuracy: 0.7


Episode 136/200: 100%|██████████| 5/5 [00:06<00:00,  1.39s/it]


Episode 136/200 loss: 0.032868217676877975 accuracy: 0.7125


Episode 137/200: 100%|██████████| 5/5 [00:06<00:00,  1.39s/it]


Episode 137/200 loss: 0.032066792249679565 accuracy: 0.7375


Episode 138/200: 100%|██████████| 5/5 [00:07<00:00,  1.43s/it]


Episode 138/200 loss: 0.03302815183997154 accuracy: 0.6875


Episode 139/200: 100%|██████████| 5/5 [00:06<00:00,  1.37s/it]


Episode 139/200 loss: 0.03155415877699852 accuracy: 0.75


Episode 140/200: 100%|██████████| 5/5 [00:06<00:00,  1.40s/it]


Episode 140/200 loss: 0.030845988541841507 accuracy: 0.7375


Episode 141/200: 100%|██████████| 5/5 [00:07<00:00,  1.43s/it]


Episode 141/200 loss: 0.0323631577193737 accuracy: 0.675


Episode 142/200: 100%|██████████| 5/5 [00:06<00:00,  1.37s/it]


Episode 142/200 loss: 0.032054755836725235 accuracy: 0.7125


Episode 143/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 143/200 loss: 0.031612418591976166 accuracy: 0.6875


Episode 144/200: 100%|██████████| 5/5 [00:07<00:00,  1.42s/it]


Episode 144/200 loss: 0.03157773241400719 accuracy: 0.7125


Episode 145/200: 100%|██████████| 5/5 [00:06<00:00,  1.38s/it]


Episode 145/200 loss: 0.030729616060853004 accuracy: 0.725


Episode 146/200: 100%|██████████| 5/5 [00:06<00:00,  1.38s/it]


Episode 146/200 loss: 0.029063845053315163 accuracy: 0.675


Episode 147/200: 100%|██████████| 5/5 [00:07<00:00,  1.47s/it]


Episode 147/200 loss: 0.030055657029151917 accuracy: 0.7875


Episode 148/200: 100%|██████████| 5/5 [00:07<00:00,  1.44s/it]


Episode 148/200 loss: 0.03234172984957695 accuracy: 0.8


Episode 149/200: 100%|██████████| 5/5 [00:07<00:00,  1.44s/it]


Episode 149/200 loss: 0.028746793046593666 accuracy: 0.65


Episode 150/200: 100%|██████████| 5/5 [00:07<00:00,  1.44s/it]


Episode 150/200 loss: 0.029864205047488213 accuracy: 0.6375


Episode 151/200: 100%|██████████| 5/5 [00:07<00:00,  1.43s/it]


Episode 151/200 loss: 0.03226158022880554 accuracy: 0.7625


Episode 152/200: 100%|██████████| 5/5 [00:07<00:00,  1.44s/it]


Episode 152/200 loss: 0.030806338414549828 accuracy: 0.675


Episode 153/200: 100%|██████████| 5/5 [00:06<00:00,  1.39s/it]


Episode 153/200 loss: 0.030379969626665115 accuracy: 0.725


Episode 154/200: 100%|██████████| 5/5 [00:06<00:00,  1.34s/it]


Episode 154/200 loss: 0.030770665034651756 accuracy: 0.8125


Episode 155/200: 100%|██████████| 5/5 [00:06<00:00,  1.35s/it]


Episode 155/200 loss: 0.030031254515051842 accuracy: 0.75


Episode 156/200: 100%|██████████| 5/5 [00:06<00:00,  1.35s/it]


Episode 156/200 loss: 0.02851925790309906 accuracy: 0.75


Episode 157/200: 100%|██████████| 5/5 [00:06<00:00,  1.35s/it]


Episode 157/200 loss: 0.030833525583148003 accuracy: 0.675


Episode 158/200: 100%|██████████| 5/5 [00:06<00:00,  1.33s/it]


Episode 158/200 loss: 0.028176456689834595 accuracy: 0.7


Episode 159/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 159/200 loss: 0.02897060476243496 accuracy: 0.7


Episode 160/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 160/200 loss: 0.02871864102780819 accuracy: 0.65


Episode 161/200: 100%|██████████| 5/5 [00:06<00:00,  1.35s/it]


Episode 161/200 loss: 0.02854323945939541 accuracy: 0.7375


Episode 162/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 162/200 loss: 0.030000198632478714 accuracy: 0.6875


Episode 163/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 163/200 loss: 0.029101744294166565 accuracy: 0.6875


Episode 164/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 164/200 loss: 0.027536148205399513 accuracy: 0.65


Episode 165/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 165/200 loss: 0.029189346358180046 accuracy: 0.75


Episode 166/200: 100%|██████████| 5/5 [00:06<00:00,  1.37s/it]


Episode 166/200 loss: 0.02875472418963909 accuracy: 0.775


Episode 167/200: 100%|██████████| 5/5 [00:06<00:00,  1.37s/it]


Episode 167/200 loss: 0.02805282175540924 accuracy: 0.7375


Episode 168/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 168/200 loss: 0.027757743373513222 accuracy: 0.6875


Episode 169/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 169/200 loss: 0.029542630538344383 accuracy: 0.7


Episode 170/200: 100%|██████████| 5/5 [00:06<00:00,  1.37s/it]


Episode 170/200 loss: 0.027959490194916725 accuracy: 0.6625


Episode 171/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 171/200 loss: 0.030472805723547935 accuracy: 0.725


Episode 172/200: 100%|██████████| 5/5 [00:06<00:00,  1.37s/it]


Episode 172/200 loss: 0.0275118388235569 accuracy: 0.7


Episode 173/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 173/200 loss: 0.0295927282422781 accuracy: 0.7375


Episode 174/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 174/200 loss: 0.02825266122817993 accuracy: 0.8125


Episode 175/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 175/200 loss: 0.028327900916337967 accuracy: 0.625


Episode 176/200: 100%|██████████| 5/5 [00:06<00:00,  1.38s/it]


Episode 176/200 loss: 0.02727987803518772 accuracy: 0.65


Episode 177/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 177/200 loss: 0.02726253867149353 accuracy: 0.775


Episode 178/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 178/200 loss: 0.028024299070239067 accuracy: 0.7375


Episode 179/200: 100%|██████████| 5/5 [00:06<00:00,  1.35s/it]


Episode 179/200 loss: 0.02565539814531803 accuracy: 0.7


Episode 180/200: 100%|██████████| 5/5 [00:06<00:00,  1.36s/it]


Episode 180/200 loss: 0.0273252185434103 accuracy: 0.725


Episode 181/200: 100%|██████████| 5/5 [00:06<00:00,  1.33s/it]


Episode 181/200 loss: 0.026982281357049942 accuracy: 0.75


Episode 182/200: 100%|██████████| 5/5 [00:06<00:00,  1.28s/it]


Episode 182/200 loss: 0.027219945564866066 accuracy: 0.75


Episode 183/200: 100%|██████████| 5/5 [00:06<00:00,  1.30s/it]


Episode 183/200 loss: 0.02712741494178772 accuracy: 0.6375


Episode 184/200: 100%|██████████| 5/5 [00:06<00:00,  1.29s/it]


Episode 184/200 loss: 0.028095562011003494 accuracy: 0.7125


Episode 185/200: 100%|██████████| 5/5 [00:06<00:00,  1.28s/it]


Episode 185/200 loss: 0.02668987773358822 accuracy: 0.6625


Episode 186/200: 100%|██████████| 5/5 [00:06<00:00,  1.31s/it]


Episode 186/200 loss: 0.026452992111444473 accuracy: 0.6625


Episode 187/200: 100%|██████████| 5/5 [00:06<00:00,  1.29s/it]


Episode 187/200 loss: 0.026751196011900902 accuracy: 0.6875


Episode 188/200: 100%|██████████| 5/5 [00:06<00:00,  1.33s/it]


Episode 188/200 loss: 0.027423067018389702 accuracy: 0.7


Episode 189/200: 100%|██████████| 5/5 [00:06<00:00,  1.37s/it]


Episode 189/200 loss: 0.027180999517440796 accuracy: 0.65


Episode 190/200: 100%|██████████| 5/5 [00:06<00:00,  1.22s/it]


Episode 190/200 loss: 0.02710375003516674 accuracy: 0.675


Episode 191/200: 100%|██████████| 5/5 [00:06<00:00,  1.30s/it]


Episode 191/200 loss: 0.025478381663560867 accuracy: 0.6625


Episode 192/200: 100%|██████████| 5/5 [00:06<00:00,  1.35s/it]


Episode 192/200 loss: 0.02512943185865879 accuracy: 0.6875


Episode 193/200: 100%|██████████| 5/5 [00:06<00:00,  1.32s/it]


Episode 193/200 loss: 0.02677147462964058 accuracy: 0.6625


Episode 194/200: 100%|██████████| 5/5 [00:06<00:00,  1.31s/it]


Episode 194/200 loss: 0.027180958539247513 accuracy: 0.6625


Episode 195/200: 100%|██████████| 5/5 [00:06<00:00,  1.30s/it]


Episode 195/200 loss: 0.025325072929263115 accuracy: 0.7


Episode 196/200: 100%|██████████| 5/5 [00:06<00:00,  1.29s/it]


Episode 196/200 loss: 0.026265442371368408 accuracy: 0.6375


Episode 197/200: 100%|██████████| 5/5 [00:06<00:00,  1.30s/it]


Episode 197/200 loss: 0.025554409250617027 accuracy: 0.725


Episode 198/200: 100%|██████████| 5/5 [00:06<00:00,  1.30s/it]


Episode 198/200 loss: 0.026861492544412613 accuracy: 0.725


Episode 199/200: 100%|██████████| 5/5 [00:06<00:00,  1.29s/it]


Episode 199/200 loss: 0.025587117299437523 accuracy: 0.6625


Episode 200/200: 100%|██████████| 5/5 [00:06<00:00,  1.31s/it]


Episode 200/200 loss: 0.0270228274166584 accuracy: 0.7


In [13]:
accuracies = {}
lengths = [1, 2, 3, 4, 5]
for n_results in lengths:
    datapoints = []
    for id, question, vector in questions:
        results = collection.query(query_texts=question, n_results=n_results)
        result_ids = results["ids"][0]
        if id in result_ids:
            datapoints.append(1)
        else:
            datapoints.append(0)

    accuracies[n_results] = sum(datapoints) / len(datapoints)
accuracies

ValueError: Expected embeddings to be a list, got [[-0.01044687 -0.02135091  0.01859619 ...  0.01973298 -0.00316574
  -0.01486639]]