In [7]:
import torch
from torchvision.datasets import CIFAR100, CIFAR10
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pickle

In [8]:
ROOT_PATH = 'data'

BATCH_SIZE = 10000

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_dataset = CIFAR100(root=ROOT_PATH, download=True, train=True, transform=transform)
eval_dataset = CIFAR100(root=ROOT_PATH, train=False, transform=transform)

# train_dataset = CIFAR10(root=ROOT_PATH, download=True, train=True, transform=transform)
# eval_dataset = CIFAR10(root=ROOT_PATH, train=False, transform=transform)

train_data_loader = DataLoader(dataset=train_dataset, num_workers=4, batch_size=BATCH_SIZE, shuffle=True)
eval_data_loader = DataLoader(dataset=eval_dataset, num_workers=4, batch_size=BATCH_SIZE, shuffle=False)

Files already downloaded and verified


In [9]:
class ConvNN(torch.nn.Module):
    def __init__(self):
        super(ConvNN, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1 = torch.nn.BatchNorm2d(64)
        self.conv2 = torch.nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2 = torch.nn.BatchNorm2d(128)
        self.pool1 = torch.nn.MaxPool2d(2, 2)
        self.dropout1 = torch.nn.Dropout(0.25)
        self.conv3 = torch.nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn3 = torch.nn.BatchNorm2d(256)
        self.pool2 = torch.nn.MaxPool2d(2, 2)
        self.dropout2 = torch.nn.Dropout(0.25)
        self.fc1 = torch.nn.Linear(8*8*256, 1024)
        self.bn4 = torch.nn.BatchNorm1d(1024)
        self.dropout3 = torch.nn.Dropout(0.5)
        self.fc2 = torch.nn.Linear(1024, 100)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool1(x)
        x = self.dropout1(x)
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.pool2(x)
        x = self.dropout2(x)
        x = x.view(-1, 8*8*256)
        x = F.relu(self.bn4(self.fc1(x)))
        x = self.dropout3(x)
        x = self.fc2(x)
        return x


In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

num_epochs = 4
learning_rate = 0.001

In [11]:
CNN_model = ConvNN().to(device)

In [12]:
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(CNN_model.parameters(), lr= learning_rate)
n_steps = len(train_data_loader)

In [None]:
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
optimizer = torch.optim.Adam(CNN_model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    total_loss = 0
    for i, (images, labels) in enumerate(train_data_loader):
        images = images.to(device)
        labels = labels.to(device)
        outputs = CNN_model(images)
        loss = loss_fn(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        if (i + 1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_steps}], Loss: {loss.item():.4f}')
    avg_loss = total_loss / n_steps
    print(f'End of Epoch {epoch+1}, Average Loss: {avg_loss:.4f}')
    scheduler.step()


End of Epoch 1, Average Loss: 3.8386


  "See more details at "


End of Epoch 2, Average Loss: 3.0894


In [None]:
def evaluate_model(model, data_loader):
    model.eval()
    total = 0
    correct = 0
    with torch.no_grad():
        for images, labels in data_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

In [None]:
validation_accuracy = evaluate_model(CNN_model, eval_data_loader)
print(f'Validation Accuracy after epoch {epoch+1}: {validation_accuracy:.2f}%')

In [None]:
torch.save(CNN_model, 'CIFAR100_CNN.pth')