In [None]:
!pip install -U -q sentence_transformers

In [None]:
import gc
import os
import random
import traceback
import pandas as pd
import numpy as np
from datetime import datetime

from sklearn.model_selection import StratifiedGroupKFold

from datasets import Dataset

from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import BinaryClassificationEvaluator
from sentence_transformers.losses import OnlineContrastiveLoss, ContrastiveLoss
from sentence_transformers.losses.ContrastiveLoss import SiameseDistanceMetric
from sentence_transformers.trainer import SentenceTransformerTrainer
from sentence_transformers.training_args import BatchSamplers, SentenceTransformerTrainingArguments

In [None]:
import torch
import _codecs

torch.serialization.add_safe_globals([
    np.core.multiarray.scalar, 
    np.dtype, 
    np.dtypes.Float64DType, 
    np.dtypes.UInt32DType,
    np.core.multiarray._reconstruct,
    np.ndarray,
    _codecs.encode, 
])

In [None]:
VER = 1
N_FOLD = 0

os.environ['WANDB_API_KEY'] = '...'
os.environ['WANDB_PROJECT'] = f'ft-fixed-berta-fold{N_FOLD}-online-contrastive'
os.environ['WANDB_NOTES'] = f'ft-fixed-berta-fold{N_FOLD}-online-contrastive-{VER}'
os.environ['WANDB_NAME'] = f'ft-fixed-berta-fold{N_FOLD}-online-contrastive-{VER}'

In [None]:
num_train_epochs = 1
batch_size = 16

output_dir = f'output/fold{N_FOLD}_training_ocl-' + datetime.now().strftime('%Y-%m-%d_%H-%M-%S')

In [None]:
df = pd.read_parquet(
    'data/preprocessed/train_texts.parquet',
    columns=['variantid_1', 'variantid_2', 'group_id', 'is_double']
)

df = df.sort_values(by=['variantid_1', 'variantid_2'])
df = df.sample(len(df), random_state=42)

sgkf = StratifiedGroupKFold(n_splits=5)

fold_mapping = {
    '0': {
        'train_idxs': [],
        'val_idxs': [],
    },
    '1': {
        'train_idxs': [],
        'val_idxs': [],
    },
    '2': {
        'train_idxs': [],
        'val_idxs': [],
    },
    '3': {
        'train_idxs': [],
        'val_idxs': [],
    },
    '4': {
        'train_idxs': [],
        'val_idxs': [],
    },
}

for fold, (train_idx, val_idx) in enumerate(sgkf.split(df, df['is_double'], groups=df['group_id'])):
    fold_mapping[str(fold)]['train_idxs'] = train_idx
    fold_mapping[str(fold)]['val_idxs'] = val_idx

In [None]:
dataset = Dataset.from_parquet('avito-for-dl-training.parquet')
dataset = dataset.remove_columns(['group_id', 'variantid_1', 'variantid_2', '__index_level_0__'])

In [None]:
train_dataset = dataset.select(fold_mapping[str(N_FOLD)]['train_idxs'].tolist())
eval_dataset = dataset.select(fold_mapping[str(N_FOLD)]['val_idxs'].tolist())

In [None]:
len(train_dataset), len(eval_dataset)

In [None]:
train_dataset[0]

In [None]:
eval_dataset[0]

In [None]:
model = SentenceTransformer('sergeyzh/BERTA')

In [None]:
# model.max_seq_length = 1024

In [None]:
model

In [None]:
margin = 0.75
distance_metric = SiameseDistanceMetric.COSINE_DISTANCE
train_loss = OnlineContrastiveLoss(model=model, distance_metric=distance_metric, margin=margin)

# loss = losses.MatryoshkaLoss(model, loss, [312, 256, 128, 64])

In [None]:
binary_acc_evaluator = BinaryClassificationEvaluator(
    sentences1=eval_dataset['sentence1'],
    sentences2=eval_dataset['sentence2'],
    labels=eval_dataset['label'],
    name='avito-duplicates',
)

In [None]:
# binary_acc_evaluator(model, epoch=0, steps=0)

In [None]:
args = SentenceTransformerTrainingArguments(
    output_dir=output_dir,
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=batch_size,
    # per_device_eval_batch_size=batch_size,
    warmup_ratio=0.03,
    fp16=True,
    bf16=False,
    batch_sampler=BatchSamplers.NO_DUPLICATES,
    # eval_strategy='epoch',
    save_strategy='steps',
    save_steps=500,
    save_total_limit=4,
    logging_steps=1,
    run_name='online-contrastive-loss',
)

In [None]:
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    # eval_dataset=eval_dataset,
    loss=train_loss,
    # evaluator=binary_acc_evaluator,
)

In [None]:
trainer.train()

In [None]:
final_output_dir = f'{output_dir}/final_fold{N_FOLD}'
model.save(final_output_dir)