In [None]:
#data loading

In [None]:
#TRAINING THE MODEL - this is the training for final densnet model

#imports 
import os
from models.densenet_model import create_densenet
from utils.data_loader import create_dataloaders, create_test_loader
from utils.evaluation import evaluate_model
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import json
from sklearn.metrics import confusion_matrix, precision_recall_curve, f1_score
from torch.optim.lr_scheduler import ReduceLROnPlateau
from utils.early_stopping import EarlyStopping

#device - (we used collab gpu)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#data
data_dir = 'data'
dataloaders = create_dataloaders(batch_size=32)
test_loader = create_test_loader(batch_size=32)

model = create_densenet(num_classes=2).to(device)

#loss function and oadam optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4,weight_decay = 1e-5)

#learnin rate
scheduler = ReduceLROnPlateau(optimizer, mode = 'min', factor = 0.5, patience = 3, verbose = True)



#TRAINING MODEL
def train_model(model, criterion, optimizer, dataloaders, num_epochs=20, device='cuda', patience=10):
    #preaparation of the folder
    metrics_file_path = os.path.join('metrics', 'metrics_by_epoch.txt')
    os.makedirs('metrics', exist_ok=True)
    os.makedirs('checkpoints', exist_ok=True)

    #anouncing metrics that we're going to save
    metrics_data = {
        "precision_recall": [],
        "f1_scores": [],
        "val_probs": [],
        "val_labels": [],
        "confusion_matrices": [],
        "epoch_metrics": []
    }

    # scheduler and early stopping
    scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5, verbose=True)
    early_stopping = EarlyStopping(patience=patience, checkpoint_path='checkpoints/best_model.pth')


    with open(metrics_file_path, 'w') as f:
        f.write(f'Epoch\tPhase\tLoss\tAccuracy\n')

        for epoch in range(num_epochs):
            print(f'Epoch {epoch+1}/{num_epochs}')
            print('-' * 10)


            for phase in ['train', 'val']:
                if phase == 'train':
                    model.train()  #training mode
                else:
                    model.eval()  #validation mode

                running_loss = 0.0
                running_corrects = 0
                all_labels = []
                all_probs = []

                #go through batches
                for inputs, labels in dataloaders[phase]:
                    
                    inputs, labels = inputs.to(device), labels.to(device)
                    optimizer.zero_grad()

                    with torch.set_grad_enabled(phase == 'train'):
                        # go forward
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)
                        _, preds = torch.max(outputs, 1)
                        probs = torch.softmax(outputs, dim=1)[:, 1]

                        # if it's training phase we need to update weights
                        if phase == 'train':
                            loss.backward() 
                            optimizer.step() 

                
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)
                    all_labels.extend(labels.cpu().numpy().tolist())
                    all_probs.extend(probs.detach().cpu().numpy().tolist())

                #epoch metrics
                epoch_loss = running_loss / len(dataloaders[phase].dataset)
                epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
                f.write(f'{epoch+1}\t{phase}\t{epoch_loss:.4f}\t{epoch_acc:.4f}\n')

                if phase == 'val':
                    scheduler.step(epoch_loss)

                    # metrics
                    precision, recall, thresholds = precision_recall_curve(all_labels, all_probs)
                    f1 = f1_score(all_labels, np.round(all_probs))
                    confusion_mat = confusion_matrix(all_labels, np.round(all_probs))

                    metrics_data["precision_recall"].append({
                        "precision": precision.tolist(),
                        "recall": recall.tolist(),
                        "thresholds": thresholds.tolist(),
                    })
                    metrics_data["f1_scores"].append(float(f1))
                    metrics_data["val_probs"].append(all_probs)
                    metrics_data["val_labels"].append(all_labels)
                    metrics_data["confusion_matrices"].append(confusion_mat.tolist())
                    metrics_data["epoch_metrics"].append({
                        "epoch": epoch+1,
                        "loss": float(epoch_loss),
                        "accuracy": float(epoch_acc)
                    })

                    # based on val loss, early stopping, we never reached patience
                    early_stopping(epoch_loss, model)

                    if early_stopping.early_stop:
                        print("Early stopping triggered.")
                        break
            # if early stopping is activated we need to stop loop (we suppose it works - never reached)
            if early_stopping.early_stop:
                break

    #extensive metrics save
    with open(os.path.join('metrics', 'metrics_data.json'), 'w') as json_file:
        json.dump(metrics_data, json_file, indent=4)

    #model saving
    model_save_path = os.path.join('checkpoints', 'densenet_ai_vs_authentic.pth')
    torch.save(model.state_dict(), model_save_path)
    print(f"Model saved at {model_save_path}")

    print(f"Metrics saved in {metrics_file_path} and metrics_data.json")
    return model



#training of model
trained_model = train_model(model, criterion, optimizer, dataloaders, num_epochs=20, device=device, patience=10)



#saving the model
torch.save(trained_model.state_dict(), 'checkpoints/densenet_ai_vs_authentic.pth')
print("Model trained and saved!")

#final evaluation for validation and test set
print("Evaluating on validation set:")
evaluate_model(trained_model, dataloaders['val'], device,"validation")

print("Evaluating on test set:")
evaluate_model(trained_model, test_loader, device, "test")


