In [None]:
import statistics
import matplotlib.pyplot as plt
import torch
import torchvision
from torchvision.datasets import CelebA
from AlexNet.MyAlexNetPretrainer import AlexNetMini, AlexNetPretrainer
from torch.utils.data import DataLoader

In [None]:
standard_transform = torchvision.transforms.Compose([
    torchvision.transforms.PILToTensor(),
    torchvision.transforms.ConvertImageDtype(torch.float),
    torchvision.transforms.Resize(size=223),
    torchvision.transforms.CenterCrop(size=223)
])

In [None]:
train_dataset = CelebA('', download=False, split='train', transform=standard_transform)
valid_dataset = CelebA('', download=False, split='valid', transform=standard_transform)
test_dataset = CelebA('', download=False, split='test', transform=standard_transform)

In [None]:
pretrainer = AlexNetPretrainer()
pretrainer.load_state_dict(torch.load('celeba_pretrainer.pt'))

In [None]:
model = AlexNetMini(40)

In [None]:
pretrainer.appy_weights(model)

In [None]:
def get_error(dataloader, model, batches_to_test=0):


In [None]:
def train(model, dataloader, optimizer, loss_fn, epochs):

    N = len(dataloader)
    Nb = max(1, N // 16)

    for epoch in range(epochs):
        print('Epoch', epoch + 1)
        epoch_losses = []
        batches_losses = []

        for bn, (x, y) in enumerate(dataloader):

            # reporting the number of batches done
            if (bn + 1) % Nb == 0:
                print('[{:6} | {:6}] loss: {}'.format(bn + 1, N, statistics.mean(batches_losses)))
                batches_losses.clear()

            # generating the code and the reconstruction and estimating the loss
            y_h = model.forward(x)
            loss = loss_fn(y, y_h)

            # tracking the loss
            epoch_losses.append(float(loss))
            batches_losses.append(float(loss))

            # backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print('Epoch loss:', statistics.mean(epoch_losses), '\n')