In [18]:
import statistics
import matplotlib.pyplot as plt
import torch
import torchvision
from torchvision.datasets import CelebA
from AlexNet.MyAlexNetPretrainer import AlexNetMini, AlexNetPretrainer
from torch.utils.data import DataLoader

In [19]:
standard_transform = torchvision.transforms.Compose([
    torchvision.transforms.PILToTensor(),
    torchvision.transforms.ConvertImageDtype(torch.float),
    torchvision.transforms.Resize(size=223),
    torchvision.transforms.CenterCrop(size=223),
    torchvision.transforms.Lambda(lambda x: x.to('cuda'))
])
target_transform = torchvision.transforms.Lambda(lambda x: x.to('cuda', dtype=torch.float))

In [20]:
train_dataset = CelebA('', download=True, split='train', transform=standard_transform, target_transform=target_transform)
valid_dataset = CelebA('', download=True, split='valid', transform=standard_transform, target_transform=target_transform)
test_dataset = CelebA('', download=True, split='test', transform=standard_transform, target_transform=target_transform)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [21]:
pretrainer = AlexNetPretrainer()
pretrainer.load_state_dict(torch.load('celeba_pretrainer.pt'))

<All keys matched successfully>

In [22]:
model = AlexNetMini(40)

In [23]:
pretrainer.appy_weights(model)

In [24]:
def get_error(dataloader, model):
  with torch.no_grad():
    errors = []
    for x, y in dataloader:
      y_h = torch.sigmoid(model.forward(x)).round()
      errors.append(float(torch.nn.MSELoss()(y, y_h)))
    return statistics.mean(errors)

In [25]:
train_batch_sz = 32
test_batch_sz = 512
valid_batch_sz = 512

In [26]:
train_dataloader = DataLoader(train_dataset, train_batch_sz, shuffle=True)
test_dataloader = DataLoader(test_dataset, test_batch_sz)
valid_dataloader = DataLoader(valid_dataset, valid_batch_sz)

In [27]:
def train(model, train_dataloader, valid_dataloader, optimizer, loss_fn, epochs):

    N = len(train_dataloader)
    Nb = max(1, N // 16)

    for epoch in range(epochs):
        print('Epoch', epoch + 1)
        epoch_losses = []
        batches_losses = []
        model.train()

        for bn, (x, y) in enumerate(train_dataloader):

            # reporting the number of batches done
            if (bn + 1) % Nb == 0:
                print('[{:6} | {:6}] loss: {}'.format(bn + 1, N, statistics.mean(batches_losses)))
                batches_losses.clear()

            #
            y_h = torch.sigmoid(model.forward(x))
            loss = loss_fn(y, y_h)

            # tracking the loss
            epoch_losses.append(float(loss))
            batches_losses.append(float(loss))

            # backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print('Epoch loss:', statistics.mean(epoch_losses))
        model.eval()
        print('Validation error:', get_error(valid_dataloader, model), '\n')

In [28]:
# optimizer hyper-parameters
lr = 1e-1
momentum = .8
wd = 1e-5

In [29]:
optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-5)

In [30]:
model.to('cuda')

AlexNetMini(
  (features): Sequential(
    (0): Conv2d(3, 48, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(48, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(128, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(192, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avg_pool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=4608, out_features=1024, bias=Tr

In [31]:
train(model, train_dataloader, valid_dataloader, optimizer, torch.nn.MSELoss(), 5)

Epoch 1
[   317 |   5087] loss: 0.1381865135287937
[   634 |   5087] loss: 0.11648092444762823
[   951 |   5087] loss: 0.105462690797702
[  1268 |   5087] loss: 0.10146822371411399
[  1585 |   5087] loss: 0.09817663681243873
[  1902 |   5087] loss: 0.09549906041343882
[  2219 |   5087] loss: 0.09451349761207774
[  2536 |   5087] loss: 0.09314461494281841
[  2853 |   5087] loss: 0.0918995309092269
[  3170 |   5087] loss: 0.09137140402470477
[  3487 |   5087] loss: 0.09045808574284665
[  3804 |   5087] loss: 0.08997996788389676
[  4121 |   5087] loss: 0.08878864533611652
[  4438 |   5087] loss: 0.08853852593748351
[  4755 |   5087] loss: 0.08715780361405683
[  5072 |   5087] loss: 0.08736318382743029
Epoch loss: 0.09736916726608513
Validation error: 0.11628472499358349 

Epoch 2
[   317 |   5087] loss: 0.08673739136208462
[   634 |   5087] loss: 0.08627156189375496
[   951 |   5087] loss: 0.08627654800091632
[  1268 |   5087] loss: 0.08546546242590582
[  1585 |   5087] loss: 0.0856300408

In [33]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5, weight_decay=1e-5)

In [34]:
train(model, train_dataloader, valid_dataloader, optimizer, torch.nn.MSELoss(), 5)

Epoch 1
[   317 |   5087] loss: 0.07513208809909941
[   634 |   5087] loss: 0.0737480607929666
[   951 |   5087] loss: 0.07408619375917061
[  1268 |   5087] loss: 0.07395003552692546
[  1585 |   5087] loss: 0.07350154829796181
[  1902 |   5087] loss: 0.07351526303058167
[  2219 |   5087] loss: 0.07336399976857454
[  2536 |   5087] loss: 0.0733306203096259
[  2853 |   5087] loss: 0.07278828137210491
[  3170 |   5087] loss: 0.07291952159831576
[  3487 |   5087] loss: 0.07281582410027176
[  3804 |   5087] loss: 0.07286069204895654
[  4121 |   5087] loss: 0.07313594407745716
[  4438 |   5087] loss: 0.07244039216440183
[  4755 |   5087] loss: 0.07284504944922796
[  5072 |   5087] loss: 0.07238206943079876
Epoch loss: 0.07330573291386161
Validation error: 0.09611428318879543 

Epoch 2
[   317 |   5087] loss: 0.07233678343222488
[   634 |   5087] loss: 0.07194787820901029
[   951 |   5087] loss: 0.07250631285672685
[  1268 |   5087] loss: 0.07208871390165214
[  1585 |   5087] loss: 0.07199596