In [1]:
from tqdm import tqdm
import torch
import torchvision
import torchvision.transforms as transforms

In [2]:
transform = transforms.Compose(    
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=256,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=256,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [3]:
# Define list of "components"
components = {
    'car':   ['tire', 'window', 'door', 'tail light', 'windshield', 'wheel', 'side view mirror', 'bumper', 'headlight'],
    'plane': ['wing', 'cockpit', 'landing gear', 'fuselage', 'tail', 'rudder', 'jet engine'],
    'bird':  ['beak', 'eyes', 'feather', 'wing', 'tail', 'feet'], 
    'cat':   ['paw', 'claw', 'tail', 'whisker', 'eye', 'leg', 'fur', 'nose', 'ear'],
    'deer':  ['antler', 'eye', 'leg', 'tail', 'hoof', 'ear', 'nose', 'mouth'],
    'dog':   ['nose', 'eye', 'mouth', 'whisker', 'tail', 'paw', 'leg', 'ear'],
    'frog':  ['foot', 'leg', 'eye'],
    'horse': ['tail', 'leg', 'eye', 'mouth', 'hoof', 'muzzle', 'belly', 'mane'],
    'ship':  ['hull', 'deck', 'hatch', 'body', 'life buoy'],
    'truck': ['tire', 'window', 'door', 'tail light', 'windshield', 'wheel', 'side view mirror', 'bumper', 'headlight'],
}

In [4]:
import matplotlib.pyplot as plt
import numpy as np

# functions to show an image


def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()

# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

<Figure size 640x480 with 1 Axes>

  car plane truck  ship


In [5]:
import torchvision.models as models
vgg19 = models.vgg19()

In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

vgg19.to(device)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): ReLU(inplace)

In [7]:
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(vgg19.parameters(), lr=0.001, momentum=0.9)

In [8]:
for epoch in range(5):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in tqdm(enumerate(trainloader, 0)):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data[0].to(device), data[1].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = vgg19(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = vgg19(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print("Validation accuracy : %f" % (100 * correct / total))

print('Finished Training')

196it [00:18, 10.63it/s]


Validation accuracy : 10.050000


196it [00:18, 10.78it/s]


Validation accuracy : 9.510000


196it [00:18, 10.78it/s]


Validation accuracy : 10.180000


196it [00:18, 10.72it/s]


Validation accuracy : 11.010000


196it [00:18, 10.76it/s]


Validation accuracy : 18.230000
Finished Training


In [9]:
for epoch in range(50):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in tqdm(enumerate(trainloader, 0)):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data[0].to(device), data[1].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = vgg19(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = vgg19(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print("Validation accuracy : %f" % (100 * correct / total))

print('Finished Training')

196it [00:18, 10.74it/s]


Validation accuracy : 19.740000


196it [00:18, 10.72it/s]


Validation accuracy : 22.120000


196it [00:18, 10.81it/s]


Validation accuracy : 24.720000


196it [00:18, 10.81it/s]


Validation accuracy : 27.210000


196it [00:18, 10.81it/s]


Validation accuracy : 30.820000


196it [00:18, 10.70it/s]


Validation accuracy : 37.000000


196it [00:18, 10.72it/s]


Validation accuracy : 41.510000


196it [00:18, 10.72it/s]


Validation accuracy : 43.040000


196it [00:18, 10.67it/s]


Validation accuracy : 45.940000


196it [00:18, 10.64it/s]


Validation accuracy : 50.240000


196it [00:18, 10.75it/s]


Validation accuracy : 51.900000


196it [00:18, 10.65it/s]


Validation accuracy : 57.320000


196it [00:18, 10.62it/s]


Validation accuracy : 57.130000


196it [00:18, 10.70it/s]


Validation accuracy : 60.490000


196it [00:18, 10.76it/s]


Validation accuracy : 60.100000


196it [00:18, 10.82it/s]


Validation accuracy : 57.080000


196it [00:18, 10.76it/s]


Validation accuracy : 62.730000


196it [00:18, 10.70it/s]


Validation accuracy : 65.020000


196it [00:18, 10.72it/s]


Validation accuracy : 65.150000


196it [00:18, 10.69it/s]


Validation accuracy : 67.910000


196it [00:18, 10.66it/s]


Validation accuracy : 68.710000


196it [00:18, 10.71it/s]


Validation accuracy : 69.620000


196it [00:18, 10.65it/s]


Validation accuracy : 70.220000


196it [00:18, 10.70it/s]


Validation accuracy : 70.240000


196it [00:18, 10.70it/s]


Validation accuracy : 64.470000


196it [00:18, 10.69it/s]


Validation accuracy : 69.720000


196it [00:18, 10.68it/s]


Validation accuracy : 70.470000


196it [00:18, 10.77it/s]


Validation accuracy : 71.080000


196it [00:18, 10.66it/s]


Validation accuracy : 71.370000


196it [00:18, 10.71it/s]


Validation accuracy : 70.180000


196it [00:18, 10.70it/s]


Validation accuracy : 71.240000


196it [00:18, 10.69it/s]


Validation accuracy : 70.120000


196it [00:18, 10.78it/s]


Validation accuracy : 70.180000


196it [00:18, 10.76it/s]


Validation accuracy : 70.360000


196it [00:18, 10.69it/s]


Validation accuracy : 71.440000


196it [00:18, 10.75it/s]


Validation accuracy : 71.440000


196it [00:18, 10.72it/s]


Validation accuracy : 70.250000


196it [00:18, 10.62it/s]


Validation accuracy : 71.460000


196it [00:18, 10.65it/s]


Validation accuracy : 70.300000


196it [00:18, 10.77it/s]


Validation accuracy : 69.860000


196it [00:18, 10.69it/s]


Validation accuracy : 71.210000


196it [00:18, 10.72it/s]


Validation accuracy : 71.240000


196it [00:18, 10.69it/s]


Validation accuracy : 70.610000


196it [00:18, 10.74it/s]


Validation accuracy : 71.320000


196it [00:18, 10.63it/s]


Validation accuracy : 71.340000


196it [00:18, 10.78it/s]


Validation accuracy : 71.660000


196it [00:18, 10.71it/s]


Validation accuracy : 72.030000


196it [00:18, 10.64it/s]


Validation accuracy : 71.970000


196it [00:18, 10.71it/s]


Validation accuracy : 71.860000


196it [00:18, 10.68it/s]


Validation accuracy : 71.620000
Finished Training


In [43]:
torch.save(vgg19.state_dict(), './cifar_net_50.pth')