Train

In [2]:
import torch
import torchvision
import torch.optim as optim
import torch.nn as nn
from torchvision import datasets, models, transforms
import os

# Define transforms for the training and validation sets
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# Load the datasets with ImageFolder
data_dir = "C:/Users/calve/Downloads/satdata/data/train"
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}

# Define the dataloaders
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}

# Detect if we have a GPU available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load the pre-trained ResNet18 model from torchvision.models
model = models.resnet18(pretrained=True)

# Replace the final fully connected layer
# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 4)  # Set the number of output classes. In your case it's 4

# Move the model to GPU if available
model = model.to(device)

# Define the criterion
criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer = optim.SGD(model.parameters(), lr=0.001)

# Define number of epochs
num_epochs = 25

# Train the model
for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-' * 10)

    # Each epoch has a training and validation phase
    for phase in ['train', 'val']:
        if phase == 'train':
            model.train()  # Set model to training mode
        else:
            model.eval()   # Set model to evaluate mode

        running_loss = 0.0
        running_corrects = 0

        # Iterate over data
        for inputs, labels in dataloaders[phase]:
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Forward pass
            with torch.set_grad_enabled(phase == 'train'):
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

            # Backward and optimize only if in training phase
            if phase == 'train':
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            # Statistics
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(image_datasets[phase])
        epoch_acc = running_corrects.double() / len(image_datasets[phase])

        print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

    print()

print('Training complete')

FileNotFoundError: ignored

Validate

In [None]:
# Load the validation dataset
validation_dataset = datasets.ImageFolder(os.path.join(data_dir, 'validation'), test_transforms)
validation_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size=4, shuffle=True, num_workers=4)

def validate(model, dataloader):
    model.eval()  # Set model to evaluate mode
    running_corrects = 0
    total_samples = 0

    # Iterate over data
    for inputs, labels in dataloader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        total_samples += labels.size(0)

        # Forward pass
        with torch.set_grad_enabled(False):
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

        # Statistics
        running_corrects += torch.sum(preds == labels.data)

    accuracy = running_corrects.double() / total_samples

    print('Validation Acc: {:.4f}'.format(accuracy))

# Call the validation function
print("Validating the model...")
validate(model, validation_dataloader)

Test

In [None]:
def test(model, dataloader):
    model.eval()  # Set model to evaluate mode
    running_corrects = 0
    total_samples = 0

    # Iterate over data
    for inputs, labels in dataloader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        total_samples += labels.size(0)

        # Forward pass
        with torch.set_grad_enabled(False):
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

        # Statistics
        running_corrects += torch.sum(preds == labels.data)

    accuracy = running_corrects.double() / total_samples

    print('Test Acc: {:.4f}'.format(accuracy))

# Call the test function
print("Testing the model...")
test(model, test_dataloader)

Preformance

In [None]:
from sklearn.metrics import confusion_matrix, classification_report
import numpy as np

def compute_metrics(model, dataloader):
    model.eval()
    all_preds = []
    all_labels = []

    # Iterate over data
    for inputs, labels in dataloader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Forward pass
        with torch.set_grad_enabled(False):
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

        # Save all predictions and true labels
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

    # Compute confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    print('Confusion Matrix:')
    print(cm)

    # Compute classification report
    cr = classification_report(all_labels, all_preds, target_names=dataloader.dataset.classes)
    print('Classification Report:')
    print(cr)

# Compute metrics on test set
print("Computing metrics on test set...")
compute_metrics(model, test_dataloader)