### Обучение E5 и валидация! 

In [2]:
# !pip install transformers datasets

In [3]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
import pytorch_lightning as pl
from torch.nn import TripletMarginLoss
from torch.nn.functional import normalize
from sklearn.model_selection import train_test_split
import numpy as np
import faiss
from tqdm import tqdm
from datetime import datetime
import os
import random


# ——— 1. Токенизация и подготовка данных ———
# def dummy_tokenize(text: str):
#     return text.lower()

class TripletDataset(Dataset):
    def __init__(self, df, tokenizer, max_length=64, num_negatives=10):
        self.samples = []
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.num_negatives = num_negatives
        self.all_products = df['product_title'].unique()
        self._build_triplets(df)

    def _build_triplets(self, df):
        n = len(df)
        for i in range(self.num_negatives):
            negs = random.choices(self.all_products, k=n)
            for idx, row in enumerate(df.itertuples(index=False)):
                query = row.query
                pos = row.product_title
                neg = negs[idx]
                self.samples.append((query, pos, neg))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        q, pos, neg = self.samples[idx]

        # Токенизация сразу для модели
        anchor_enc = self.tokenizer(
            f"query: {q}", padding='max_length', truncation=True,
            max_length=self.max_length, return_tensors='pt'
        )
        pos_enc = self.tokenizer(
            f"passage: {pos}", padding='max_length', truncation=True,
            max_length=self.max_length, return_tensors='pt'
        )
        neg_enc = self.tokenizer(
            f"passage: {neg}", padding='max_length', truncation=True,
            max_length=self.max_length, return_tensors='pt'
        )

        # Возвращаем только input_ids и attention_mask, как нужно для forward
        return (
            (anchor_enc['input_ids'].squeeze(0), anchor_enc['attention_mask'].squeeze(0)),
            (pos_enc['input_ids'].squeeze(0), pos_enc['attention_mask'].squeeze(0)),
            (neg_enc['input_ids'].squeeze(0), neg_enc['attention_mask'].squeeze(0)),

            f"query: {q}", # для дебага
            f"passage: {pos}",
            f"passage: {neg}"

        )

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from datasets import load_dataset
from tqdm import tqdm
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import pytorch_lightning as pl
from datetime import datetime

# Предполагается, что TripletDataset и E5Model уже определены ранее

# ——— 1. Подготовка данных ———
def prepare_data(model_name='intfloat/e5-small', batch_size=16, sample_rate=1.0):
    print("🔽 Загружаем датасет tasksource/esci...")
    dataset = load_dataset("tasksource/esci", split="train")
    dataset_len = len(dataset)

    # семплим чтобы быстрее отдебажить!
    dataset_current_len = int(dataset_len * sample_rate)
    dataset = dataset.shuffle(seed=42).select(range(dataset_current_len))

    print("📦 Конвертируем в pandas DataFrame...")
    df = pd.DataFrame([x for x in tqdm(dataset, desc="→ Преобразование строк")])

    print("🧹 Фильтруем классы: Exact / Substitute / Irrelevant...")
    df = df[df['esci_label'].isin(['Exact', 'Substitute', 'Irrelevant'])]

    print("🔍 Удаляем запросы с < 2 примерами...")
    query_counts = df['query'].value_counts()
    df = df[df['query'].isin(query_counts[query_counts >= 2].index)]

    print("✂️ Разбиваем на train/val...")
    train_df, val_df = train_test_split(
        df, test_size=0.1, random_state=42
        # , stratify=df['query']
    )
    print(f"✅ Train size: {len(train_df)} / Val size: {len(val_df)}")

    print(f"📚 Загружаем токенизатор: {model_name}")
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    print("📐 Создаём TripletDataset'ы...")
    train_dataset = TripletDataset(train_df, tokenizer)
    val_dataset = TripletDataset(val_df, tokenizer)

    print(f"📊 Train triplets: {len(train_dataset)} / Val triplets: {len(val_dataset)}")

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, drop_last=True)

    print(f"Train batches: {len(train_loader)} / Val batches: {len(val_loader)} | {batch_size=}")


    return train_loader, val_loader, df

In [5]:
batch_size = 128

In [6]:
# ——— 3. Основной запуск ———
train_loader, val_loader, df = prepare_data(sample_rate=0.01, batch_size=batch_size)

