In [9]:
import torch
from torch import nn
from torch import optim
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import copy
from datetime import datetime

plt.ion()

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [6]:
tx = transforms.Compose(
    [transforms.Resize((224, 224)), 
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

BATCH_SIZE = 32

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=tx)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

validate_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=tx)
validate_loader = DataLoader(validate_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


Use a pretrained ResNet model, reset the final layer to number of classes for our dataset.

In [7]:
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, len(classes))
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


HBox(children=(FloatProgress(value=0.0, max=46827520.0), HTML(value='')))




Tune weights of the model.

In [10]:
# best_model = copy.deepcopy(model.state_dict())
epochs = 10
for epoch in range(epochs):
  running_loss = 0.0
  for i, data in enumerate(train_loader, 0):
    inputs, labels = data[0].to(device), data[1].to(device)
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    running_loss += loss.item()
    if i % 100 == 99:  # print every 100 mini-batches
      print(f'{datetime.now().time().replace(microsecond=0)} --- '
            f'Epoch: {epoch + 1}\t'
            f'Batch: {i + 1}\t'
            f'Loss: {running_loss / 100:.3f}')
      running_loss = 0.0
  correct = 0
  total = 0

  with torch.no_grad():
    for data in validate_loader:
      images, labels = data[0].to(device), data[1].to(device)
      outputs = model(images)
      _, predicted = torch.max(outputs.data, 1)
      total += labels.size(0)
      correct += (predicted == labels).sum().item()
  print(f'{correct} correct predictions out of {total}. Accuracy: {100 * correct / total}')

03:57:54 --- Epoch: 1	Batch: 100	Loss: 1.251
03:58:04 --- Epoch: 1	Batch: 200	Loss: 0.964
03:58:14 --- Epoch: 1	Batch: 300	Loss: 0.820
03:58:24 --- Epoch: 1	Batch: 400	Loss: 0.782
03:58:34 --- Epoch: 1	Batch: 500	Loss: 0.712
03:58:44 --- Epoch: 1	Batch: 600	Loss: 0.670
03:58:54 --- Epoch: 1	Batch: 700	Loss: 0.627
03:59:04 --- Epoch: 1	Batch: 800	Loss: 0.619
03:59:15 --- Epoch: 1	Batch: 900	Loss: 0.605
03:59:25 --- Epoch: 1	Batch: 1000	Loss: 0.569
03:59:35 --- Epoch: 1	Batch: 1100	Loss: 0.546
03:59:46 --- Epoch: 1	Batch: 1200	Loss: 0.517
03:59:56 --- Epoch: 1	Batch: 1300	Loss: 0.548
04:00:06 --- Epoch: 1	Batch: 1400	Loss: 0.522
04:00:17 --- Epoch: 1	Batch: 1500	Loss: 0.446
8278 correct predictions out of 10000. Accuracy: 82.78
04:00:51 --- Epoch: 2	Batch: 100	Loss: 0.402
04:01:02 --- Epoch: 2	Batch: 200	Loss: 0.405
04:01:12 --- Epoch: 2	Batch: 300	Loss: 0.412
04:01:23 --- Epoch: 2	Batch: 400	Loss: 0.412
04:01:33 --- Epoch: 2	Batch: 500	Loss: 0.428
04:01:44 --- Epoch: 2	Batch: 600	Loss: 