In [None]:
import torch
from torch.optim import AdamW, lr_scheduler
from config import load_config
from model import load_model_for_finetuning
from data_prep import create_data_loader
from train import train_one_epoch, validate_model
from utils import plot_loss, plot_accuracy, plot_confusion_matrix, plot_precision_recall, plot_roc_curve

In [None]:
# load config and model
config = load_config('configs/config.yaml')
model = load_model_for_finetuning(config)

optimizer = AdamW(model.parameters(), lr=config['learning_rate'])
scheduler = lr_scheduler.StepLR(optimizer, step_size=config['step_size'], gamma=config['gamma'])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
# Load training and validation data
train_loader = create_data_loader(config["trainpath"], 
                                      config["train_label_col"],
                                      config['tokenizer_model'],
                                      config['max_length'],
                                      config['batch_size'],
                                      shuffle=True)

val_loader = create_data_loader(config["valpath"], 
                                config["val_label_col"],
                                config['tokenizer_model'],
                                config['max_length'],
                                config['batch_size'],
                                shuffle=False) 


In [None]:
# train data
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []
train_f1_scores = []
val_f1_scores = []
all_train_labels = []
all_train_preds = []
all_val_labels = []
all_val_preds = []

for epoch in range(config['epochs']):
    train_loss, train_accuracy, train_f1, epoch_train_labels, epoch_train_preds = train_one_epoch(model, train_loader, optimizer, device)
    val_loss, val_accuracy, val_f1, epoch_val_labels, epoch_val_preds = validate_model(model, val_loader, device)
    scheduler.step()
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accuracies.append(train_accuracy)
    val_accuracies.append(val_accuracy)
    train_f1_scores.append(train_f1)
    val_f1_scores.append(val_f1)
    all_train_labels.extend(epoch_train_labels)
    all_train_preds.extend(epoch_train_preds)
    all_val_labels.extend(epoch_val_labels)
    all_val_preds.extend(epoch_val_preds)
    print(f"Epoch {epoch+1}/{config['epochs']}, Train Loss: {train_loss}, Train Accuracy: {train_accuracy}, Train F1: {train_f1}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}, Val F1: {val_f1}")


In [None]:
# Plot loss and accuracy
plot_loss(train_losses, val_losses, config['epochs'])

plot_accuracy(train_accuracies, val_accuracies, config['epochs'])

# Plot confusion matrix
plot_confusion_matrix(all_val_labels, all_val_preds, classes=[str(i) for i in range(config['num_labels'])])

# Plot precision-recall curve
plot_precision_recall(all_val_labels, all_val_preds)

# Plot ROC curve
plot_roc_curve(all_val_labels, all_val_preds)

In [None]:
# Save the trained model
torch.save(model.state_dict(), config['trained_model_path'])
print(f"Model saved as {config['trained_model_path']}")