In [1]:
import sys
import os
import torch
import torch.nn as nn
import torch.optim as optim
from datetime import datetime

sys.path.append(os.path.abspath("D:\\burtm\\Visual_studio_code\\PD_related_projects"))

from utils.model_utils import get_model, get_trainable_layers
from utils.data_loading import get_dataloaders
from utils.utils_transforms import get_transform  
from utils.training_utils import fine_tune_last_n_layers, train_model, get_criterion, get_optimizer, get_scheduler

In [None]:
selected_model = "resnet18"
selected_transform = "resnet18"
N_max=282
use_patches=True
pretrained=True
depth=2
num_epochs=10
batch_size=32
learning_rate=0.001
input_filename="train_df_patches_cc.csv"
criterion_name="CrossEntropyLoss"
criterion = get_criterion("CrossEntropyLoss")
optimizer_name = "Adam"
num_classes = 2  # Change this to match your dataset
early_stopping=10
scheduler_name = 'no_scheduling'#CosineAnnealingLR'
checkpoint_path = "D:\\burtm\Visual_studio_code\PD_related_projects\checkpoints\\"
models_path = "D:\\burtm\Visual_studio_code\PD_related_projects\outputs\models\\"

In [None]:
# Example training metadata
training_metadata = {
    "type_of_approach": "fine tuning imagenet pre-trained model",
    "type_of_approach_sigla": "FTIPM",
    "model_name": selected_model,
    "transform_name": selected_transform,
    "epochs": num_epochs,
    "batch_size": batch_size,
    "learning_rate": learning_rate,
    "optimizer": optimizer_name,
    "pretrained": pretrained,
    "depth": depth,
    "use_patches": use_patches,
    "input_filename": input_filename,
    "num_classes": num_classes,
    "criterion_name": criterion_name,
    "early_stopping": early_stopping,
    "N_max": N_max,
    "scheduler_name": scheduler_name,
}

In [None]:
transform=get_transform(selected_transform,use_patches=use_patches)
train_dataloader,val_dataloader=get_dataloaders(transform, batch_size=batch_size, N_max=N_max, file_name=input_filename)


In [None]:
# Modify the final classification layer (assuming you have 10 classes)

model=get_model(selected_model, pretrained=pretrained, num_classes=num_classes)

# Define loss function and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device is: ",device)
model = model.to(device)

In [None]:
# Freeze all layers but the last n
num_trainable_layers = get_trainable_layers(selected_model,depth=depth) 
model = fine_tune_last_n_layers(model, num_trainable_layers)
optimizer = get_optimizer(model, optimizer_name, lr=learning_rate)
scheduler = get_scheduler(optimizer, scheduler_name, T_max=num_epochs)

In [None]:
start_time=datetime.now()
model,train_losses,val_losses = train_model(model, train_dataloader, val_dataloader, criterion, optimizer, device, 
                                            num_epochs=num_epochs, 
                                            checkpoint_path=checkpoint_path,
                                            early_stopping_patience=early_stopping, scheduler=scheduler)
end_time=datetime.now()

In [None]:
#get the best_checkpoint.pth and add the training metadata to it
checkpoint = torch.load(checkpoint_path+'best_checkpoint.pth')
checkpoint['training_metadata'] = training_metadata
val_accuracy= checkpoint['val_acc']
# Save the updated checkpoint with metadata
torch.save(checkpoint, checkpoint_path)

#do the same for the last checkpoint
checkpoint = torch.load(checkpoint_path+'last_checkpoint.pth')
checkpoint['training_metadata'] = training_metadata
# Save the updated checkpoint with metadata
torch.save(checkpoint, checkpoint_path)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
torch.save(checkpoint, f"{models_path}\{training_metadata['type_of_approach_sigla']}_ValAcc{val_accuracy}_{timestamp}.pth")