In [None]:
from __future__ import annotations

import json
from copy import deepcopy
from pathlib import Path

import numpy as np
import pandas as pd
from dataclasses import dataclass
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as f
from torch.utils.data import DataLoader
from torch.optim import AdamW

import evaluate
import faiss
from datasets import Dataset
from transformers import AutoTokenizer, AutoModel

In [None]:
with open('data/retrieval_dataset.json', 'r') as file:
    retrieval_dataset = json.load(file)

In [None]:
products = pd.read_csv('data/products.csv', sep=';', index_col='id')
products.sample(n=5)

In [None]:
tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-large-instruct')
tokenizer_kwargs = dict(max_length=512, padding=True, truncation=True, return_tensors='pt')

model = AutoModel.from_pretrained('intfloat/multilingual-e5-large-instruct')

In [None]:
def average_pool(last_hidden_states, attention_mask):
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

def get_detailed_instruct(task_description: str, query: str) -> str:
    return f'Instruct: {task_description}\nQuery: {query}'

task = 'Given a Russian search query, retrieve relevant items that satisfy the query'

In [None]:
def get_embeddings(texts: list[str]) -> np.ndarray:
    batch_dict = tokenizer(texts, **tokenizer_kwargs)
    with torch.no_grad():
        outputs = model(**batch_dict)
    embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
    embeddings = f.normalize(embeddings, p=2, dim=1)
    return embeddings.numpy()

In [None]:
items = [f"{item['category']} {item['brand']} {item['name']}" for _, item in products.iterrows()]
item_embeddings = get_embeddings(items)

In [None]:
dim = 1024
index = faiss.IndexFlatIP(dim)
index.add(item_embeddings)

In [None]:
def get_top_results(query: str, num_nearest_neighbors=25):
    query_with_instruction = get_detailed_instruct(task, query)
    queries = [query_with_instruction]
    query_embeddings = get_embeddings(queries)
    distances, indices = index.search(query_embeddings, num_nearest_neighbors)
    return products.loc[indices[0]]

In [None]:
query = 'Ищу наушники Sennheiser до 20к.'

get_top_results(query)

In [None]:
model.save_pretrained('models/retrieval')