In [1]:
import numpy as np
import math
import torch
import torch.nn.functional as F 
import torchvision as tv
import torch.nn as nn
from torchvision import transforms
custom_transform= transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize((0.5, 0.5, 0.5), (0.5,0.5,0.5))                                   
])
train_dataset = tv.datasets.CIFAR10(root="./", train=True, transform=custom_transform, download=True) 
test_dataset = tv.datasets.CIFAR10(root="./", train=False, transform=custom_transform, download=True) 
 

 
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=4, num_workers=0,shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=4, num_workers=0, shuffle=False)
1
for  img, label in train_loader:
  print(img.shape)
  print(label.shape)
  break

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./cifar-10-python.tar.gz to ./
Files already downloaded and verified
torch.Size([4, 3, 32, 32])
torch.Size([4])


In [7]:

MODELNAME = 'cifar.model' 
EPOCH =10
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

class VGG16(nn.Module):
    def __init__(self):
        super(VGG16, self).__init__()
        self.conv1_1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv1_2 = nn.Conv2d(32, 32, 3,padding=1)

        self.conv2_1 = nn.Conv2d(32, 64, 3,padding=1)
        self.conv2_2 = nn.Conv2d(64, 64, 3,padding=1)

        self.conv3_1 = nn.Conv2d(64, 128, 3,padding=1)
        self.conv3_2 = nn.Conv2d(128, 128, 3,padding=1)
        self.conv3_3 = nn.Conv2d(128, 128, 3,padding=1)

        self.conv4_1 = nn.Conv2d(128, 256, 3,padding=1)
        self.conv4_2 = nn.Conv2d(256, 256, 3,padding=1)
        self.conv4_3 = nn.Conv2d(256, 256, 3,padding=1)

        self.conv5_1 = nn.Conv2d(256, 256, 3,padding=1)
        self.conv5_2 = nn.Conv2d(256, 256, 3,padding=1)
        self.conv5_3 = nn.Conv2d(256, 256, 3,padding=1)

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.fc1 = nn.Linear(2*2*256, 512, bias= True)
        self.fc2 = nn.Linear(512, 512, bias= True)
        self.fc3 = nn.Linear(512, 10, bias= True)
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()
    def forward(self, x):
        x = F.relu(self.conv1_1(x))
        x = F.relu(self.conv1_2(x))
        x = self.maxpool(x)
        x = F.relu(self.conv2_1(x))
        x = F.relu(self.conv2_2(x))
        x = self.maxpool(x)
        x = F.relu(self.conv3_1(x))
        x = F.relu(self.conv3_2(x))
        x = F.relu(self.conv3_3(x))
        x = self.maxpool(x)
        x = F.relu(self.conv4_1(x))
        x = F.relu(self.conv4_2(x))
        x = F.relu(self.conv4_3(x))
        x = self.maxpool(x)
        x = F.relu(self.conv5_1(x))
        x = F.relu(self.conv5_2(x))
        x = F.relu(self.conv5_3(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, 0.5) #dropout was included to combat overfitting
        x = F.relu(self.fc2(x))
        x = F.dropout(x, 0.5)
        x = self.fc3(x)
        return x

def train():
  model = VGG16().to(DEVICE)
  optimizer = torch.optim.SGD(model.parameters(), lr= 1e-3, momentum= 0.9) 
  for epoch in range(EPOCH):
    loss = 0
    for images, labels in train_loader:
      images = images.to(DEVICE) 
      labels = labels.to(DEVICE) 
      y = model(images)
      batchloss = F.cross_entropy(y,labels) 
      optimizer.zero_grad()
      batchloss.backward()  
      optimizer.step()
      loss = loss + batchloss.item() 
    print("epoch:",epoch,'loss:',loss)
  torch.save(model.state_dict(), MODELNAME) 
def test():
  correct = 0
  total = len(test_loader.dataset) 
  model = VGG16().to(DEVICE)
  model.load_state_dict(torch.load(MODELNAME)) 
  model.eval()
  for images, labels in test_loader:
    images = images.to(DEVICE) 
    labels = labels.to(DEVICE)
    y= model(images)
    pred_labels = y.max(dim=1)[1]
    correct =correct + (pred_labels == labels).sum() 
  print('correct:',correct.item()) 
  print('total:',total) 
  print('accuracy:',correct.item()/float(total))
train() 
test()

epoch: 0 loss: 28764.93603348732
epoch: 1 loss: 24621.17156279087
epoch: 2 loss: 20822.31701269746
epoch: 3 loss: 16863.88964672014
epoch: 4 loss: 13361.86581614986
epoch: 5 loss: 11260.119883395615
epoch: 6 loss: 9777.326023298083
epoch: 7 loss: 8625.140891194354
epoch: 8 loss: 7680.214925247954
epoch: 9 loss: 6896.973772802928
correct: 7677
total: 10000
accuracy: 0.7677
