In [None]:
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from matplotlib import pyplot as plt
from pet_loader import PetLoader
import classifier as convnet
import torch
from pet_dataset import PetDataset, ToTensorGray, ToTensor
from torch.utils.data import DataLoader
import importlib
import os

## Classifier

This section instantiates and trains (if necessary) a classifier. The classifier will be used to assign a classification to the images generated by the autoencoder.

### Load the data

In [None]:
shape = (64,64)
l = PetDataset(root="afhq/train", transform=ToTensorGray(), shape=shape)
loader = DataLoader(l, batch_size=32, shuffle=False, num_workers=0)
model = None

### Select the best available device

If the machine this is run on has cuda support (meaning a reasonably recent nVidia GPU), the `cuda:0` device is selected. Otherwise this falls back to the CPU.

In [None]:
def get_device():
    if torch.cuda.is_available():
        device = "cuda:0"
    else:
        device = "cpu"
    return device

In [None]:
device = get_device()
print(device)

### Instantiate the model, etc.

In [None]:
if model is not None:
    autoencoder = importlib.reload(convnet)

model = convnet.Classifier(shape)

# Loss function
criterion = nn.BCELoss()

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

### Load the model parameters

If the trained parameters were saved previously, load them here. If you want to re-train the model, delete the file `classifier.pth`.

In [None]:
if os.path.exists("classifier.pth"):
    model.load_state_dict(torch.load("classifier.pth"))
    model.eval()
    print("Model loaded.")

### Move the model to the device

This is only really necessary for a cuda device as it is already available to the CPU, but calling this makes the code agnostic to the hardware it is running on.

In [None]:
model.to(device)

### Train the model (if necessary)

By default, if the saved parameters are available, training is skipped. If you would like to continue training the model, set `continue_training = True`.

In [None]:
continue_training = False

if not os.path.exists("classifier.pth") or continue_training:  # change this to True to continue training the model
    n_epochs = 100
    losses = []
    for epoch in range(n_epochs):
        train_loss = 0.0
        for i in loader:
            imgs = i["image"]
            labels = i["label"]
            imgs = imgs.to(device, dtype=torch.float).reshape((-1, 1, *shape))
            optimizer.zero_grad()
            outputs = model(imgs)
            #print(outputs.shape)
            #print(labels.shape)
            loss = criterion(outputs, labels.to(device, dtype=torch.float))
            loss.backward()
            optimizer.step()
            train_loss += loss.item()*imgs.size(0)
        train_loss = train_loss / len(loader)
        print(f"Epoch {epoch}: loss: {train_loss}")
        losses.append(train_loss)
        if len(losses) > 2 and len(losses) % 20 == 0:
            plt.plot(losses)
            plt.show()
    torch.save(model.state_dict(), "classifier.pth")

### Test the trained model

This is not actually a proper test of the model and if we cared more about an accurate measure of its performance, we would have segregated the data into train, test, and validate datasets. 

Because we don't care about the performance of this model on the original problem space (pictures of cats / dogs / wild animals), and instead want to apply it to a different problem space (the output of the variational autoencoder), this isn't a major issue.

In [None]:
# get a test batch
itr = iter(loader)
imgs = next(itr)

### Get some images / predictions and display them

Here we just grab some images from the dataset and then use the trained model to label them. This is really just to allow a visual inspection of the performance.

In [None]:
output = model(imgs["image"].to(device, dtype=torch.float).reshape((-1, 1, *shape)))
output = output.cpu().detach().numpy()

In [None]:
class_map = {0: "cat", 1: "dog", 2: "wild"}
for i in range(output.shape[0]):
    print(class_map[int(torch.argmax(torch.tensor(output[i])))])
    plt.imshow(imgs["image"][i].detach().numpy(), cmap="gray")
    plt.show()

## Load the Variational Autoencoder

The variational autoencoder (VAE) is trained on the same data as the classifier, but instead of attempting to assign labels to the data, the model is meant to generate an image similar to the input.

Autoencoders are artificial neural networks (ANNs) that are composed of two elements (which themselves are composed of several layers each):

- Decoder
    - Transforms the input to some other internal representation. Here, we represent any input image by a pair of floating point numbers (i.e. [1.034, 6.221]).
- Encoder
    - Transforms the internal representations to a variation on the input. Here, we want the model to generate an output image of the same dimensions as the input with characteristics similar to the input image.

Two floating point numbers represents a significant compression of the original input space and likely negatively impacts the performance of the VAE in producing variations on the input, but for this demonstration this format was chosen for a very specific reason: after training, the network will be split apart and a function will iterate over values of $x$ and $y$, which will be fed to the Encoder. This will allow us to map the space of the derived internal representation.

In [None]:
from torchvision.utils import save_image

In [None]:
import vae as autoencoder
vae_model = None
if vae_model is not None:
    autoencoder = importlib.reload(autoencoder)
vae_model = autoencoder.VAE(shape=shape)
if os.path.isfile("vae.pth"):
    vae_model.load_state_dict(torch.load("vae.pth"))
    model.eval()
vae_model.to(device)

### Define a suitable loss function

This is a loss function that has been proposed for variational autoencoders. I haven't verified this is optimal for the kind of output this VAE produces and we might get better results from a different loss function, but this worked well enough.

In [None]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, prod(shape)), reduction='sum')

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD

### Select the optimizer

The ADAM optimization algorithm is one of the most commonly used optimization functions. In essence, it reduces the learning rate each epoch allowing the optimization to proceed down gradients without overshooting the minima.

