In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import os
import matplotlib.pyplot as plt

transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

batch_size = 64
trainset = torchvision.datasets.CIFAR10(root= os.getcwd(), train = True,
                                        download = True, transform = transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size = batch_size,
                                          shuffle = True, num_workers = 2)

testset = torchvision.datasets.CIFAR10(root= os.getcwd(), train = False,
                                       download = True, transform = transform)
testloader = torch.utils.data.DataLoader(testset, batch_size = batch_size, shuffle=False, num_workers = 2)


classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [2]:
import torch.nn as nn
import torch.nn.functional as F

def init_weights(m):
    if type(m) in [nn.Linear, nn.Conv2d]:
        nn.init.xavier_uniform_(m.weight)

In [3]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        base = torchvision.models.resnet18(pretrained=True)
        self.base = nn.Sequential(*list(base.children())[:-1])
        in_features = base.fc.in_features
        self.drop = nn.Dropout()
        self.final = nn.Linear(in_features,10)

    def forward(self,x):
        x = self.base(x)
        x = self.drop(x.view(-1,self.final.in_features))
        return self.final(x)


In [None]:
import torch.optim as optim
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net = Model()
criterion = nn.CrossEntropyLoss()
param_groups = [
    {'params':net.base.parameters(),'lr':.0001},
    {'params':net.final.parameters(),'lr':.001}
]
trainer = optim.Adam(param_groups)
net.to(device)
criterion.to(device)

In [49]:
for epoch in range(10):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs = inputs.cuda()
        labels = labels.cuda()
        trainer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        trainer.step()
        running_loss += loss.item()
        if i % 100 == 99:    # print every 100 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}')
            running_loss = 0.0

print('Finished Training')

[1,   100] loss: 0.902
[1,   200] loss: 0.401
[1,   300] loss: 0.304
[1,   400] loss: 0.296
[1,   500] loss: 0.282
[1,   600] loss: 0.258
[1,   700] loss: 0.251
[2,   100] loss: 0.122
[2,   200] loss: 0.129
[2,   300] loss: 0.130
[2,   400] loss: 0.124
[2,   500] loss: 0.132
[2,   600] loss: 0.130
[2,   700] loss: 0.166
[3,   100] loss: 0.064
[3,   200] loss: 0.059
[3,   300] loss: 0.065
[3,   400] loss: 0.083
[3,   500] loss: 0.075
[3,   600] loss: 0.073
[3,   700] loss: 0.081
[4,   100] loss: 0.047
[4,   200] loss: 0.042
[4,   300] loss: 0.043
[4,   400] loss: 0.052
[4,   500] loss: 0.056
[4,   600] loss: 0.064
[4,   700] loss: 0.064
[5,   100] loss: 0.044
[5,   200] loss: 0.044
[5,   300] loss: 0.042
[5,   400] loss: 0.047
[5,   500] loss: 0.066
[5,   600] loss: 0.060
[5,   700] loss: 0.045
[6,   100] loss: 0.034
[6,   200] loss: 0.025
[6,   300] loss: 0.034
[6,   400] loss: 0.034
[6,   500] loss: 0.037
[6,   600] loss: 0.040
[6,   700] loss: 0.044
[7,   100] loss: 0.034
[7,   200] 

In [5]:
net.load_state_dict(torch.load('model_state_dict.pth'))

<All keys matched successfully>

In [7]:
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images = images.cuda()
        labels = labels.cuda()
        outputs = net(images)
        predicted = outputs.argmax(1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct / total} %')

Accuracy of the network on the 10000 test images: 93.46 %
