# Imports

In [None]:
# Import necessary libraries
from sentence_transformers import SentenceTransformer, InputExample, losses, evaluation
from torch.utils.data import DataLoader
from tqdm import tqdm
import pandas as pd
import numpy as np
from helpers import load_data_pairs, flatten_list
from sklearn.model_selection import KFold
from datetime import datetime
from transformers import set_seed
set_seed(42)

# Load Data

In [None]:
# Load positive and negative data pairs
data_dict = load_data_pairs()
positive_pairs = flatten_list([data_dict[x] for x in data_dict if "pos" in x])
negative_pairs = flatten_list([data_dict[x] for x in data_dict if "neg" in x])

# Define Training Parameters

In [None]:
# Define model and training parameters
model_name = "../00_data/SBERT_Models/models/jobgbert_TSDAE_epochs5"
batch_size = 16
learning_rate = 2e-5
num_epochs = 1
fold_size = 10
output_path = f"../00_data/SBERT_Models/models/jobgbert_batch{batch_size}_TSDAE_{learning_rate}_f{fold_size}"
output_path

# K-Fold Cross-Validation

In [None]:
# Initialize K-Fold cross-validation
kf = KFold(n_splits=fold_size, random_state=42, shuffle=True)

In [None]:
# Initialize variables for tracking metrics
MRR = []
MRR_AT = 100
training_start = "".join([c for c in str(datetime.now()).split('.')[0] if c.isdigit()])
max_MRR = 0

In [None]:
# Perform training and evaluation for each fold
for epoch, (train_index, dev_index) in enumerate(kf.split(positive_pairs)):
    # Split data into training and development sets
    pos_train_samples = [positive_pairs[i] for i in train_index]
    pos_dev_samples = [positive_pairs[i] for i in dev_index]
    warmup_steps = len(pos_train_samples) * 0.1

    # Create training examples
    train_examples = [InputExample(texts=[item[0], item[1]]) for item in pos_train_samples]
    train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=batch_size)
    train_loss = losses.MultipleNegativesRankingLoss(SentenceTransformer(model_name))

    # Define evaluator
    evaluator = evaluation.RerankingEvaluator(pos_dev_samples, at_k=100, show_progress_bar=True)

    # Train the model
    SentenceTransformer(model_name).fit(
        train_objectives=[(train_dataloader, train_loss)],
        epochs=num_epochs,
        warmup_steps=warmup_steps,
        evaluator=evaluator,
        output_path=output_path
    )