In [165]:
from torchvision.models import resnet18, ResNet18_Weights

# Load a pretrained model using the new API
weights = ResNet18_Weights.DEFAULT
model = resnet18(weights=weights)


In [166]:
#Freeze all layers
for param in model.parameters():
    param.requires_grad = False

In [167]:

#Unfreeze the last two layers
for param in model.layer4.parameters():
    param.requires_grad = True

for param in model.fc.parameters():
    param.requires_grad = True


In [168]:

import torch.nn as nn

num_classes = 10

#model.fc = nn.Linear(model.fc.in_features, num_classes)
model.fc = nn.Sequential(
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Dropout(0.4),  # Dropout to reduce overfitting
    nn.Linear(256, num_classes)
)

In [169]:
# Train the last layer
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD([
    {"params": model.layer4.parameters()},
    {"params": model.fc.parameters()}
], lr=0.01, momentum=0.9)

In [170]:
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
from torchvision import transforms, datasets

# Define transformations for the training dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),                                                  # Resize images to match ResNet input s ize
    transforms.RandomHorizontalFlip(),                                               # Randomly flip images horizontally
    transforms.RandomRotation(10),                                                  # Randomly rotate images by 10 degrees
    #transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.5),  # Randomly change brightness, contrast, saturation and hue
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),                            # Randomly crop images to 224x224
    transforms.ToTensor(),                                                          # Convert images to PyTorch tensors
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])     # Normalize using ImageNet stats
])


#Load the full dataset
full_dataset = datasets.ImageFolder(root='../data/animals10', transform=transform)

#Split dataset into training and validation sets
dataset_size = len(full_dataset)
indices = list(range(dataset_size))
train_indices, val_indices = train_test_split(indices, test_size=0.2, random_state=42, shuffle=True)

#create subsets
train_dataset = Subset(full_dataset, train_indices)
val_dataset = Subset(full_dataset, val_indices)

# Create DataLoaders for training and validation
from torch.utils.data import DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)



In [171]:
print(model.fc)  # should show output = number of unique classes
class_names = train_dataset.dataset.classes
print(len(class_names))  # how many classes do you have?


Sequential(
  (0): Linear(in_features=512, out_features=256, bias=True)
  (1): ReLU()
  (2): Dropout(p=0.4, inplace=False)
  (3): Linear(in_features=256, out_features=10, bias=True)
)
10


In [172]:
def check_early_stopping(current_acc, best_acc, patience_counter, patience, model, save_path='../models/best_resnet_model.pth'):
    """
    Check if early stopping criteria are met and save the best model if needed.
    """
    should_stop = False
    
    if current_acc > best_acc:
        best_acc = current_acc
        patience_counter = 0
        # Save weights
        torch.save(model.state_dict(), save_path)
        print("Accuracy improved — saving best model.")
    else:
        patience_counter += 1
        print(f"No improvement in accuracy. Patience: {patience_counter}/{patience}")
        
        if patience_counter >= patience:
            print("Early stopping triggered!")
            should_stop = True
            
    return best_acc, patience_counter, should_stop

In [173]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [174]:
print(full_dataset.classes)


['cane', 'cavallo', 'elefante', 'farfalla', 'gallina', 'gatto', 'mucca', 'pecora', 'ragno', 'scoiattolo']


In [None]:


# Initialize early stopping parameters
best_val_acc = 0.0
patience = 3 
patience_counter = 0

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Training loop
num_epochs = 20  # Increased from 5 to give early stopping a chance to be triggered
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    # Validation Accuracy
    val_correct = 0
    val_total = 0
    model.eval()
    with torch.no_grad():
        for val_images, val_labels in val_loader:
            val_images, val_labels = val_images.to(device), val_labels.to(device)
            val_outputs = model(val_images)
            _, val_predicted = torch.max(val_outputs.data, 1)
            val_total += val_labels.size(0)
            val_correct += (val_predicted == val_labels).sum().item()

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Track stats
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    val_accuracy = 100 * val_correct / val_total
    print(f"Validation Accuracy: {val_accuracy:.2f}%")
    print(f"Training Accuracy: {accuracy:.2f}%")
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss:.4f}, Accuracy: {accuracy:.2f}%")
    
    # Check early stopping
    best_val_acc, patience_counter, should_stop = check_early_stopping(
        current_acc=val_accuracy,
        best_acc=best_val_acc,
        patience_counter=patience_counter,
        patience=patience,
        model=model
    )
    
    if should_stop:
        print(f"Training stopped early at epoch {epoch+1}")
        break

print("Training completed!")

Validation Accuracy: 12.91%
Training Accuracy: 89.14%
Epoch [1/20], Loss: 222.4611, Accuracy: 89.14%
Accuracy improved — saving best model.


KeyboardInterrupt: 