<a href="https://colab.research.google.com/github/juliawol/WB_Embedder/blob/main/WB_Embedder_GigaChat.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import DataLoader
from datasets import load_dataset, Dataset
import pandas as pd

# Configuration
MODEL_NAME = "ai-sage/Giga-Embeddings-instruct"
BATCH_SIZE = 32
MAX_LENGTH = 256
EPOCHS = 3
LEARNING_RATE = 1e-5
WARMUP_STEPS = 500
RANDOM_SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Dataset paths
CARDS_DATASET = "JuliaWolken/WB_CARDS"
TRIPLETS_DATASET = "JuliaWolken/WB_TRIPLETS"
BRANDS_DATASET = "JuliaWolken/WB_BRANDS"

# Instructional prefixes
task_name_to_instruct = {
    "retrieval": "Дано вопрос, необходимо получить наиболее релевантный товар\nquestion: ",
    "ranking": "Rank items based on their relevance",
    "classification": "Classify the given input into predefined categories",
}
query_prefix = task_name_to_instruct["retrieval"]
passage_prefix = ""

# Set random seed
random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

# Data Loading
print("Loading datasets...")
data_sampled = load_dataset(CARDS_DATASET)["train"]
triplet_candidates = load_dataset(TRIPLETS_DATASET)["train"]
brand_candidates = load_dataset(BRANDS_DATASET)["train"]

data_sampled_df = data_sampled.to_pandas()
triplet_candidates_df = triplet_candidates.to_pandas()
brand_candidates_df = brand_candidates.to_pandas()


# Dataset Class and Loaders
class TripletDataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer, query_prefix, passage_prefix, max_length=256):
        self.data = data
        self.tokenizer = tokenizer
        self.query_prefix = query_prefix
        self.passage_prefix = passage_prefix
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        query, positive, negative = row["Anchor"], row["Positive"], row["Negative"]

        query_text = self.query_prefix + query
        positive_text = self.passage_prefix + positive
        negative_text = self.passage_prefix + negative

        query_enc = self.tokenizer(query_text, truncation=True, padding="max_length", max_length=self.max_length, return_tensors="pt")
        positive_enc = self.tokenizer(positive_text, truncation=True, padding="max_length", max_length=self.max_length, return_tensors="pt")
        negative_enc = self.tokenizer(negative_text, truncation=True, padding="max_length", max_length=self.max_length, return_tensors="pt")

        return {
            "query_input_ids": query_enc["input_ids"].squeeze(0),
            "query_attention_mask": query_enc["attention_mask"].squeeze(0),
            "positive_input_ids": positive_enc["input_ids"].squeeze(0),
            "positive_attention_mask": positive_enc["attention_mask"].squeeze(0),
            "negative_input_ids": negative_enc["input_ids"].squeeze(0),
            "negative_attention_mask": negative_enc["attention_mask"].squeeze(0),
        }

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Triplet loader
triplet_dataset = TripletDataset(triplet_candidates_df, tokenizer, query_prefix, passage_prefix)
triplet_loader = DataLoader(triplet_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Model Definition
class MultiTaskModel(nn.Module):
    def __init__(self, model_name, num_classes=60):
        super(MultiTaskModel, self).__init__()
        self.encoder = AutoModel.from_pretrained(model_name, trust_remote_code=True)
        self.classification_head = nn.Linear(self.encoder.config.hidden_size, num_classes)  # Classification tasks
        self.ranking_head = nn.Linear(self.encoder.config.hidden_size, 1)  # Ranking tasks

    def forward(self, input_ids, attention_mask, task="classification"):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        cls_emb = outputs.last_hidden_state[:, 0, :]
        if task == "classification":
            return self.classification_head(cls_emb)
        elif task == "ranking":
            return self.ranking_head(cls_emb)
        else:
            raise ValueError("Unknown task")

# Initialize model
model = MultiTaskModel(MODEL_NAME).to(DEVICE)

# Optimization and Loss
def contrastive_loss(query_emb, positive_emb, negative_emb, margin=0.2):
    return F.triplet_margin_loss(query_emb, positive_emb, negative_emb, margin=margin)

classification_criterion = nn.CrossEntropyLoss()  # For classification tasks

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=WARMUP_STEPS, num_training_steps=(len(triplet_loader)) * EPOCHS)


# Training Loop
print("Starting fine-tuning...")
model.train()

for epoch in range(EPOCHS):
    for step, batch in enumerate(triplet_loader):
        query_input_ids = batch["query_input_ids"].to(DEVICE)
        query_attention_mask = batch["query_attention_mask"].to(DEVICE)
        positive_input_ids = batch["positive_input_ids"].to(DEVICE)
        positive_attention_mask = batch["positive_attention_mask"].to(DEVICE)
        negative_input_ids = batch["negative_input_ids"].to(DEVICE)
        negative_attention_mask = batch["negative_attention_mask"].to(DEVICE)

        # Encode embeddings
        query_emb = model.encoder(input_ids=query_input_ids, attention_mask=query_attention_mask).last_hidden_state[:, 0, :]
        positive_emb = model.encoder(input_ids=positive_input_ids, attention_mask=positive_attention_mask).last_hidden_state[:, 0, :]
        negative_emb = model.encoder(input_ids=negative_input_ids, attention_mask=negative_attention_mask).last_hidden_state[:, 0, :]

        # Normalize embeddings
        query_emb = F.normalize(query_emb, p=2, dim=1)
        positive_emb = F.normalize(positive_emb, p=2, dim=1)
        negative_emb = F.normalize(negative_emb, p=2, dim=1)

        # Compute contrastive loss
        loss = contrastive_loss(query_emb, positive_emb, negative_emb)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        if step % 10 == 0:
            print(f"Epoch {epoch+1}/{EPOCHS}, Step {step}, Loss: {loss.item():.4f}")


# Save the Model
os.makedirs("fine_tuned_model", exist_ok=True)
model.encoder.save_pretrained("fine_tuned_model")
torch.save(model.classification_head.state_dict(), "fine_tuned_model/classification_head.pt")
torch.save(model.ranking_head.state_dict(), "fine_tuned_model/ranking_head.pt")

print("Fine-tuning complete. Model saved.")
