In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [3]:
# Data Loading
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_set = torchvision.datasets.CIFAR10(root='./CIFAR_data', train=True, download=True, transform=transform)
test_set = torchvision.datasets.CIFAR10(root='./CIFAR_data', train=False, download=True, transform=transform)

classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')

batch_size = 32
train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=batch_size,shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=batch_size,shuffle=False)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./CIFAR_data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./CIFAR_data/cifar-10-python.tar.gz to ./CIFAR_data
Files already downloaded and verified


In [4]:
# CNN model
class LeNet(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)        
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)


    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [5]:
model = LeNet().to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
loss_fn = nn.CrossEntropyLoss()

In [6]:
# TRAINING 

n_epoch = 10

for epoch in range(n_epoch):

    # training
    for batch_idx, (x, target) in enumerate(train_loader):

        x, target = x.to(device), target.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = loss_fn(out, target)
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print('epoch: ', epoch, '  batch_idx: ', batch_idx, loss.item())


epoch:  0   batch_idx:  0 2.3146002292633057
epoch:  0   batch_idx:  100 2.1598916053771973
epoch:  0   batch_idx:  200 2.0165092945098877
epoch:  0   batch_idx:  300 2.1527647972106934
epoch:  0   batch_idx:  400 1.9253326654434204
epoch:  0   batch_idx:  500 2.0609073638916016
epoch:  0   batch_idx:  600 1.7858139276504517
epoch:  0   batch_idx:  700 1.5747385025024414
epoch:  0   batch_idx:  800 1.6107089519500732
epoch:  0   batch_idx:  900 1.6359246969223022
epoch:  0   batch_idx:  1000 1.6037461757659912
epoch:  0   batch_idx:  1100 1.5823832750320435
epoch:  0   batch_idx:  1200 1.4657402038574219
epoch:  0   batch_idx:  1300 1.6023273468017578
epoch:  0   batch_idx:  1400 1.5418870449066162
epoch:  0   batch_idx:  1500 1.5252995491027832
epoch:  1   batch_idx:  0 1.4479082822799683
epoch:  1   batch_idx:  100 1.2772544622421265
epoch:  1   batch_idx:  200 1.4528390169143677
epoch:  1   batch_idx:  300 1.3328099250793457
epoch:  1   batch_idx:  400 1.440908432006836
epoch:  1   

In [7]:
# evaluation 
        
total_cnt = 0
correct_cnt = 0

for batch_idx, (x, target) in enumerate(test_loader):
    x, target = x.to(device), target.to(device)
    out = model(x)
    _, pred_label = torch.max(out, 1)
    total_cnt += len(x)
    correct_cnt += (pred_label == target).sum().item()

print('=======> Accuracy: ', correct_cnt*1.0/total_cnt)


