In [1]:
import torch
import torchvision
from torchvision import datasets
from torch.utils.data import DataLoader, ConcatDataset

device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"

In [2]:
model = torchvision.models.resnet50(weights='IMAGENET1K_V1')

for param in model.parameters():
    param.requires_grad = False
model.fc = torch.nn.Linear(model.fc.in_features, 10)

optim = torch.optim.Adam(model.fc.parameters(), lr=0.001)

criterion = torch.nn.CrossEntropyLoss()

In [3]:
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]

transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(imagenet_mean, imagenet_std)
])

train_dataset = datasets.CIFAR10(
    root="./data",
    train=True,
    download=True,
    transform=transform
) 

test_dataset = datasets.CIFAR10(
        root="./data",
        train=False,
        download=True,
        transform=transform
    )

train_loader = DataLoader(
    train_dataset,
    batch_size=128,
    shuffle=True,
    num_workers=2
)

test_loader = DataLoader(
        test_dataset,
        batch_size= 128,
        shuffle=False,
        num_workers=2
    )

cifar10_classes = train_dataset.classes

In [4]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.to(device)
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    avg_loss = running_loss / total
    accuracy = 100. * correct / total
    return avg_loss, accuracy

In [5]:
def evaluate(model, test_loader, loss_fn, classes, device):
    model.to(device)
    model.eval()
    correct = 0
    total = 0
    class_correct = [0] * len(classes)
    class_total = [0] * len(classes)
    class_accuracies = {}

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)

            preds = model.forward(images)
            loss = loss_fn(preds, labels)
            preds = torch.argmax(preds, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

            for i in range(len(labels)):
                label = labels[i].item()
                if preds[i].item() == label:
                    class_correct[label] += 1
                class_total[label] += 1
    accuracy = correct / total * 100
    
    for i, classname in enumerate(classes):
        acc = 100.0 * class_correct[i] / class_total[i]
        class_accuracies[classname] = acc

    return accuracy, class_accuracies

In [6]:
for i in range(3):
    _, train_acc = train_epoch(model, train_loader, criterion, optim, device)
    print(f"Epoch {i+1}.   Train Accuracy:  {train_acc:.2f}%")
    

Epoch 1.   Train Accuracy:  74.34%
Epoch 2.   Train Accuracy:  79.89%
Epoch 3.   Train Accuracy:  80.90%


In [7]:
accuracy, class_accuracies = evaluate(model, test_loader, criterion, cifar10_classes, device)
print(f"ResNet CIFAR-10 Test Accuracy: {accuracy:.2f}%")
print("Per-class accuracy:")
for classname, acc in class_accuracies.items():
    print(f"    {classname:10s}: {acc:.2f}%")

ResNet CIFAR-10 Test Accuracy: 81.20%
Per-class accuracy:
    airplane  : 82.20%
    automobile: 88.00%
    bird      : 68.50%
    cat       : 67.70%
    deer      : 80.50%
    dog       : 73.20%
    frog      : 90.80%
    horse     : 86.40%
    ship      : 89.00%
    truck     : 85.70%
