In [None]:
# =============================================================================
# 1. IMPORT LIBRARIES
# =============================================================================
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18, ResNet18_Weights
from tqdm import tqdm # For a nice progress bar
import matplotlib.pyplot as plt

print(f"PyTorch Version: {torch.__version__}")
print(f"Torchvision Version: {torchvision.__version__}")

In [None]:
# =============================================================================
# 2. SETUP AND CONFIGURATION
# =============================================================================
# Hyperparameters
NUM_EPOCHS = 20
BATCH_SIZE = 64
LEARNING_RATE = 0.001
VALIDATION_SPLIT = 0.2 # 20% of data for validation

# Set the device to use for training.
# Given your setup, this will automatically select your NVIDIA GTX 1650.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# =============================================================================
# 3. DATA LOADING AND PREPROCESSING
# =============================================================================
# Define transformations for the images.
# For transfer learning, we use the normalization stats from the ImageNet dataset,
# on which the ResNet model was pre-trained.
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)), # ResNet models expect 224x224 input
    transforms.ToTensor(),         # Convert images to PyTorch Tensors
    transforms.Normalize(mean=[0.485, 0.456, 0.406], # ImageNet normalization mean
                         std=[0.229, 0.224, 0.225])   # ImageNet normalization std
])

# Download and load the EuroSAT dataset
# The dataset will be downloaded to a 'data' directory in your project folder.
full_dataset = torchvision.datasets.EuroSAT(
    root='./Data',
    download=True,
    transform=data_transforms
)

# Get the class names
class_names = full_dataset.classes
print(f"Dataset has {len(class_names)} classes: {class_names}")

# Split the dataset into training and validation sets
num_data = len(full_dataset)
num_val = int(VALIDATION_SPLIT * num_data)
num_train = num_data - num_val

train_dataset, val_dataset = random_split(full_dataset, [num_train, num_val])
print(f"Number of training images: {len(train_dataset)}")
print(f"Number of validation images: {len(val_dataset)}")

# Create DataLoaders to handle batching and shuffling
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True, # Shuffle training data to improve model generalization
    num_workers=2 # Use multiple subprocesses to load data
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False, # No need to shuffle validation data
    num_workers=2
)

In [None]:
# =============================================================================
# 4. MODEL DEFINITION (TRANSFER LEARNING)
# =============================================================================
# Load a pre-trained ResNet18 model.
# Using 'ResNet18_Weights.IMAGENET1K_V1' provides the best available weights.
model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)

# The EuroSAT dataset has 10 classes. We need to replace the final layer
# of the pre-trained ResNet model, which was originally trained for 1000 ImageNet classes.
num_ftrs = model.fc.in_features # Get the number of input features of the final layer
model.fc = nn.Linear(num_ftrs, len(class_names)) # Replace it with a new layer for our 10 classes

# Move the model to the configured device (GPU or CPU)
model = model.to(device)

In [None]:
# =============================================================================
# 5. LOSS FUNCTION AND OPTIMIZER
# =============================================================================
# CrossEntropyLoss is standard for multi-class classification.
# It combines LogSoftmax and NLLLoss in one single class.
criterion = nn.CrossEntropyLoss()

# Adam is a popular and effective optimization algorithm.
# We pass the model's parameters and the learning rate.
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [None]:
# =============================================================================
# 6. TRAINING AND VALIDATION LOOP
# =============================================================================
# To store metrics for plotting
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

for epoch in range(NUM_EPOCHS):
    # --- Training Phase ---
    model.train() # Set the model to training mode
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    # Using tqdm for a progress bar
    train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Training]")
    for inputs, labels in train_pbar:
        # Move inputs and labels to the device
        inputs, labels = inputs.to(device), labels.to(device)

        # 1. Zero the parameter gradients
        optimizer.zero_grad()

        # 2. Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # 3. Backward pass and optimize
        loss.backward()
        optimizer.step()

        # 4. Calculate statistics
        running_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs.data, 1)
        total_samples += labels.size(0)
        correct_predictions += (predicted == labels).sum().item()

        # Update progress bar
        train_pbar.set_postfix({'loss': loss.item()})

    epoch_train_loss = running_loss / total_samples
    epoch_train_acc = correct_predictions / total_samples
    history['train_loss'].append(epoch_train_loss)
    history['train_acc'].append(epoch_train_acc)


    # --- Validation Phase ---
    model.eval() # Set the model to evaluation mode
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    # No need to track gradients during validation, which saves memory and computation
    with torch.no_grad():
        val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [Validation]")
        for inputs, labels in val_pbar:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total_samples += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()

            val_pbar.set_postfix({'loss': loss.item()})


    epoch_val_loss = running_loss / total_samples
    epoch_val_acc = correct_predictions / total_samples
    history['val_loss'].append(epoch_val_loss)
    history['val_acc'].append(epoch_val_acc)

    print(f"Epoch {epoch+1}/{NUM_EPOCHS} -> "
          f"Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f} | "
          f"Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.4f}")

print("Finished Training!")

In [None]:
# =============================================================================
# 7. SAVE THE MODEL
# =============================================================================
# It's good practice to save the model's state dictionary after training.
torch.save(model.state_dict(), 'resnet18_eurosat.pth')
print("Model saved to resnet18_eurosat.pth")


In [None]:
# =============================================================================
# 8. VISUALIZE RESULTS
# =============================================================================
# Plotting accuracy and loss curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Plot training & validation accuracy values
ax1.plot(history['train_acc'])
ax1.plot(history['val_acc'])
ax1.set_title('Model Accuracy')
ax1.set_ylabel('Accuracy')
ax1.set_xlabel('Epoch')
ax1.legend(['Train', 'Validation'], loc='upper left')

# Plot training & validation loss values
ax2.plot(history['train_loss'])
ax2.plot(history['val_loss'])
ax2.set_title('Model Loss')
ax2.set_ylabel('Loss')
ax2.set_xlabel('Epoch')
ax2.legend(['Train', 'Validation'], loc='upper left')

plt.show()
