In [None]:
%matplotlib inline
import torch
import numpy as np
from matplotlib import pyplot as plt

torch.set_printoptions(edgeitems=2)
torch.manual_seed(123)

In [None]:
import torch.nn as nn
conv = nn.Conv2d(3, 16, kernel_size=3)
conv

In [None]:
conv.weight.shape, conv.bias.shape

In [None]:
from torchvision import datasets
from torchvision import transforms
cifar10 = datasets.CIFAR10('data', train=True, download=False,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize(
                                   (0.4915, 0.4823, 0.4468), (0.2470, 0.2435, 0.2616))]))
cifar10_val = datasets.CIFAR10('data', train=False, download=False,
                               transform=transforms.Compose([
                                   transforms.ToTensor(),
                                   transforms.Normalize(
                                       (0.4915, 0.4823, 0.4468), (0.2470, 0.2435, 0.2616))]))

In [None]:
label_map = {0: 0, 2: 1}
cifar2 = [(img, label_map[label]) for img, label in cifar10 if label in [0, 2]]
cifar2_val = [(img, label_map[label])
              for img, label in cifar10_val if label in [0, 2]]

In [None]:
img, _ = cifar2[0]
plt.imshow(img.permute(1, 2, 0))
plt.show()

In [None]:
output = conv(img.unsqueeze(0))
img.unsqueeze(0).shape, output.shape

In [None]:
plt.imshow(output[0, 0].detach(), cmap='gray')
plt.show()

In [None]:
conv = nn.Conv2d(3, 1, kernel_size=3, padding=1)
output = conv(img.unsqueeze(0))
img.unsqueeze(0).shape, output.shape

In [None]:
with torch.no_grad():
    conv.bias.zero_()

with torch.no_grad():
    conv.weight.fill_(1.0/9.0)

In [None]:
output = conv(img.unsqueeze(0))
plt.imshow(output[0, 0].detach(), cmap='gray')
plt.show()

In [None]:
conv = nn.Conv2d(3, 1, kernel_size=3, padding=1)

with torch.no_grad():
    conv.weight[:] = torch.tensor([[-1.0, 0.0, 1.0],
                                   [-1.0, 0.0, 1.0],
                                   [-1.0, 0.0, 1.0]])
    conv.bias.zero_()

In [None]:
output = conv(img.unsqueeze(0))
plt.imshow(output[0, 0].detach(), cmap='gray')
plt.show()

In [None]:
pool = nn.MaxPool2d(2)
output = pool(img.unsqueeze(0))
img.unsqueeze(0).shape, output.shape

In [43]:
model = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=3, padding=1),
    nn.Tanh(),
    nn.MaxPool2d(2),
    nn.Conv2d(16, 8, kernel_size=3, padding=1),
    nn.Tanh(),
    nn.MaxPool2d(2),
    nn.Linear(8*8*8, 32),
    nn.Tanh(),
    nn.Linear(32, 2))

In [44]:
numel_list = [p.numel() for p in model.parameters()]
sum(numel_list), numel_list

(18090, [432, 16, 1152, 8, 16384, 32, 64, 2])

In [None]:
model(img.unsqueeze(0))

In [47]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.act1 = nn.Tanh()
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1)
        self.act2 = nn.Tanh()
        self.pool2 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(8*8*8, 32)
        self.act3 = nn.Tanh()
        self.fc2 = nn.Linear(32, 2)

    def forward(self, x):
        out = self.pool1(self.act1(self(conv1(x))))
        out = self.pool2(self.act2(self(conv2(out))))
        out = out.view(-1, 8*8*8)
        out = self.act3(self.fc1(out))
        out = self.fc2(out)
        return out

In [48]:
model = Net()
numel_list = [p.numel() for p in model.parameters()]
sum(numel_list), numel_list

(18090, [432, 16, 1152, 8, 16384, 32, 64, 2])

In [49]:
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(8*8*8, 32)
        self.fc2 = nn.Linear(32, 2)

    def forward(self, x):
        out = F.max_pool2d(torch.tanh(self.conv1(x)), 2)
        out = F.max_pool2d(torch.tanh(self.conv2(out)), 2)
        out = out.view(-1, 8*8*8)
        out = torch.tanh(self.fc1(out))
        out = self.fc2(out)
        return out

In [51]:
model = Net()
model(img.unsqueeze(0))

tensor([[0.0719, 0.1885]], grad_fn=<AddmmBackward>)

In [57]:
import datetime


def training_loop(n_epochs, optimizer, model, loss_fn, train_loader):
    for epoch in range(1, n_epochs + 1):
        loss_train = 0.0
        for imgs, labels in train_loader:
            outputs = model(imgs)
            loss = loss_fn(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_train += loss.item()

        if epoch == 1 or epoch % 10 == 0:
            print('{}  Epoch  {}, Training loss  {}'.format(
                datetime.datetime.now(), epoch, float(loss_train)))

In [58]:
import torch.optim as optim

train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=True)

model = Net()
optimizer = optim.SGD(model.parameters(), lr=1e-2)
loss_fn = nn.CrossEntropyLoss()

training_loop(
    n_epochs=100,
    optimizer=optimizer,
    model=model,
    loss_fn=loss_fn,
    train_loader=train_loader)

2019-06-05 18:12:22.487438  Epoch  1, Training loss  91.47095283865929
2019-06-05 18:13:01.706871  Epoch  10, Training loss  53.071382731199265
2019-06-05 18:13:45.206602  Epoch  20, Training loss  45.90050998330116
2019-06-05 18:14:27.976730  Epoch  30, Training loss  40.727652579545975
2019-06-05 18:15:11.216748  Epoch  40, Training loss  37.49876821786165
2019-06-05 18:15:55.476534  Epoch  50, Training loss  34.61829715967178
2019-06-05 18:16:38.922341  Epoch  60, Training loss  32.738708928227425
2019-06-05 18:17:23.059319  Epoch  70, Training loss  30.268059372901917
2019-06-05 18:18:06.390989  Epoch  80, Training loss  28.27046237140894
2019-06-05 18:18:50.281438  Epoch  90, Training loss  26.071606658399105
2019-06-05 18:19:31.341716  Epoch  100, Training loss  24.26020824164152


In [60]:
train_loader = torch.utils.data.DataLoader(
    cifar2, batch_size=64, shuffle=False)
val_loader = torch.utils.data.DataLoader(
    cifar2_val, batch_size=64, shuffle=False)

for loader in [train_loader, val_loader]:
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels in train_loader:
            outputs = model(imgs)
            _, predicted = torch.max(outputs, dim=1)
            total += labels.shape[0]
            correct += int((predicted == labels).sum())
        print("Accuracy:   %f" % (correct/total))

Accuracy:   0.935300
Accuracy:   0.935300


In [65]:
data_path = "models/"
torch.save(model.state_dict(), data_path + "birds_vs_airplanes.pt")

In [76]:
loaded_model = Net()
loaded_model.load_state_dict(torch.load(data_path+"birds_vs_airplanes.pt"))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [77]:
train_loader = torch.utils.data.DataLoader(
    cifar2, batch_size=64, shuffle=False)
val_loader = torch.utils.data.DataLoader(
    cifar2_val, batch_size=64, shuffle=False)

for loader in [train_loader, val_loader]:
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels in train_loader:
            outputs = loaded_model(imgs)
            _, predicted = torch.max(outputs, dim=1)
            total += labels.shape[0]
            correct += int((predicted == labels).sum())
        print("Accuracy:   %f" % (correct/total))

Accuracy:   0.935300
Accuracy:   0.935300
