# AlexNet (Transfer Learning)

In [None]:
import torch
import torchvision.models as models
from torch import nn, optim
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
from torchsummary import summary
import pandas as pd
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report, confusion_matrix
from PIL import Image
from ptflops import get_model_complexity_info
import time

# Import custom utility module
from utils import get_transform, prepare_train_test_data, calculate_evaluation_metrics

In [None]:
# Device configuration (Select CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\n Device : {device}")

### Load and prepare the data

In [None]:
# Transform data
image_input_size = (3, 227, 227) # Image input size for the model
transform = get_transform(resize_image_size = (227, 227))

# Prepare Train/Val/Test Data
train_loader, val_loader, test_loader = prepare_train_test_data(transform)

# Model keyword
model_keyword = "alexnet"

### Load pre-trained ResNet50 model

In [None]:
# Load pre-trained AlexNet
model = models.alexnet(pretrained=True)

# Modify the final fully connected layer for binary classification
num_features = model.classifier[6].in_features  # Get the input features of the last layer
model.classifier[6] = nn.Linear(num_features, 2)  # Replaced the final layer for binary classification: Internal waves (1) or No waves (0). (This is actually treating binary classification as a 2-class problem when compared to a single output neuron)
model = model.to(device)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.1) # Implement a learning rate scheduler to adjust the learning rate during training

# Print summary
summary(model, input_size=image_input_size)
print(f"Total number of layers: {sum(1 for _ in model.modules())}")

### Train the model

In [None]:
# Initialize a DataFrame to log training/validation metrics
log_df = []

# Record the start time
start_time = time.time()

# Training loop
num_epochs = 50
best_val_loss = float('inf')  # Initialize best validation loss for saving model

for epoch in range(num_epochs):
    # Training phase
    model.train()
    train_loss = 0.0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    # Validation phase
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()
    
    val_accuracy = val_correct / val_total * 100
    
    # Adjust learning rate
    scheduler.step(val_loss)
    
    # Calculate FLOPs after each epoch
    with torch.no_grad():
        flops, _ = get_model_complexity_info(model, (image_input_size), as_strings=False, print_per_layer_stat=False)

    # Log metrics into the DataFrame
    new_row = {
        "Epoch": epoch + 1,
        "Train_Loss": train_loss / len(train_loader),
        "Validation_Loss": val_loss / len(val_loader),
        "Validation_Accuracy": val_accuracy,
        "FLOPs": flops
    }
    log_df.append(new_row)

    # Print metrics for this epoch
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss / len(train_loader):.4f}, "
          f"Val Loss: {val_loss / len(val_loader):.4f}, Val Accuracy: {val_accuracy:.2f}%, "
          f"GFLOPs: {(flops / 1e9 ):.2f}")

    # Save the best model based on validation loss
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), f"model_outputs_data/best_{model_keyword}_model.pth")


# Calculate the total training time in seconds
print(f"\nTotal Training Time: {(time.time() - start_time):.2f} seconds")

# Save the DataFrame to a CSV file for later use
log_df.to_csv(f"model_outputs_data/model_evaluation_logs/training_logs_{model_keyword}.csv", index=False)

### Prediction and Model evaluation

In [None]:
# Test the model
model.eval()
# model.load_state_dict(torch.load("model_outputs_data/best_alexnet_model.pth"))  # Load the best model for testing

all_labels = []
all_preds = []

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)

        # convert to CPU numpy arrays
        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(predicted.cpu().numpy())

# Compute Evaluation Metrics
accuracy = accuracy_score(all_labels, all_preds)
precision = precision_score(all_labels, all_preds, average='binary')
recall = recall_score(all_labels, all_preds, average='binary')
f1 = f1_score(all_labels, all_preds, average='binary')

# Print evaluation metrics
calculate_evaluation_metrics(all_labels, all_preds, model_keyword)