In [None]:
import numpy as np
import pandas as pd 

import random
import json
import re
import gc

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from transformers import AutoTokenizer, AutoModel, AdamW
from dataclasses import dataclass

from sklearn.model_selection import train_test_split

from tqdm import tqdm
tqdm.pandas()

In [None]:
@dataclass
class Config:
    model_name: str = 'cointegrated/rubert-tiny2'
    max_length: int = 768
    batch_size: int = 32   
    n_epochs: int = 1
    lr: float = 3e-5
    seed: int = 52
    test_size: float = 0.15
    
config = Config()

In [None]:
def seed_all(SEED):
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True
    
seed_all(config.seed)

In [None]:
text_and_bert = pd.read_parquet('/kaggle/input/extracted_data/text_and_bert.parquet', engine='pyarrow')
text_and_bert['description'] = text_and_bert['description'].fillna('no desc')

In [None]:
attrs = pd.read_parquet('/kaggle/input/extracted_data/attributes.parquet', columns=['categories'], engine='pyarrow')
attrs['category_level_1'] = attrs['categories'].progress_apply(lambda x: eval(x)['1'])
attrs['category_level_2'] = attrs['categories'].progress_apply(lambda x: eval(x)['2'])
attrs['category_level_3'] = attrs['categories'].progress_apply(lambda x: eval(x)['3'])

In [None]:
data = pd.concat([text_and_bert, attrs], axis=1)

In [None]:
del text_and_bert, attrs
gc.collect()

In [None]:
def remove_html_tags(text):
    clean_text = re.sub(r'<[^>]+>', '', text)
    clean_text = clean_text.replace('\n', ' ')
    return clean_text.lower()

data['description'] = data['description'].progress_apply(remove_html_tags)
data['name'] = data['name'].progress_apply(remove_html_tags)
data['category_level_1'] = data['category_level_1'].progress_apply(lambda x: x.lower())
data['category_level_2'] = data['category_level_2'].progress_apply(lambda x: x.lower())
data['category_level_3'] = data['category_level_3'].progress_apply(lambda x: x.lower())

In [None]:
data['text'] = data['name'].str.cat(data['category_level_1'], sep=' [SEP] ')
data['text'] = data['text'].str.cat(data['category_level_2'], sep=' [SEP] ')
data['text'] = data['text'].str.cat(data['category_level_3'], sep=' [SEP] ')
data['text'] = data['text'].str.cat(data['description'], sep=' [SEP] ')

In [None]:
train_pairs = pd.read_parquet('/kaggle/input/extracted_data/train.parquet', engine='pyarrow')

In [None]:
train_data = data[['text', 'variantid']]

In [None]:
train_pairs.rename(
    columns={
        'variantid1': 'variantid_1',
        'variantid2': 'variantid_2'
    }, inplace=True
)

train_df = train_pairs.merge(
    train_data.add_suffix('_1'), 
    on='variantid_1'
).merge(
    train_data.add_suffix('_2'), 
    on='variantid_2'
)

train_df = train_df[['text_1', 'text_2', 'target']]

In [None]:
train_df, val_df = (
    train_test_split(
        train_df,
        test_size=config.test_size, 
        random_state=config.seed, 
        stratify=train_df.target
    )
)

In [None]:
class TextPairDataset(Dataset):
    def __init__(self, df, tokenizer, max_length):
        self.df = df
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        text_1 = self.df.iloc[idx]['text_1']
        text_2 = self.df.iloc[idx]['text_2']
        target = self.df.iloc[idx]['target']

        encoding_1 = self.tokenizer.encode_plus(
            text_1[:1200], 
            max_length=self.max_length, 
            pad_to_max_length=False, 
            return_attention_mask=True, 
            return_tensors='pt', 
            truncation=True
        )

        encoding_2 = self.tokenizer.encode_plus(
            text_2[:1200], 
            max_length=self.max_length, 
            pad_to_max_length=False, 
            return_attention_mask=True, 
            return_tensors='pt', 
            truncation=True
        )

        return {
            'input_ids_1': encoding_1['input_ids'].flatten(),
            'attention_mask_1': encoding_1['attention_mask'].flatten(),
            'input_ids_2': encoding_2['input_ids'].flatten(),
            'attention_mask_2': encoding_2['attention_mask'].flatten(),
            'target': torch.tensor(target, dtype=torch.float)
        }

In [None]:
class SiameseBERT(nn.Module):
    def __init__(self):
        super(SiameseBERT, self).__init__()
        self.bert = AutoModel.from_pretrained('cointegrated/rubert-tiny2')

    def forward(self, input_ids_1, attention_mask_1, input_ids_2, attention_mask_2):
        output_1 = self.bert(input_ids_1, attention_mask=attention_mask_1)
        output_2 = self.bert(input_ids_2, attention_mask=attention_mask_2)
        pooled_output_1 = output_1.pooler_output
        pooled_output_2 = output_2.pooler_output
        return pooled_output_1, pooled_output_2

In [None]:
tokenizer = AutoTokenizer.from_pretrained('cointegrated/rubert-tiny2')

train_dataset = TextPairDataset(train_df, tokenizer, config.max_length)
val_dataset = TextPairDataset(val_df, tokenizer, config.max_length)

train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)

In [None]:
model = SiameseBERT()
criterion = nn.CosineEmbeddingLoss()
optimizer = optim.Adam(model.parameters(), lr=config.lr)
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

In [None]:
def evaluate(model, val_loader, criterion, device):
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in val_loader:
            input_ids_1 = batch['input_ids_1'].to(device)
            attention_mask_1 = batch['attention_mask_1'].to(device)
            input_ids_2 = batch['input_ids_2'].to(device)
            attention_mask_2 = batch['attention_mask_2'].to(device)
            target = batch['target'].to(device)

            output_1, output_2 = model(input_ids_1, attention_mask_1, input_ids_2, attention_mask_2)
            loss = criterion(output_1, output_2, target)
            val_loss += loss.item()
    return val_loss / len(val_loader)

In [None]:
for epoch in range(config.n_epochs):
    model.train()
    running_loss = 0.0
    pbar = tqdm(train_loader, desc=f'Epoch {epoch + 1}', dynamic_ncols=True)
    for batch in pbar:
        input_ids_1 = batch['input_ids_1'].to(device)
        attention_mask_1 = batch['attention_mask_1'].to(device)
        input_ids_2 = batch['input_ids_2'].to(device)
        attention_mask_2 = batch['attention_mask_2'].to(device)
        target = batch['target'].to(device)

        optimizer.zero_grad()
        output_1, output_2 = model(input_ids_1, attention_mask_1, input_ids_2, attention_mask_2)
        loss = criterion(output_1, output_2, target)
        loss.backward()
        optimizer.step()
        scheduler.step()

        running_loss += loss.item()
        pbar.set_postfix({'Loss': loss.item()})

    train_loss = running_loss / len(train_loader)
    val_loss = evaluate(model, val_loader, criterion, device)
    print(f'Epoch {epoch + 1}, Train Loss: {train_loss}, Val Loss: {val_loss}')
    
    torch.save(model.state_dict(), f'model_epoch_{epoch + 1}_valloss_{val_loss}.pth')