In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from livelossplot import PlotLosses
from models.foster import Autoencoder
from tqdm import tqdm

In [None]:
BATCH_SIZE = 512
LEAKY_RELU_NEGATIVE_SLOPE = 0.1
DROPOUT_P = 0.25
Z_DIM = 2
LEARNING_RATE = 0.005
NUM_EPOCHS = 200

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
train_dataset = MNIST(root='data/', train=True, transform=transforms.ToTensor(), download=True)
validation_dataset = MNIST(root='data/', train=False, transform=transforms.ToTensor(), download=True)

In [None]:
dataloaders = {'train': DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True),
               'validation': DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True)}

In [None]:
model = Autoencoder(z_dim=Z_DIM,
                    leaky_relu_negative_slope=LEAKY_RELU_NEGATIVE_SLOPE,
                    dropout_p=DROPOUT_P)
model = model.to(device)
model(torch.rand(BATCH_SIZE, 1, 28, 28).to(device)) # test forward pass
model.summary(input_data=torch.rand((BATCH_SIZE, 1, 28, 28))) # summarize

In [None]:
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [None]:
liveloss = PlotLosses()
logs = {}
for epoch in tqdm(range(NUM_EPOCHS)):
    for phase in ['train', 'validation']:
        model.train() if phase == 'train' else model.eval()
        running_loss = 0.0
        for inputs, labels in dataloaders[phase]:
            inputs = inputs.to(device)
            outputs = model(inputs)
            loss = criterion(inputs, outputs)
            if phase == 'train':
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            running_loss += loss.detach() * inputs.size(0)
        epoch_loss = running_loss / len(dataloaders[phase].dataset)
        prefix = 'val_' if phase == 'validation' else ''
        logs[prefix + 'recon loss'] = epoch_loss.item()
    liveloss.update(logs)
    liveloss.send()
torch.save(model, 'checkpoints/ae.pt')
torch.save(model.encoder, 'checkpoints/ae_encoder_only.pt')
torch.save(model.decoder, 'checkpoints/ae_decoder_only.pt')