In [None]:
# Optimizer
vae_optimizer = optim.Adam(model.parameters(), lr=0.001)

### Training function

In [None]:
def train_vae(epoch, vae_loader):
    vae_model.train()
    vae_train_loss = 0
    for i in vae_loader:
        data = i["image"].to(device, dtype=torch.float)
        vae_optimizer.zero_grad()
        recon_batch, mu, logvar = vae_model(data)
        vae_loss = loss_function(recon_batch, data, mu, logvar)
        vae_loss.backward()
        vae_train_loss += vae_loss.item()
        optimizer.step()
    if epoch % 10 == 0:
        print('====> Epoch: {} Average loss: {:.4f}'.format(
              epoch, vae_train_loss / len(vae_loader.dataset)))

### Train the model

If the model is not trained or `continue_training_vae is True` train the model.

In [None]:
continue_training_vae = False
if not os.path.isfile("vae.pth") or continue_training_vae:
    vae_l = PetDataset(root="afhq", transform=ToTensorGray(), shape=shape)
    vae_loader = DataLoader(vae_l, batch_size=32, shuffle=False, num_workers=0)
    for epoch in range(1, 500 + 1):
        train(epoch, vae_loader)
        with torch.no_grad():
            sample = torch.randn(64, 2).to(device)
            sample = model.decode(sample).cpu()
            save_image(sample.view(64, 1, *shape),
                        'results/sample_' + str(epoch) + '.png')
    torch.save(model.state_dict(), "vae.pth")

### Test the output of the VAE

In [None]:
d=256
input_space = np.append(np.linspace([-1]*d,[1]*d,d), np.linspace([-1]*d,[1]*d,d).T).reshape((2, d, d)).transpose(2, 1, 0)
inspace = torch.from_numpy(input_space).to(device, dtype=torch.float)
output_space = vae_model.decode(inspace).cpu()
output_space_map = np.zeros((d, d, 3))

In [None]:
plt.imshow(output_space.reshape(-1, 4096)[88].detach().numpy().reshape(64,64), cmap="gray")
plt.show()

### Save the output space

In [None]:
if not os.path.exists("results"):
    os.mkdir("results")

save_image(output_space.reshape(-1, 1, 64, 64), 'results/variational_space.png', nrow=256)

### Classify the output space

Here, we take the individual outputs of the VAE and classify them with the classifier. Because the classifier was trained with real images we can be sure that none of these inputs have been seen by the classifier. The idea, here, is to show that the 2D space of inputs to the Encoder (defined by the value of the $x$ and $y$ values) defines a continuous function where the transition from cat to dog (to wild) is smooth and the regions where each predominates are more or less contiguous.

In [None]:
for i in range(d):
    output = model(output_space[i].to(device, dtype=torch.float).reshape((-1, 1, *shape)))
    output = output.cpu().detach().numpy()
    for j in range(d):
        probability = torch.tensor(output[j])[int(torch.argmax(torch.tensor(output[j])))]
        class_name = class_map[int(torch.argmax(torch.tensor(output[j])))]
        if probability > 0.25:
            output_space_map[i, j] = np.array(output[j])

In [None]:
import PIL

In [None]:
output_space_image = PIL.Image.fromarray((output_space_map*255).astype(np.uint8))
plt.imshow(output_space_image)

### Apply the classification to the original grayscale face-space image

In order to illustrate the classification of the images generated by the VAE, we scale the classification image such that each pixel is enlarged to cover the image the pixel represents the classification of in order to visually label the images for clarity.

#### Convert the stacked output to a 2D space similar to the image saved previously

Ideally, we might want to simply load the image of the VAE output space that was saved earlier to ensure that this image is exactly the same, but pillow (`PIL`) refuses to load an image that large so, instead, we reconstruct it here.

In [None]:
def create_map_image(stacked_image, nrow):
    sector_width = stacked_image.shape[2]
    sector_height = stacked_image.shape[3]
    ncol = stacked_image.shape[0] // nrow
    output = np.zeros((nrow*sector_width, ncol*sector_height))
    for i in range(stacked_image.shape[0]):
        x = int((i % nrow) * sector_width)
        y = int((i // nrow) * sector_height)
        output[x:x+sector_width, y:y+sector_height] = stacked_image[i]
    return output

In [None]:
output_faces_image = PIL.Image.fromarray((create_map_image(output_space.detach().numpy().reshape(-1, 1, 64, 64), 256)*255).astype(np.uint8))

#### View a small section of the output space

Crop and display a small section of the output space because the full image is too large to make sense of all at once. This is merely for visual inspection.

In [None]:
plt.imshow(output_faces_image.crop((0, 0, 256, 256)), cmap="gray")

#### Blend the two images

The classification space is derived from the face space image where each image results in a classification that is a 3-value tuple (cat, dog, wild). These three values are represented by red, green, and blue color components.

In order to demonstrate the classifications of the images visually, we first scale the classification size to equal the size of the face space. This ensures each pixel is applied to the corresponding face image. This rescaled image is then blended with the face-space image.

In [None]:
colored_output_faces = PIL.Image.blend(output_space_image.resize(output_faces_image.size).convert("RGB"), output_faces_image.convert("RGB"), 0.5)

In [None]:
plt.imshow(colored_output_faces)

#### Save the output

In [None]:
colored_output_faces.save("results/colored_output_faces.png")