<a href="https://colab.research.google.com/github/juliawol/WB_Embedder/blob/main/Fine_tuned_Embedder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install datasets transformers torch

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m15.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl 

In [20]:
from datasets import load_dataset
import pandas as pd

# Paths for separate datasets
CARDS_DATASET = "JuliaWolken/WB_CARDS"
TRIPLETS_DATASET = "JuliaWolken/WB_TRIPLETS"
BRANDS_DATASET = "JuliaWolken/WB_BRANDS"

# Load datasets
print("Loading main dataset (cards)...")
data_sampled = load_dataset(CARDS_DATASET)["train"]
data_sampled_df = data_sampled.to_pandas()

print("Loading triplet dataset...")
triplet_candidates = load_dataset(TRIPLETS_DATASET)["train"]
triplet_candidates_df = triplet_candidates.to_pandas()

print("Loading brand dataset...")
brand_candidates = load_dataset(BRANDS_DATASET)["train"]
brand_candidates_df = brand_candidates.to_pandas()

# Validate loaded data
print("\nMain dataset (data_sampled_30.csv):")
print(data_sampled_df.head())

print("\nTriplet candidates (triplet_candidates.csv):")
print(triplet_candidates_df.head())

print("\nBrand candidates (brand_candidates.csv):")
print(brand_candidates_df.head())



Loading main dataset (cards)...


Generating train split: 0 examples [00:00, ? examples/s]

Loading triplet dataset...


triplet_candidates.csv:   0%|          | 0.00/843M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/127892 [00:00<?, ? examples/s]

Loading brand dataset...


brand_candidates.csv:   0%|          | 0.00/612M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/22735953 [00:00<?, ? examples/s]


Main dataset (data_sampled_30.csv):
                             aggregated_charc_values  \
0  Материал изделия: ЛДСП\nВес с упаковкой (кг): ...   
1  Цвет: красный\nШирина упаковки: 10 см \nСовмес...   
2  Высота предмета: 200 см \nСтиль дизайна: Миним...   
3  Высота предмета: 200 см \nСтиль дизайна: Миним...   
4  Ставка НДС: Без НДС\nВес без упаковки (кг): 13...   

                                               title  \
0              Набор для увеличения кровати - белый    
1           Чехол-книжка Tecno Spark 9Pro Спарк 9Про   
2  Шкаф пенал двухдверный распашной серый витрина...   
3  Шкаф пенал двухдверный распашной серый витрина...   
4  Комплект барных стульев Loft со спинкой для ку...   

                                         description  \
0  Отличный вариант для тех, кто не хочет расстав...   
1  НА ФОТО ОБРАЗЕЦ ЧЕХЛА!!! ВАМ ПРИДЕТ ЧЕХОЛ В СО...   
2  Принцесса Мелания Шкаф-витрина  - это идеально...   
3  Принцесса Мелания Шкаф-витрина  - это идеально...   
4  ВНИМАН

In [None]:
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import DataLoader
from datasets import load_dataset, Dataset
import pandas as pd

# Configuration
MODEL_NAME = "DeepPavlov/rubert-base-cased"
BATCH_SIZE = 32
MAX_LENGTH = 256
EPOCHS = 3
LEARNING_RATE = 1e-5
WARMUP_STEPS = 500
RANDOM_SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"



# Set random seed for reproducibility
random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)


# Dataset Classes and Loaders
class RetrievalDataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer, max_length):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        anchor, positive, negative = row["Anchor"], row["Positive"], row["Negative"]

        anchor_enc = self.tokenizer(anchor, truncation=True, padding="max_length", max_length=self.max_length, return_tensors="pt")
        positive_enc = self.tokenizer(positive, truncation=True, padding="max_length", max_length=self.max_length, return_tensors="pt")
        negative_enc = self.tokenizer(negative, truncation=True, padding="max_length", max_length=self.max_length, return_tensors="pt")

        return {
            "anchor_input_ids": anchor_enc["input_ids"].squeeze(0),
            "anchor_attention_mask": anchor_enc["attention_mask"].squeeze(0),
            "positive_input_ids": positive_enc["input_ids"].squeeze(0),
            "positive_attention_mask": positive_enc["attention_mask"].squeeze(0),
            "negative_input_ids": negative_enc["input_ids"].squeeze(0),
            "negative_attention_mask": negative_enc["attention_mask"].squeeze(0),
        }

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Triplet loader
triplet_dataset = RetrievalDataset(triplet_candidates_df, tokenizer, MAX_LENGTH)
triplet_loader = DataLoader(triplet_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Brand loader
brand_dataset = RetrievalDataset(brand_candidates_df, tokenizer, MAX_LENGTH)
brand_loader = DataLoader(brand_dataset, batch_size=BATCH_SIZE, shuffle=True)


# Model Definition
class MultiTaskModel(nn.Module):
    def __init__(self, model_name):
        super(MultiTaskModel, self).__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.classification_head = nn.Linear(self.encoder.config.hidden_size, 60)  # Adjust for category classification
        self.ranking_head = nn.Linear(self.encoder.config.hidden_size, 1)  # Used for ranking tasks

    def forward(self, input_ids, attention_mask, task="classification"):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        cls_emb = outputs.last_hidden_state[:, 0, :]
        if task == "classification":
            return self.classification_head(cls_emb)
        elif task == "ranking":
            return self.ranking_head(cls_emb)
        else:
            raise ValueError("Unknown task")

# Initialize model
model = MultiTaskModel(MODEL_NAME).to(DEVICE)


# Optimization and Loss
def contrastive_loss(anchor_emb, positive_emb, negative_emb, margin=0.2):
    return F.triplet_margin_loss(anchor_emb, positive_emb, negative_emb, margin=margin)

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=WARMUP_STEPS, num_training_steps=(len(triplet_loader) + len(brand_loader)) * EPOCHS)


