# Author: Umang Srivastav

Trained ResNet18 on CIFAR100.

In [None]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.models as models

transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4865, 0.4409],
                         std=[0.2673, 0.2564, 0.2762])
])

train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False)

model = models.resnet18(pretrained=False, num_classes=100)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)

num_epochs = 100
for epoch in range(num_epochs):
    
    for images, labels in train_loader:
        
        if torch.cuda.is_available():
            images = images.cuda()
            labels = labels.cuda()

        optimizer.zero_grad()
        
        outputs = model(images)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            if torch.cuda.is_available():
                images = images.cuda()
                labels = labels.cuda()

            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        accuracy = 100 * correct / total
        print(f"Epoch [{epoch + 1}/{num_epochs}], Test Accuracy: {accuracy:.2f}%")


Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:13<00:00, 12855908.08it/s]


Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified




Epoch [1/100], Test Accuracy: 14.13%
Epoch [2/100], Test Accuracy: 20.08%
Epoch [3/100], Test Accuracy: 24.91%
Epoch [4/100], Test Accuracy: 28.42%
Epoch [5/100], Test Accuracy: 29.74%
Epoch [6/100], Test Accuracy: 33.80%
Epoch [7/100], Test Accuracy: 36.02%
Epoch [8/100], Test Accuracy: 36.77%
Epoch [9/100], Test Accuracy: 39.40%
Epoch [10/100], Test Accuracy: 40.20%
Epoch [11/100], Test Accuracy: 40.85%
Epoch [12/100], Test Accuracy: 41.95%
Epoch [13/100], Test Accuracy: 42.18%
Epoch [14/100], Test Accuracy: 43.57%
Epoch [15/100], Test Accuracy: 43.94%
Epoch [16/100], Test Accuracy: 44.28%
Epoch [17/100], Test Accuracy: 45.73%
Epoch [18/100], Test Accuracy: 44.61%
Epoch [19/100], Test Accuracy: 46.11%
Epoch [20/100], Test Accuracy: 46.22%
Epoch [21/100], Test Accuracy: 46.00%
Epoch [22/100], Test Accuracy: 46.26%
Epoch [23/100], Test Accuracy: 46.14%
Epoch [24/100], Test Accuracy: 46.70%
Epoch [25/100], Test Accuracy: 46.97%
Epoch [26/100], Test Accuracy: 47.10%
Epoch [27/100], Test 

In [None]:
torch.save(model.state_dict(), 'model18.pt')
