In [None]:
# Imports
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, random_split
# from torchvision.models import efficientnet_v2_l
import os
!pip install torchinfo



Hyperparameters

In [None]:
# Hyperparameters
# To reproduce the result, please follow the same hyperparameters from the paper
BATCH_SIZE = 100
NUM_EPOCHS = 10
NUM_CLASSES = 10
IMG_SIZE = 224

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


Load the SAR image dataset

In [None]:
# Transforms
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])


dataset_path = "PATH TO THE DATASET"
full_dataset = datasets.ImageFolder(root=dataset_path, transform=transform)

# Image distributions train/valid/test 80/10/10
train_size = int(0.8 * len(full_dataset))
val_size = int(0.1 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size
train_ds, val_ds, test_ds = random_split(full_dataset, [train_size, val_size, test_size])

print(f"Train: {len(train_ds)}, Val: {len(val_ds)}, Test: {len(test_ds)}")

# 80/10/10 Dataset split
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)
class_names = full_dataset.classes


Choose the model

In [None]:
# MODEL SELECTION -------------------->
import torch.nn as nn
import torchvision.models as models

def get_model(model_name, num_classes, device, pretrained=True):
    if model_name == "efficientnet_v2_l":  # 117,247,082
        model = models.efficientnet_v2_l(pretrained=pretrained)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
    elif model_name == "convnext_base": # 87,576,714
        model = models.convnext_base(pretrained=pretrained)
        model.classifier[2] = nn.Linear(model.classifier[2].in_features, num_classes)
    elif model_name == "vit_b_16": # 85,806,346
        model = models.vit_b_16(pretrained=pretrained)
        model.heads.head = nn.Linear(model.heads.head.in_features, num_classes)
    elif model_name == "swin_v2_b": # 86,916,098
        model = models.swin_v2_b(pretrained=pretrained)
        model.head = nn.Linear(model.head.in_features, num_classes)
    elif model_name == "resnet152": # 58,164,298
        model = models.resnet152(pretrained=pretrained)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    else:
        raise ValueError(f"Unsupported model name: {model_name}")

    return model.to(device)

# Chose your pretrained model
model_name = "resnet152"

model = get_model(model_name, NUM_CLASSES, device, pretrained=True)
model = model.to(device)

from torchinfo import summary
summary(model)

Traning and validation

In [None]:
# TRAINING AND VALIDAITON -------------------->
import os
import torch
from tqdm.notebook import tqdm

# Checkpoint settings
checkpoint_dir = "Path to save the checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

resume = True # Set to False to start fresh
checkpoint_path = f"{checkpoint_dir}/{model_name}_checkpoint.pth"

# Training configuration
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

# Initialize variables
train_losses = []
val_accuracies = []
start_epoch = 0

# -------- RESUME IF NEEDED -------- #
if resume and os.path.exists(checkpoint_path):
    print(f"✓ Resuming from checkpoint: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    train_losses = checkpoint['train_losses']
    val_accuracies = checkpoint['val_accuracies']
    start_epoch = checkpoint['epoch'] + 1
else:
    print("✓ Starting training from scratch")

# -------- TRAINING & VALIDATION LOOP -------- #
for epoch in range(start_epoch, NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")

    # --- Training ---
    model.train()
    total_loss, correct, total = 0, 0, 0
    train_pbar = tqdm(train_loader, desc="Training", leave=False)

    for images, labels in train_pbar:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        train_pbar.set_postfix({
            'loss': total_loss / (total / BATCH_SIZE),
            'acc': correct / total * 100
        })

    train_acc = correct / total * 100
    print(f"[Epoch {epoch+1}] Train Loss: {total_loss:.3f} | Train Acc: {train_acc:.2f}%")

    train_losses.append(total_loss / len(train_loader))

    # --- Validation ---
    model.eval()
    val_correct, val_total = 0, 0
    val_pbar = tqdm(val_loader, desc="Validation", leave=False)

    with torch.no_grad():
        for images, labels in val_pbar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            val_total += labels.size(0)
            val_correct += predicted.eq(labels).sum().item()

    val_acc = val_correct / val_total * 100
    val_accuracies.append(val_acc)
    print(f"           Val Acc: {val_acc:.2f}%")

    # --- Save Checkpoint ---
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_losses': train_losses,
        'val_accuracies': val_accuracies
    }, checkpoint_path)
    print(f"✓ Checkpoint saved at epoch {epoch+1}")

✓ Resuming from checkpoint: /content/drive/MyDrive/Colab Notebooks/ML_model_sar_synthetic_data_project/checkpoints/resnet152_checkpoint.pth


Evaluation

In [None]:
# EVALUATE THE TRAINED MODEL ON TEST BATCH -------------------->
from tqdm.notebook import tqdm
from sklearn.metrics import classification_report
model.eval()
test_correct, test_total = 0, 0
all_preds = []
all_labels = []

test_pbar = tqdm(test_loader, desc="Testing", leave=False)

with torch.no_grad():
    for images, labels in test_pbar:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = outputs.max(1)
        test_total += labels.size(0)
        test_correct += predicted.eq(labels).sum().item()

        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

test_acc = test_correct / test_total
print(f"\n[Final Test Accuracy] {test_acc*100:.2f}%")

print("\ns Classification Report:")
print(classification_report(all_labels, all_preds, target_names=class_names))


# GENERATE CONFUSION MATRIX -------------------->

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

# Compute confusion matrix
cm = confusion_matrix(all_labels, all_preds)

# Plot the confusion matrix
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
fig, ax = plt.subplots(figsize=(10, 10))  # adjust size as needed
disp.plot(ax=ax, cmap='Blues', xticks_rotation=45)
plt.title("Confusion Matrix")
plt.show()

Save the trained model

In [None]:
# SAVE THE TRAINED MODEL -------------------->
import os
from datetime import datetime

# Set save directory (on Google Drive or local)
save_dir = "SAVE THE TRAINED MODEL"
os.makedirs(save_dir, exist_ok=True)

# Dynamic model filename using model_name and timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_filename = f"{model_name}_weights_{timestamp}.pth"

# Full path
model_path = os.path.join(save_dir, model_filename)

# Save model weights
torch.save(model.state_dict(), model_path)
print(f"✓ Model saved to: {model_path}")

✓ Model saved to: /content/drive/MyDrive/Colab Notebooks/ML_model_sar_synthetic_data_project/trained_models/resnet152_weights_20250626_001156.pth