🔽 Загружаем датасет tasksource/esci...
📦 Конвертируем в pandas DataFrame...


→ Преобразование строк: 100%|██████████| 20278/20278 [00:03<00:00, 6354.40it/s]


🧹 Фильтруем классы: Exact / Substitute / Irrelevant...
🔍 Удаляем запросы с < 2 примерами...
✂️ Разбиваем на train/val...
✅ Train size: 3657 / Val size: 407
📚 Загружаем токенизатор: intfloat/e5-small
📐 Создаём TripletDataset'ы...
📊 Train triplets: 36570 / Val triplets: 4070
Train batches: 285 / Val batches: 31 | batch_size=128


In [7]:
# val_loader.dataset[0]

In [8]:
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel

# https://huggingface.co/intfloat/multilingual-e5-small
# --- Класс для инференса батчей ---
class E5InferenceModel:
    def __init__(self, model_name='intfloat/e5-small', device=None):
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(self.device)
        self.model.eval()

    def encode_batch(self, input_ids, attention_mask):
        with torch.no_grad():
            outputs = self.model(input_ids=input_ids.to(self.device), attention_mask=attention_mask.to(self.device))
            return outputs.last_hidden_state[:, 0].cpu().numpy()

# --- Класс для метрик внутри батча ---
class RetrievalMetrics:
    @staticmethod
    def recall_at_k_batch(anchor_embs, product_embs, k_list=[5, 10, 30]):
        recalls = {k: 0 for k in k_list}
        n = len(anchor_embs)
        for i, a_emb in enumerate(anchor_embs):
            scores = np.dot(product_embs, a_emb)
            top_indices = np.argsort(scores)[-max(k_list):][::-1]
            for k in k_list:
                # Positive всегда на позиции i (по построению TripletDataset)
                if i in top_indices[:k]:
                    recalls[k] += 1
        for k in k_list:
            recalls[k] /= n
        return recalls

In [9]:

# --- Пример использования с val_loader ---
model_name = 'intfloat/e5-small'
device = 'cuda:6' if torch.cuda.is_available() else 'cpu'
inference_model = E5InferenceModel(model_name=model_name, device=device)

all_recalls = {k: [] for k in [1, 5, 10, 30]}
k_list = [1, 5, 10, 30]

for batch in tqdm(val_loader, desc="🔍 Pretrain E5 на батчах"):

    (anchor_ids, anchor_mask), (pos_ids, pos_mask), (neg_ids, neg_mask), q, pos, neg = batch

    # Собираем все продукты батча (positive + negative)
    batch_product_ids = torch.cat([pos_ids, neg_ids], dim=0)
    batch_product_mask = torch.cat([pos_mask, neg_mask], dim=0)

    # Эмбеддинги
    anchor_embs = inference_model.encode_batch(anchor_ids, anchor_mask)
    product_embs = inference_model.encode_batch(batch_product_ids, batch_product_mask)

    # Метрики
    recalls = RetrievalMetrics.recall_at_k_batch(anchor_embs, product_embs, k_list=k_list)
    for k in k_list:
        all_recalls[k].append(recalls[k])

# Усреднение по всем батчам
for k in k_list:
    mean_recall = np.mean(all_recalls[k])
    print(f"Recall@{k}: {mean_recall:.4f}")

🔍 Pretrain E5 на батчах: 100%|██████████| 31/31 [00:04<00:00,  6.27it/s]

Recall@1: 0.2686
Recall@5: 0.5129
Recall@10: 0.5862
Recall@30: 0.6971





In [10]:
# Сохраняем модель
# Путь для сохранения
SAVE_DIR = "saved_e5_model"

# Сохраняем модель и токенизатор
inference_model.model.save_pretrained(SAVE_DIR)
inference_model.tokenizer.save_pretrained(SAVE_DIR)

print(f"✅ Модель и токенизатор сохранены в: {SAVE_DIR}")

✅ Модель и токенизатор сохранены в: saved_e5_model


In [11]:
from transformers import AutoModel, AutoTokenizer

# Загружаем модель и токенизатор
_model = AutoModel.from_pretrained("saved_e5_model")
_tokenizer = AutoTokenizer.from_pretrained("saved_e5_model")

