In [None]:
from sklearn.model_selection import train_test_split
import sys
import torch
sys.path.append('../')
from src.model_trainer import load_bert_model, train_model

# Split the data into training and validation sets
def split_data(inputs_path, labels_path, subset_size=None, val_split=0.2):
    inputs = torch.load(inputs_path)
    labels = torch.load(labels_path)

    if subset_size:
        inputs = inputs[:subset_size]
        labels = labels[:subset_size]

    inputs_train, inputs_val, labels_train, labels_val = train_test_split(inputs, labels, test_size=val_split, random_state=42)

    return inputs_train, inputs_val, labels_train, labels_val

# Load the entire dataset
subset_size = 2000
train_inputs, val_inputs, train_labels_10class, val_labels_10class = split_data('../data/train_inputs.pt', '../data/train_labels.pt', subset_size=subset_size)
_, _, train_labels_3class, val_labels_3class = split_data('../data/train_inputs.pt', '../data/train_labels_3class.pt', subset_size=subset_size)


num_examples = len(train_inputs)
print("Number of examples:", num_examples)
num_labels = len(train_labels_10class)
print("Number of labels:", num_labels)
num_labels_3class = len(train_labels_3class)
print("Number of 3-class labels:", num_labels_3class)

# Train the model for the 10-class problem
model_10class = load_bert_model(num_labels=10)  # Assuming 10 classes
model_10class = train_model(model_10class, train_inputs, train_labels_10class, val_inputs, val_labels_10class, 3, 16, 1e-5)

# # Evaluate the model for the 10-class problem
# evaluate_model(model_10class, val_inputs, val_labels_10class)

# Save the trained models
model_path_10class = '../models/model_10class_subset_2000.pth'
torch.save(model_10class.state_dict(), model_path_10class)

# Train the model for the 3-class problem
model_3class = load_bert_model(num_labels=3)  # Assuming 3 classes
model_3class = train_model(model_3class, train_inputs, train_labels_3class, val_inputs, val_labels_3class, 3, 16, 1e-5)

# # Evaluate the model for the 3-class problem
# evaluate_model(model_3class, val_inputs, val_labels_3class)


model_path_3class = '../models/model_3class_subset_2000.pth'
torch.save(model_3class.state_dict(), model_path_3class)