In [None]:
# Import required libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Step 1: Define transformations
transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Step 2: Load CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

# Step 3: Define DataLoaders
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

# Step 4: Define the Simplified CNN Classifier
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 8 * 8, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# Step 5: Initialize model, loss function, and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Step 6: Train the Classifier
epochs = 25
for epoch in range(epochs):
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {train_loss/len(train_loader):.4f}, Accuracy: {accuracy:.2f}%")

# Step 7: Evaluate the Model on Test Data
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = 100 * correct / total
print(f"Test Accuracy: {test_accuracy:.2f}%")


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:03<00:00, 48.2MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Epoch [1/25], Loss: 1.3600, Accuracy: 51.43%
Epoch [2/25], Loss: 0.9910, Accuracy: 65.29%
Epoch [3/25], Loss: 0.8398, Accuracy: 70.67%
Epoch [4/25], Loss: 0.7329, Accuracy: 74.24%
Epoch [5/25], Loss: 0.6379, Accuracy: 77.69%
Epoch [6/25], Loss: 0.5555, Accuracy: 80.68%
Epoch [7/25], Loss: 0.4801, Accuracy: 83.20%
Epoch [8/25], Loss: 0.4096, Accuracy: 85.64%
Epoch [9/25], Loss: 0.3431, Accuracy: 88.01%
Epoch [10/25], Loss: 0.2834, Accuracy: 90.09%
Epoch [11/25], Loss: 0.2333, Accuracy: 91.83%
Epoch [12/25], Loss: 0.1870, Accuracy: 93.55%
Epoch [13/25], Loss: 0.1584, Accuracy: 94.54%
Epoch [14/25], Loss: 0.1322, Accuracy: 95.52%
Epoch [15/25], Loss: 0.1122, Accuracy: 96.11%
Epoch [16/25], Loss: 0.1019, Accuracy: 96.36%
Epoch [17/25], Loss: 0.0979, Accuracy: 96.57%
Epoch [18/25], Loss: 0.0774, Accuracy: 97.31%
Epoch [19/25], Loss: 0.0752, Accuracy: 97.37%
Epoch [20/25], Loss: 0.0699, Accuracy: 97.57%