# _model.eval()  # Важно для инференса

In [12]:
text = "query: wireless mouse"
inputs = _tokenizer(text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
    output = _model(**inputs)
    embedding = output.last_hidden_state[:, 0]  # [CLS] токен

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


### Сравниваем метрику с BM25

In [13]:
from rank_bm25 import BM25Okapi
from nltk.tokenize import word_tokenize

def tokenize(text):
    return word_tokenize(text.lower())

all_products = df['product_title'].dropna().unique().tolist()
all_products = [f"passage: {p}" for p in all_products]
tokenized_products_all = [tokenize(p) for p in all_products]
bm25_full = BM25Okapi(tokenized_products_all)

In [14]:
from collections import defaultdict
bm25_recalls = defaultdict(list)
k_list = [1, 5, 10, 30]

product_idx_map = {title: i for i, title in enumerate(all_products)}

for batch in tqdm(val_loader, desc="🔍 BM25 на батчах"):
    _, _, _, queries, pos_titles, neg_titles = batch

    # --- 1. Подготовим список из продуктов текущего батча ---
    batch_products = pos_titles + neg_titles

    # --- 2. Индексы этих товаров в all_products (для score фильтрации) ---
    batch_indices = [product_idx_map[p] for p in batch_products if p in product_idx_map]

    for q, true_title in zip(queries, pos_titles):
        q_tokens = tokenize(q)
        scores_all = bm25_full.get_scores(q_tokens)

        # --- 3. Оставим только скоры товаров из текущего батча ---
        scores_batch = [(i, scores_all[i]) for i in batch_indices]
        top_indices = sorted(scores_batch, key=lambda x: x[1], reverse=True)
        top_titles = [all_products[i] for i, _ in top_indices]
        
        for k in k_list:
            bm25_recalls[k].append(int(true_title in top_titles[:k]))

# --- Усреднение ---
for k in k_list:
    print(f"Recall@{k} (BM25): {np.mean(bm25_recalls[k]):.4f}")

🔍 BM25 на батчах: 100%|██████████| 31/31 [00:12<00:00,  2.55it/s]

Recall@1 (BM25): 0.5731
Recall@5 (BM25): 0.6822
Recall@10 (BM25): 0.7145
Recall@30 (BM25): 0.7631





### Дообучение претраин E5

In [15]:
from torch.utils.tensorboard import SummaryWriter

# Запускаем в фоне (можно использовать tmux) борду
# # tmux new -s e5_train
# tensorboard --logdir tb_logs

In [29]:
class E5Model(torch.nn.Module):
    def __init__(self, model_name='intfloat/e5-small'):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

    def encode(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.last_hidden_state[:, 0]

    def forward(self, anchor_ids, anchor_mask, pos_ids, pos_mask, neg_ids, neg_mask):
        anchor_emb = self.encode(anchor_ids, anchor_mask)
        pos_emb = self.encode(pos_ids, pos_mask)
        neg_emb = self.encode(neg_ids, neg_mask)
        return anchor_emb, pos_emb, neg_emb

In [30]:
def eval_model(model, val_loader, device='cuda'):
    model.eval()
    model.to(device)

    all_recalls = {k: [] for k in [1, 5, 10, 30]}

    with torch.no_grad():
        # for batch in tqdm(val_loader, desc="🔎 Eval"):
        for batch in val_loader:
            (a_ids, a_mask), (p_ids, p_mask), (n_ids, n_mask), _, _, _ = batch

            # Переносим на нужное устройство
            a_ids, a_mask = a_ids.to(device), a_mask.to(device)
            p_ids, p_mask = p_ids.to(device), p_mask.to(device)
            n_ids, n_mask = n_ids.to(device), n_mask.to(device)

            # Получаем эмбеддинги
            anchor_embs = model.encode(a_ids, a_mask).cpu().numpy()
            pos_embs = model.encode(p_ids, p_mask).cpu().numpy()
            neg_embs = model.encode(n_ids, n_mask).cpu().numpy()

            # Собираем "пул" продуктов (positive + negative)
            product_embs = np.concatenate([pos_embs, neg_embs], axis=0)

            recalls = RetrievalMetrics.recall_at_k_batch(anchor_embs, product_embs, k_list=k_list)
            for k in recalls:
                all_recalls[k].append(recalls[k])
    
    # Усреднение и печать
    all_means_recalls = {}
    for k in all_recalls:
        all_means_recalls[k] = np.mean(all_recalls[k])
    return all_means_recalls

In [65]:
import os
CHECKPOINTS_DIR = "checkpoints"
os.makedirs(CHECKPOINTS_DIR, exist_ok=True)


def train_model(model, train_loader, val_loader, num_epochs=3, lr=2e-5, device='cuda', every_n_step_do_val=50):
    time_suffix = str(datetime.now().strftime('%Y%m%d_%H%M%S'))
    run_name = f"e5_train_{time_suffix}" # and model_name!
    writer = SummaryWriter(log_dir=f"runs/{run_name}")

    
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    loss_fn = TripletMarginLoss(margin=0.2)
    
    global_step = 1
    for epoch in range(num_epochs):
        model.train()

        for batch_id, batch in tqdm(enumerate(train_loader), desc=f"🛠️ Epoch {epoch + 1}/{num_epochs}", total=len(train_loader)):
            writer.add_scalar("train/epoch_marker", epoch, global_step)
            (a_ids, a_mask), (p_ids, p_mask), (n_ids, n_mask), _, _, _ = batch

            a_ids, a_mask = a_ids.to(device), a_mask.to(device)
            p_ids, p_mask = p_ids.to(device), p_mask.to(device)
            n_ids, n_mask = n_ids.to(device), n_mask.to(device)

            anchor, pos, neg = model(a_ids, a_mask, p_ids, p_mask, n_ids, n_mask)
            loss = loss_fn(anchor, pos, neg)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1e9)

            writer.add_scalar("train/loss", loss.item(), global_step)
            writer.add_scalar("GradNorm/train", grad_norm, global_step)
            writer.add_scalar("LR/train", optimizer.param_groups[0]['lr'], global_step)

            train_recalls = eval_model(model, [batch], device=device)
            for k, val in train_recalls.items():
                writer.add_scalar(f"train/recall@{k}", val, global_step)



            if batch_id % every_n_step_do_val == 0:
                recalls = eval_model(model, val_loader, device=device)
                for k, val in recalls.items():
                    writer.add_scalar(f"val/recall@{k}", val, global_step)
            
            global_step += 1    
    
    print("✅ Обучение завершено. Сохраняем модель...")
    save_path = os.path.join(CHECKPOINTS_DIR, f"{run_name}.pt")
    torch.save(model.state_dict(), save_path)
    print(f"📦 Модель сохранена в: {save_path}")

In [68]:
# train_loader, val_loader, _ = prepare_data(batch_size=16, sample_rate=0.05)
model = E5Model(model_name='intfloat/e5-small')
train_model(model, train_loader, val_loader, num_epochs=3, device='cuda:6', lr=1e-4, every_n_step_do_val=10)



🛠️ Epoch 1/10: 100%|██████████| 285/285 [04:22<00:00,  1.09it/s]
🛠️ Epoch 2/10: 100%|██████████| 285/285 [04:22<00:00,  1.08it/s]
🛠️ Epoch 3/10: 100%|██████████| 285/285 [04:22<00:00,  1.09it/s]
🛠️ Epoch 4/10: 100%|██████████| 285/285 [04:22<00:00,  1.08it/s]
🛠️ Epoch 5/10: 100%|██████████| 285/285 [04:22<00:00,  1.09it/s]
🛠️ Epoch 6/10: 100%|██████████| 285/285 [04:22<00:00,  1.08it/s]
🛠️ Epoch 7/10: 100%|██████████| 285/285 [04:23<00:00,  1.08it/s]
🛠️ Epoch 8/10: 100%|██████████| 285/285 [04:23<00:00,  1.08it/s]
🛠️ Epoch 9/10: 100%|██████████| 285/285 [04:24<00:00,  1.08it/s]
🛠️ Epoch 10/10: 100%|██████████| 285/285 [04:23<00:00,  1.08it/s]


✅ Обучение завершено. Сохраняем модель...
📦 Модель сохранена в: checkpoints/e5_train_20250703_175235.pt


In [64]:
train_loader.dataset[20][3]

'query: tshirt jordan men'