In [13]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from torchsummary import summary
from torchviz import make_dot
from torch.autograd import Variable
from statistics import mean
import numpy as np

In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [3]:
input_size = 32
batch_size = 100

transform = transforms.Compose([
 transforms.Pad(4),
 transforms.RandomHorizontalFlip(),
 transforms.RandomCrop(32),
 transforms.ToTensor(),
 transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) # only can do to tensor so keep order
])

train_dataset = datasets.CIFAR10('C:\data/cifar10', train=True, download=True, transform=transform)

train_loader = DataLoader(
    dataset= train_dataset,
    batch_size=batch_size,
    shuffle=True)

Files already downloaded and verified


In [4]:
valid_dataset = datasets.CIFAR10(root='C:\data/',
                                            train=False, 
                                            transform=transforms.ToTensor())

valid_loader = DataLoader(dataset=valid_dataset,
                                          batch_size=batch_size,
                                          shuffle=False)
                                    

In [5]:
class MnistModel(nn.Module):
    def __init__(self):
        super(MnistModel, self).__init__()

        self.relu = nn.ReLU()

        self.conv0 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU()
        )

        self.block11 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(16)
        )

        self.block12 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(16)
        )

        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1, stride=2, bias=False)

        self.block21 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1, stride=2, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(32)
        )

        self.block22 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(32)
        )

        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1, stride=2, bias=False)

        self.block31 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1, stride=2, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(64)
        )

        self.block32 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(64)
        )

        self.avg_pool = nn.AvgPool2d(8)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(64, 10)

    def forward(self, x):
        out0 = self.conv0(x)
        out1 = self.block11(out0)
        out1 = self.block12(out1)

        res2 = self.conv2(out1)
        out2 = self.block21(out1)
        out2 = self.block22(out2)
        out2 += res2
        out2 = self.relu(out2)

        res3 = self.conv3(out2)
        out3 = self.block31(out2)
        out3 = self.block32(out3)
        out3 += res3
        out3 = self.relu(out3)

        out3 = self.avg_pool(out3)
        out3 = self.flatten(out3)
        out = self.fc(out3)

        return out

In [6]:
model = MnistModel().to(device)

In [7]:
InTensor = Variable(torch.randn(1, 3, 32, 32)).to(device)
make_dot(model(InTensor), params=dict(model.named_parameters())).render("model", format="png")

'model.png'

In [8]:
summary(model, input_size=(3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 32, 32]             432
       BatchNorm2d-2           [-1, 16, 32, 32]              32
              ReLU-3           [-1, 16, 32, 32]               0
            Conv2d-4           [-1, 16, 32, 32]           2,304
       BatchNorm2d-5           [-1, 16, 32, 32]              32
              ReLU-6           [-1, 16, 32, 32]               0
            Conv2d-7           [-1, 16, 32, 32]           2,304
       BatchNorm2d-8           [-1, 16, 32, 32]              32
            Conv2d-9           [-1, 16, 32, 32]           2,304
      BatchNorm2d-10           [-1, 16, 32, 32]              32
             ReLU-11           [-1, 16, 32, 32]               0
           Conv2d-12           [-1, 16, 32, 32]           2,304
      BatchNorm2d-13           [-1, 16, 32, 32]              32
           Conv2d-14           [-1, 32,

In [9]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [14]:
loss_dict = {}
val_loss_dict = {}
train_step = len(train_loader)
val_step = len(valid_loader)
epochs = 10

for i in range(1, epochs + 1):
    loss_list = [] # losses of i'th epoch
    for train_step_idx, (img, label) in enumerate(train_loader):
        img = img.to(device)
        label = label.to(device)

        output = model(img)
        loss = loss_fn(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_list.append(loss.item())

        if ((train_step_idx+1) % 100 == 0):
            print(f"Epoch [{i}/{epochs}] Step [{train_step_idx + 1}/{train_step}] Loss: {loss.item():.4f}")

    loss_dict[i] = loss_list

    with torch.no_grad():
        val_loss_list = []
        for val_step_idx, (val_img, val_label) in enumerate(valid_loader):
            val_img = val_img.to(device)
            val_label = val_label.to(device)

            val_output = model(val_img)
            val_loss = loss_fn(val_output, val_label)

            val_loss_list.append(val_loss.item())

        val_loss_dict[i] = val_loss_list

    print(f"Epoch [{i}] Train Loss: {mean(loss_dict[i]):.4f} Val Loss: {mean(val_loss_dict[i]):.4f}")
    print("========================================================================================")


Epoch [1/10] Step [100/500] Loss: 0.3297
Epoch [1/10] Step [200/500] Loss: 0.4739
Epoch [1/10] Step [300/500] Loss: 0.5288
Epoch [1/10] Step [400/500] Loss: 0.4637
Epoch [1/10] Step [500/500] Loss: 0.2659
Epoch [1] Train Loss: 0.4161 Val Loss: 0.5497
Epoch [2/10] Step [100/500] Loss: 0.3478
Epoch [2/10] Step [200/500] Loss: 0.3011
Epoch [2/10] Step [300/500] Loss: 0.4150
Epoch [2/10] Step [400/500] Loss: 0.3313
Epoch [2/10] Step [500/500] Loss: 0.3554
Epoch [2] Train Loss: 0.4067 Val Loss: 0.5482
Epoch [3/10] Step [100/500] Loss: 0.3674
Epoch [3/10] Step [200/500] Loss: 0.5226
Epoch [3/10] Step [300/500] Loss: 0.3693
Epoch [3/10] Step [400/500] Loss: 0.4384
Epoch [3/10] Step [500/500] Loss: 0.3910
Epoch [3] Train Loss: 0.3973 Val Loss: 0.5470
Epoch [4/10] Step [100/500] Loss: 0.3294
Epoch [4/10] Step [200/500] Loss: 0.3570
Epoch [4/10] Step [300/500] Loss: 0.5070
Epoch [4/10] Step [400/500] Loss: 0.3626
Epoch [4/10] Step [500/500] Loss: 0.4844
Epoch [4] Train Loss: 0.3889 Val Loss: 0.5

In [15]:
torch.save(model.state_dict(), 'resnet.pt')