# Imports

In [None]:
import sys
sys.path.append("..")
from _utils import flatten_list, load_json,write_json

In [None]:
from sentence_transformers import SentenceTransformer, InputExample
from torch.utils.data import DataLoader
from sentence_transformers import models, losses, util
from tqdm import tqdm
import random
from sentence_transformers import evaluation
import json
import pandas as pd
from collections import Counter
import os
from sklearn.model_selection import train_test_split
from training_helpers import load_data_pairs, create_trainig_samples, encode_jobs
import numpy as np
from datetime import datetime
from sklearn.model_selection import KFold
from transformers import set_seed
import accelerate
set_seed(42)

# Variables

In [None]:
data_dict = load_data_pairs()
positive_pairs = flatten_list([data_dict[x] for x in data_dict if "pos" in x])
negative_pairs = flatten_list([data_dict[x] for x in data_dict if "neg" in x])

# Define Training Parameters

In [None]:
model_name = "../00_data/SBERT_Models/models/gbert_TSDAE_epochs5"
model = SentenceTransformer(model_name)

In [None]:
if model_name == "deepset/gbert":
  TSDAE = "woTSDAE"
elif model_name == "../00_data/SBERT_Models/models/gbert_TSDAE_epochs5":
  TSDAE = "wTSDAE"
else:
  raise TypeError

batch_size = 16
lr = 2e-5
num_epochs = 1
fold_size = 10
output_path = f"../00_data/SBERT_Models/models/gbert_batch{batch_size}_{TSDAE}_{lr}_f{fold_size}"
output_path

# K-Fold Cross-Validation

In [None]:
kf = KFold(n_splits=fold_size, random_state=42, shuffle=True)

In [None]:
# Perform training and evaluation for each fold
for epoch, (train_index, dev_index) in enumerate(kf.split(positive_pairs)):
    # Split data into training and development sets
    pos_train_samples = [positive_pairs[i] for i in train_index]
    pos_dev_samples = [positive_pairs[i] for i in dev_index]
    warmup_steps = len(pos_train_samples) * 0.1

    # Create training examples
    train_examples = [InputExample(texts=[item[0], item[1]]) for item in pos_train_samples]
    train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=batch_size)
    train_loss = losses.MultipleNegativesRankingLoss(SentenceTransformer(model_name))

    # Define evaluator
    evaluator = evaluation.RerankingEvaluator(pos_dev_samples, at_k=100, show_progress_bar=True)

    # Train the model
    SentenceTransformer(model_name).fit(
        train_objectives=[(train_dataloader, train_loss)],
        epochs=num_epochs,
        warmup_steps=warmup_steps,
        evaluator=evaluator,
        output_path=output_path
    )