In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.chdir("../")

In [3]:
import torch
import numpy as np
import torchvision
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torchvision.transforms as transforms

from propinf.training.models import AlexnetCNN, SimpleCNN
from propinf.training.training_utils import fit

In [4]:
device = "cuda"
batch_size = 256

## Data

In [5]:
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [6]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

trainset = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, shuffle=True, num_workers=8
)

testset = torchvision.datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transform
)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size, shuffle=False, num_workers=8
)

classes = (
    "plane",
    "car",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
)

# # get some random training images
# dataiter = iter(trainloader)
# images, labels = next(dataiter)

# # show images
# imshow(torchvision.utils.make_grid(images))
# # print labels
# print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))

Files already downloaded and verified
Files already downloaded and verified


In [7]:
data_loaders = {"train": trainloader, "val": testloader}

## Model

In [8]:
def train(model, dataloaders, criterion, optimizer, num_epochs=25, device="cuda"):
    # Initialize variables to monitor training and validation loss
    train_loss = []
    val_loss = []
    for epoch in range(num_epochs):
        # Initialize variables to monitor training accuracy
        train_acc = 0.0
        val_acc = 0.0
        # Set the model to train mode
        model.train()
        for i, (inputs, labels) in enumerate(dataloaders["train"]):
            # Move input and label tensors to the GPU
            inputs = inputs.to(device)
            labels = labels.to(device)
            # Zero the parameter gradients
            optimizer.zero_grad()
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            # Backward pass and optimization
            loss.backward()
            optimizer.step()
            # Update training accuracy
            _, preds = torch.max(outputs, 1)
            train_acc += torch.sum(preds == labels.data)
        # Print training statistics for the epoch
        train_loss.append(loss)
        print(
            "Epoch: {}/{}, Training Loss: {:.4f}, Training Accuracy: {:.4f}".format(
                epoch + 1,
                num_epochs,
                loss.item(),
                train_acc / len(dataloaders["train"].dataset),
            )
        )
        # Set the model to evaluate mode
        model.eval()
        for i, (inputs, labels) in enumerate(dataloaders["val"]):
            # Move input and label tensors to the GPU
            inputs = inputs.to(device)
            labels = labels.to(device)
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            # Update validation accuracy
            _, preds = torch.max(outputs, 1)
            val_acc += torch.sum(preds == labels.data)
        val_loss.append(loss)
        # Print validation statistics
        print(
            "\t\tValidation Loss: {:.4f}, Validation Accuracy: {:.4f}".format(
                # epoch + 1,
                # num_epochs,
                loss.item(),
                val_acc / len(dataloaders["val"].dataset),
            )
        )
    return train_loss, val_loss


def evaluate_accuracy(model, dataloader, device):
    model.eval()  # set the model to evaluation mode
    correct = 0
    total = 0
    with torch.no_grad():  # temporarily set all the requires_grad flag to false
        for data in dataloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    return accuracy


In [9]:
# model = AlexnetCNN()
model = SimpleCNN()
model = model.train()
model = model.to(device)

In [10]:
optimizer = optim.Adam(model.parameters(), lr = 0.001, weight_decay=5e-4,)
criterion = nn.CrossEntropyLoss()


In [11]:
%%time
train_loss, val_loss = train(
    model=model,
    dataloaders=data_loaders,
    criterion=criterion,
    optimizer=optimizer,
    num_epochs=35,
    device=device,
)

Epoch: 1/35, Training Loss: 1.4088, Training Accuracy: 0.4254
		Validation Loss: 1.2181, Validation Accuracy: 0.5386
Epoch: 2/35, Training Loss: 0.9450, Training Accuracy: 0.5652
		Validation Loss: 0.8416, Validation Accuracy: 0.6252
Epoch: 3/35, Training Loss: 0.8201, Training Accuracy: 0.6395
		Validation Loss: 0.8000, Validation Accuracy: 0.6685
Epoch: 4/35, Training Loss: 0.7803, Training Accuracy: 0.6833
		Validation Loss: 0.8747, Validation Accuracy: 0.6895
Epoch: 5/35, Training Loss: 0.8084, Training Accuracy: 0.7141
		Validation Loss: 0.5293, Validation Accuracy: 0.7210
Epoch: 6/35, Training Loss: 1.0032, Training Accuracy: 0.7392
		Validation Loss: 0.4488, Validation Accuracy: 0.7381
Epoch: 7/35, Training Loss: 0.6486, Training Accuracy: 0.7561
		Validation Loss: 0.6480, Validation Accuracy: 0.7568
Epoch: 8/35, Training Loss: 0.5342, Training Accuracy: 0.7738
		Validation Loss: 0.4679, Validation Accuracy: 0.7661
Epoch: 9/35, Training Loss: 0.5796, Training Accuracy: 0.7864
		

In [12]:
model_acc = evaluate_accuracy(model, data_loaders["val"], device)
print(f"Model accuracy: {model_acc}")


Model accuracy: 0.8052


In [13]:
model

SimpleCNN(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=2048, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=10, bias=True)
  (dropout): Dropout(p=0.25, inplace=False)
)