# AlexNet (Transfer Learning)

In [2]:
import torch
import torchvision.models as models
from torch import nn, optim
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
import pandas as pd
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report, confusion_matrix
from PIL import Image
import time

In [3]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Load and prepare the data

In [4]:
# Transformations for the dataset
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # Convert grayscale-like images to 3 channels
    transforms.Resize((256, 256)),               # Resize images for AlexNet input
    transforms.CenterCrop(224),                  # Crop the central 224×224 region
    transforms.ToTensor(),                       # Convert images to tensors
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Normalize
])

# Custom dataset class to handle CSV and image folder
class SARImageDataset(torch.utils.data.Dataset):
    def __init__(self, csv_file, image_folder, transform=None):
        self.data = pd.read_csv(csv_file)
        self.image_folder = image_folder
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = f"{self.image_folder}/{self.data.iloc[idx, 0]}.png"
        label = self.data.iloc[idx, 1]
        image = Image.open(img_name)

        if self.transform:
            image = self.transform(image)

        return image, label

# Load full training dataset
train_val_dataset = SARImageDataset(csv_file="data/train.csv", image_folder="data/images_train", transform=transform)
test_dataset = SARImageDataset(csv_file="data/test.csv", image_folder="data/images_test", transform=transform)

# Split into training and validation datasets
train_size = int(0.8 * len(train_val_dataset))  # 80% for training
val_size = len(train_val_dataset) - train_size  # 20% for validation
train_dataset, val_dataset = random_split(train_val_dataset, [train_size, val_size])

# DataLoaders for batching
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

### Load pre-trained ResNet50 model and modify it

In [5]:
# Load pre-trained AlexNet
model = models.alexnet(pretrained=True)

# Modify the final fully connected layer for binary classification
num_features = model.classifier[6].in_features  # Get the input features of the last layer
model.classifier[6] = nn.Linear(num_features, 2)  # Binary classification: Internal waves (1) or No waves (0)
model = model.to(device)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)



### Train the model

In [6]:
# Initialize a DataFrame to log training/validation metrics
log_df = pd.DataFrame(columns=["Epoch", "Train Loss", "Validation Loss", "Validation Accuracy"])

# Record the start time
start_time = time.time()

# Training loop
num_epochs = 10
best_val_loss = float('inf')  # Initialize best validation loss for saving model

for epoch in range(num_epochs):
    # Training phase
    model.train()
    train_loss = 0.0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    # Validation phase
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()

    val_accuracy = val_correct / val_total * 100

    # Log metrics into the DataFrame
    new_row = {
        "Epoch": epoch + 1,
        "Train Loss": train_loss / len(train_loader),
        "Validation Loss": val_loss / len(val_loader),
        "Validation Accuracy": val_accuracy
    }
    log_df = pd.concat([log_df, pd.DataFrame([new_row])], ignore_index=True)

    # Print metrics for this epoch
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss / len(train_loader):.4f}, "
          f"Val Loss: {val_loss / len(val_loader):.4f}, Val Accuracy: {val_accuracy:.2f}%")

    # Save the best model based on validation loss
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "model_outputs_data/best_alexnet_model.pth")


# Calculate the total training time in seconds
print(f"\nTotal Training Time: {(time.time() - start_time):.2f} seconds")

# Save the DataFrame to a CSV file for later use
log_df.to_csv("model_outputs_data/training_logs_alexnet.csv", index=False)

KeyboardInterrupt: 

### Prediction and Model evaluation

In [None]:
# Test the model
model.load_state_dict(torch.load("model_outputs_data/best_alexnet_model.pth"))  # Load the best model
model.eval()
all_labels = []
all_preds = []

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)

        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(predicted.cpu().numpy())

# Compute Evaluation Metrics
accuracy = accuracy_score(all_labels, all_preds)
precision = precision_score(all_labels, all_preds, average='binary')
recall = recall_score(all_labels, all_preds, average='binary')
f1 = f1_score(all_labels, all_preds, average='binary')

print(f"\nTest Accuracy: {accuracy * 100:.2f}%")
print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")
print(f"F1-Score: {f1:.2f}")

# Classification Report
print("\nClassification Report:")
print(classification_report(all_labels, all_preds, target_names=["No Internal Waves", "Internal Waves"]))

# Confusion Matrix
print("\nConfusion Matrix:")
print(confusion_matrix(all_labels, all_preds))