In [1]:
import torch
import torchvision
from torchvision import datasets
from torchvision import models
from torchvision import transforms
from torch.nn import functional as F
from dataclasses import dataclass
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import numpy as np

In [2]:
MEAN = (0.4914, 0.4822, 0.4465)
STD = (0.2023, 0.1994, 0.2010)
BATCH_SIZE = 100
LR = 0.1

def get_data():
    transform_train = torchvision.transforms.Compose([
        torchvision.transforms.RandomCrop(32, padding=4),
        torchvision.transforms.RandomHorizontalFlip(),
        # Convert from numppy array to tensor, (H, W, C) -> (C, H, W), [0, 255] -> [0.0, 1.0]
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=MEAN, std=STD)
    ])
    transform_test = torchvision.transforms.Compose([
        # Convert from numppy array to tensor, (H, W, C) -> (C, H, W), [0, 255] -> [0.0, 1.0]
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=MEAN, std=STD)
    ])
    
    trainset = torchvision.datasets.CIFAR10('./data', train=True, download=True, transform=transform_train)
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    
    testset = torchvision.datasets.CIFAR10('./data', train=False, download=True, transform=transform_test)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    
    classes = trainset.classes
    return train_loader, test_loader, classes

In [3]:
class SimpleCNNModel(torch.nn.Module):
    def __init__(self):
        super(SimpleCNNModel, self).__init__()
        self.layer1 = torch.nn.Sequential(
            torch.nn.Conv2d(3, 32, 5),
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2)
        )
        
        self.layer2 = torch.nn.Sequential(
            torch.nn.Conv2d(32, 64, 5),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2)
        )
        
        self.dropout1 = torch.nn.Dropout(0.25)
        self.dropout2 = torch.nn.Dropout(0.5)
        
        self.fc1 = torch.nn.Linear(64*5*5, 600)
        self.fc2 = torch.nn.Linear(600, 256)
        self.fc3 = torch.nn.Linear(256, 128)
        self.fc4 = torch.nn.Linear(128, 10)
    
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.dropout1(x)
        x = x.view(x.size(0), -1)
        
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x
model = SimpleCNNModel()

In [4]:
# in_chanels = 3
# out_chanels = 6
# kernel_size = 3
# conv1 = Conv(in_chanels, out_chanels, kernel_size)
# weight = conv1.weight.data.numpy()
# random_chanels = np.random.randint(out_chanels)
# plt.imshow(weight[random_chanels, :, :, :])

In [5]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
train_loader, test_loader, classes = get_data()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
epochs = 200
interal = 100
best_acc = 0

for epoch in range(epochs):
    running_loss = 0
    for i, batch in enumerate(train_loader):
        inputs, labels =  batch
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if (i+1) % interal == 0:
            correct = 0
            total = 0
            for batch in test_loader:
                inputs, labels = batch
                inputs = inputs.to(device)
                labels = labels.to(device)
                with torch.no_grad():
                    outputs = model(inputs)
                    predictions = torch.max(outputs, 1)[1].to(device)
                    correct += (labels == predictions).sum()
                    total += len(labels)
            acc = 100 * correct / total
            print('Epoch: {}\t[{}/{} ({:0.0f}%)]\tLoss:{:.6f}\tAcc {:.6f}'.format(
                epoch + 1,
                i * len(inputs),
                len(train_loader.dataset),
                100 * i / len(train_loader),
                running_loss / interal,
                acc))
            running_loss = 0
            
            if acc > best_acc:
                state = {
                    'model': model.state_dict(),
                    'acc': acc,
                    'epoch': epoch
                }
                torch.save(state, './ckpt.pth')
                print('checkpoint was saved!')
                best_acc = acc
    # adjust lr
    scheduler.step()
print('Training was finished.')

Files already downloaded and verified
Files already downloaded and verified
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!
checkpoint was saved!


In [7]:
model = SimpleCNNModel()
path = './ckpt.pth'
state_dict = torch.load(path)['model']
model.load_state_dict(state_dict)
model.eval()

SimpleCNNModel(
  (layer1): Sequential(
    (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (layer2): Sequential(
    (0): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (dropout1): Dropout(p=0.25, inplace=False)
  (dropout2): Dropout(p=0.5, inplace=False)
  (fc1): Linear(in_features=1600, out_features=600, bias=True)
  (fc2): Linear(in_features=600, out_features=256, bias=True)
  (fc3): Linear(in_features=256, out_features=128, bias=True)
  (fc4): Linear(in_features=128, out_features=10, bias=True)
)

In [11]:
outputs = model(images)
_, predicted = torch.max(outputs, 1)
print('GroundTruth:', ' '.join('%5s' % classes[labels[i]] for i in range(len(classes))))
print('Predicted: ', ' '.join('%5s' % classes[predicted[i]] for i in range(len(classes))))

GroundTruth:  ship   cat  frog  frog airplane automobile automobile automobile  ship automobile
Predicted:    cat  ship  ship airplane  frog  frog automobile  frog   cat automobile


In [9]:
class_correct = [0] * len(classes)
class_total = [0] * len(classes)

In [10]:
with torch.no_grad():
    for batch in test_loader:
        inputs, labels = batch
        outputs = model(inputs)
        _, predicted_idxs = torch.max(outputs, 1)
        same_idxs = (predicted_idxs == labels).squeeze()
        for i in range(len(labels)):
            label = labels[i].item()
            class_correct[label] += int(same_idxs[i])
            class_total[label] += 1
            
for i in range(len(classes)):
    print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))

Accuracy of airplane : 87 %
Accuracy of automobile : 91 %
Accuracy of  bird : 75 %
Accuracy of   cat : 69 %
Accuracy of  deer : 87 %
Accuracy of   dog : 70 %
Accuracy of  frog : 92 %
Accuracy of horse : 84 %
Accuracy of  ship : 89 %
Accuracy of truck : 92 %
