<a href="https://colab.research.google.com/github/mobarakol/Applied_Deep_Learning/blob/main/LeNet_CIFAR10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

# Directory to store CIFAR-10 data
DATA_DIR = './data/cifar10'

# Download CIFAR-10 using torchvision
datasets.CIFAR10(root=DATA_DIR, train=True, download=True)
datasets.CIFAR10(root=DATA_DIR, train=False, download=True)

# File paths for extracted data
train_dir = os.path.join(DATA_DIR, 'cifar-10-batches-py/train')
test_dir = os.path.join(DATA_DIR, 'cifar-10-batches-py/test')

# Define class names for CIFAR-10
CLASSES = [
    'airplane', 'automobile', 'bird', 'cat', 'deer',
    'dog', 'frog', 'horse', 'ship', 'truck'
]

# Custom Dataset class for CIFAR-10
class CIFAR10Dataset(Dataset):
    def __init__(self, root_dir, train=True, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []

        # Load dataset files
        dataset = datasets.CIFAR10(root=DATA_DIR, train=train, download=True)
        self.data = dataset.data
        self.labels = dataset.targets

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Read image and label
        image = self.data[idx]
        label = self.labels[idx]

        # Convert image to PIL format
        image = Image.fromarray(image)

        if self.transform:
            image = self.transform(image)

        return image, label

# Define transformation for CIFAR-10
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Create dataset instances for training and testing
train_dataset = CIFAR10Dataset(root_dir=train_dir, train=True, transform=transform)
test_dataset = CIFAR10Dataset(root_dir=test_dir, train=False, transform=transform)

# Split training into train/validation sets
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])

# Create DataLoader instances
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Define LeNet architecture for CIFAR-10 (input channels = 3)
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=2)  # 3 input channels for RGB images
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1)            # Conv layer 2
        self.fc1 = nn.Linear(16 * 5 * 5, 120)                             # Fully connected layer 1
        self.fc2 = nn.Linear(120, 84)                                      # Fully connected layer 2
        self.fc3 = nn.Linear(84, 10)                                        # Fully connected layer 3 (10 classes)
        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)                   # Average pooling layer
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))  # Conv1 -> ReLU -> Pool
        x = self.pool(self.relu(self.conv2(x)))  # Conv2 -> ReLU -> Pool
        x = x.view(-1, 16 * 5 * 5)               # Flatten tensor
        x = self.relu(self.fc1(x))               # FC1 -> ReLU
        x = self.relu(self.fc2(x))               # FC2 -> ReLU
        x = self.fc3(x)                          # FC3
        return x

# Instantiate model and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = LeNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# Function to train one epoch
def train_one_epoch(model, train_loader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

    accuracy = 100 * correct / total
    return running_loss / len(train_loader), accuracy

# Function to validate model
def validate(model, val_loader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    accuracy = 100 * correct / total
    return running_loss / len(val_loader), accuracy

# Training loop
num_epochs = 10
best_val_accuracy = 0.0
best_model_path = './best_lenet_cifar10.pth'

for epoch in range(num_epochs):
    train_loss, train_accuracy = train_one_epoch(model, train_loader, criterion, optimizer)
    val_loss, val_accuracy = validate(model, val_loader, criterion)

    print(f"Epoch [{epoch + 1}/{num_epochs}]")
    print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%")
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")

    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save(model.state_dict(), best_model_path)
        print(f"Best model saved with accuracy: {best_val_accuracy:.2f}%\n")

# Load best model
model.load_state_dict(torch.load(best_model_path, weights_only=True))
print(f"Best model loaded with accuracy: {best_val_accuracy:.2f}%")

# Evaluate on test set
test_loss, test_accuracy = validate(model, test_loader, criterion)
print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")
