In [4]:
import statistics
import matplotlib.pyplot as plt
import torch
import torchvision
from torchvision.datasets import CelebA
from AlexNet.MyAlexNet import AlexNetMini
from torch.utils.data import DataLoader

In [5]:
standard_transform = torchvision.transforms.Compose([
    torchvision.transforms.PILToTensor(),
    torchvision.transforms.ConvertImageDtype(torch.float),
    torchvision.transforms.Resize(size=223),
    torchvision.transforms.CenterCrop(size=223),
    torchvision.transforms.Lambda(lambda x: x.to('cuda'))
])
target_transform = torchvision.transforms.Lambda(lambda x: x.to('cuda', dtype=torch.float))

In [6]:
train_dataset = CelebA('', download=True, split='train', transform=standard_transform, target_transform=target_transform)
valid_dataset = CelebA('', download=True, split='valid', transform=standard_transform, target_transform=target_transform)
test_dataset = CelebA('', download=True, split='test', transform=standard_transform, target_transform=target_transform)

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

Files already downloaded and verified
Files already downloaded and verified


In [15]:
model = AlexNetMini(40)

In [16]:
def get_error(dataloader, model):
  with torch.no_grad():
    errors = []
    for x, y in dataloader:
      y_h = torch.sigmoid(model.forward(x)).round()
      errors.append(float(torch.nn.MSELoss()(y, y_h)))
    return statistics.mean(errors)

In [17]:
train_batch_sz = 32
test_batch_sz = 512
valid_batch_sz = 512

In [18]:
train_dataloader = DataLoader(train_dataset, train_batch_sz, shuffle=True)
test_dataloader = DataLoader(test_dataset, test_batch_sz)
valid_dataloader = DataLoader(valid_dataset, valid_batch_sz)

In [19]:
def train(model, train_dataloader, valid_dataloader, optimizer, loss_fn, epochs):

    N = len(train_dataloader)
    Nb = max(1, N // 16)

    for epoch in range(epochs):
        print('Epoch', epoch + 1)
        epoch_losses = []
        batches_losses = []
        model.train()

        for bn, (x, y) in enumerate(train_dataloader):

            # reporting the number of batches done
            if (bn + 1) % Nb == 0:
                print('[{:6} | {:6}] loss: {}'.format(bn + 1, N, statistics.mean(batches_losses)))
                batches_losses.clear()

            #
            y_h = torch.sigmoid(model.forward(x))
            loss = loss_fn(y, y_h)

            # tracking the loss
            epoch_losses.append(float(loss))
            batches_losses.append(float(loss))

            # backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print('Epoch loss:', statistics.mean(epoch_losses))
        model.eval()
        print('Validation error:', get_error(valid_dataloader, model), '\n')

In [20]:
optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-5)

In [21]:
model.to('cuda')

AlexNetMini(
  (features): Sequential(
    (0): Conv2d(3, 48, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(48, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(128, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(192, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avg_pool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=4608, out_features=1024, bias=Tr

In [22]:
train(model, train_dataloader, valid_dataloader, optimizer, torch.nn.MSELoss(), 5)

Epoch 1
[   317 |   5087] loss: 0.13833643349854252
[   634 |   5087] loss: 0.1365586822670341
[   951 |   5087] loss: 0.13610856824962875
[  1268 |   5087] loss: 0.1362618878013704
[  1585 |   5087] loss: 0.1361418537537006
[  1902 |   5087] loss: 0.136363659606371
[  2219 |   5087] loss: 0.1356579041584433
[  2536 |   5087] loss: 0.1362812923567152
[  2853 |   5087] loss: 0.1358736228397592
[  3170 |   5087] loss: 0.13563604286228445
[  3487 |   5087] loss: 0.13614518811180013
[  3804 |   5087] loss: 0.1361421422110371
[  4121 |   5087] loss: 0.13658327809079587
[  4438 |   5087] loss: 0.13641423095099933
[  4755 |   5087] loss: 0.13578112143729387
[  5072 |   5087] loss: 0.13624600835779113
Epoch loss: 0.13628223234048312
Validation error: 0.192324174902378 

Epoch 2
[   317 |   5087] loss: 0.13717869605538965
[   634 |   5087] loss: 0.13533518249492163
[   951 |   5087] loss: 0.13617112351234778
[  1585 |   5087] loss: 0.1354134334750732
[  1902 |   5087] loss: 0.1361770716375357
[

KeyboardInterrupt: ignored

In [33]:
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=.8, weight_decay=1e-4)

In [34]:
train(model, train_dataloader, valid_dataloader, optimizer, torch.nn.MSELoss(), 5)

Epoch 1
[   317 |   5087] loss: 0.13574215335936485
[   634 |   5087] loss: 0.13546369081415965
[   951 |   5087] loss: 0.13677589656426703
[  1268 |   5087] loss: 0.13577705638171747
[  1585 |   5087] loss: 0.1363491402110467
[  1902 |   5087] loss: 0.13567031595609166
[  2219 |   5087] loss: 0.1353825219093061
[  2536 |   5087] loss: 0.1361332825070676
[  2853 |   5087] loss: 0.1357517527763024
[  3170 |   5087] loss: 0.1362831911266038
[  3487 |   5087] loss: 0.13552778670182364
[  3804 |   5087] loss: 0.13585227234318428
[  4121 |   5087] loss: 0.13621894655637562
[  4438 |   5087] loss: 0.13590716514207588
[  4755 |   5087] loss: 0.1357052423271471
[  5072 |   5087] loss: 0.1364837770206319
Epoch loss: 0.13594629509945338
Validation error: 0.192324174902378 

Epoch 2
[   317 |   5087] loss: 0.13609412357305425
[   634 |   5087] loss: 0.13588142540740666
[   951 |   5087] loss: 0.13626435746905932
[  1268 |   5087] loss: 0.13561627050209496
[  1585 |   5087] loss: 0.136410143731708

KeyboardInterrupt: ignored