# Training Loop
print("Starting fine-tuning...")
model.train()

for epoch in range(EPOCHS):
    for task, loader in zip(["triplet", "brand"], [triplet_loader, brand_loader]):
        for step, batch in enumerate(loader):
            anchor_input_ids = batch['anchor_input_ids'].to(DEVICE)
            anchor_attention_mask = batch['anchor_attention_mask'].to(DEVICE)
            positive_input_ids = batch['positive_input_ids'].to(DEVICE)
            positive_attention_mask = batch['positive_attention_mask'].to(DEVICE)
            negative_input_ids = batch['negative_input_ids'].to(DEVICE)
            negative_attention_mask = batch['negative_attention_mask'].to(DEVICE)

            # Encode embeddings
            anchor_emb = model.encoder(input_ids=anchor_input_ids, attention_mask=anchor_attention_mask).last_hidden_state[:, 0, :]
            positive_emb = model.encoder(input_ids=positive_input_ids, attention_mask=positive_attention_mask).last_hidden_state[:, 0, :]
            negative_emb = model.encoder(input_ids=negative_input_ids, attention_mask=negative_attention_mask).last_hidden_state[:, 0, :]

            # Normalize embeddings
            anchor_emb = F.normalize(anchor_emb, p=2, dim=1)
            positive_emb = F.normalize(positive_emb, p=2, dim=1)
            negative_emb = F.normalize(negative_emb, p=2, dim=1)

            # Compute loss
            loss = contrastive_loss(anchor_emb, positive_emb, negative_emb)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()

            if step % 10 == 0:
                print(f"Epoch {epoch+1}/{EPOCHS}, Task {task}, Step {step}, Loss: {loss.item():.4f}")


# Save the Model
os.makedirs("fine_tuned_model", exist_ok=True)
model.encoder.save_pretrained("fine_tuned_model")
torch.save(model.classification_head.state_dict(), "fine_tuned_model/classification_head.pt")
torch.save(model.ranking_head.state_dict(), "fine_tuned_model/ranking_head.pt")

print("Fine-tuning complete. Model saved.")


Some weights of the model checkpoint at DeepPavlov/rubert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Starting fine-tuning...
Epoch 1/3, Task triplet, Step 0, Loss: 0.2293
Epoch 1/3, Task triplet, Step 10, Loss: 0.1707
Epoch 1/3, Task triplet, Step 20, Loss: 0.1731
Epoch 1/3, Task triplet, Step 30, Loss: 0.1880
Epoch 1/3, Task triplet, Step 40, Loss: 0.1777
Epoch 1/3, Task triplet, Step 50, Loss: 0.1941
Epoch 1/3, Task triplet, Step 60, Loss: 0.1739
Epoch 1/3, Task triplet, Step 70, Loss: 0.2042
Epoch 1/3, Task triplet, Step 80, Loss: 0.1326
Epoch 1/3, Task triplet, Step 90, Loss: 0.1301
Epoch 1/3, Task triplet, Step 100, Loss: 0.1510
Epoch 1/3, Task triplet, Step 110, Loss: 0.0957
Epoch 1/3, Task triplet, Step 120, Loss: 0.0868
Epoch 1/3, Task triplet, Step 130, Loss: 0.1223
Epoch 1/3, Task triplet, Step 140, Loss: 0.0573
Epoch 1/3, Task triplet, Step 150, Loss: 0.0550
Epoch 1/3, Task triplet, Step 160, Loss: 0.0677
Epoch 1/3, Task triplet, Step 170, Loss: 0.0572
Epoch 1/3, Task triplet, Step 180, Loss: 0.0336
