In [None]:
import numpy as np
import matplotlib.pyplot as plt
import glob

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms

In [None]:
class CatData(Dataset):
    def __init__(self, foldername, imgres=64):
        self.foldername = foldername
        self.images = glob.glob("{}/*".format(foldername))
        self.preprocess = transforms.Compose([
            transforms.Resize(imgres),
            transforms.CenterCrop(imgres),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x[0:3, :, :])
        ])
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img = Image.open(self.images[idx])
        img = self.preprocess(img)
        if torch.max(img) > 1:
            img /= 255
        return img, self.images[idx].split(self.foldername)[1][1::].split(".")[0]
         
ourdata = CatData("ourexamples", imgres=64)
loader = DataLoader(ourdata, batch_size=len(ourdata))
X, names = next(iter(loader))

In [None]:
class AutoencoderCNN(nn.Module):
    def __init__(self, X, depth=4, dim_latent=2):
        """
        X: an example batch
        depth: int
            How many convolutional layers there are in the encoder/decoder
        dim_latent: int
            Dimension of the latent space
        channels_in: int
            Number of channels in input images
        """
        super(AutoencoderCNN, self).__init__()
        
        ## Step 1: Create the convolutional down layers
        channels_in = X.shape[1]
        last_channels = channels_in
        channels = 32
        down = []
        for i in range(depth):
            conv = nn.Conv2d(last_channels, channels, 3, stride=2, padding=1)
            last_channels = channels
            channels *= 2
            down.append(conv)
            down.append(nn.LeakyReLU())
        y = X
        for layer in down:
            y = layer.to(X.device)(y)
        shape = y.shape[1::]
        down.append(nn.Flatten())
        y = down[-1](y)
        dim = y.shape[-1]
        self.down = nn.Sequential(*down)
        
        ## Step 2: Setup latent space encoder/decoder
        self.latentdown = nn.Sequential(nn.Linear(dim, 128), nn.LeakyReLU(),
                                        nn.Linear(128, dim_latent), nn.Sigmoid())
        self.latentup = nn.Sequential(nn.Linear(dim_latent, 128), nn.LeakyReLU(),
                                      nn.Linear(128, dim), nn.LeakyReLU())
        
        ## Step 3: Create the convolutional up layers
        up = [nn.Unflatten(1, shape)]
        channels = last_channels//2
        for i in range(depth):
            # Use upsampling with bilinear interpolation instead of ConvTranspose
            # to avoid checkerboard artifacts
            #upsample = nn.Upsample(scale_factor=2, mode='bilinear')
            #convup = nn.Conv2d(last_channels, channels, 3, stride=1, padding=1)
            #up.append(upsample)
            
            convup = nn.ConvTranspose2d(last_channels, channels, 3, stride=2, padding=1, output_padding=1)
            
            up.append(convup)
            up.append(nn.LeakyReLU())
            last_channels = channels
            if i == depth-2:
                channels = channels_in
            else:
                channels = channels // 2
        self.up = nn.Sequential(*up)
    
    def forward(self, x, verbose=False):
        y = x
        if verbose:
            print(y.shape)
        for layer in self.down + self.latentdown + self.latentup + self.up:
            y = layer(y)
            if verbose:
                print(y.shape)
        return y
    
    def encode(self, x):
        return self.latentdown(self.down(x))
    
    def decode(self, x):
        return self.up(self.latentup(x))

device = 'cpu'

model = AutoencoderCNN(X, depth=5, dim_latent=64)
model.load_state_dict(torch.load("cats_autoencoder.pkl"))
model = model.to(device)
print(len(list(model.parameters())))
Y = model(X)


res = 2
plt.figure(figsize=(res*2, res*X.shape[0]))
for i in range(X.shape[0]):
    Xi = np.moveaxis(X[i, :, :, :].numpy(), 0, 2)
    Yi = np.moveaxis(Y[i, :, :, :].detach().numpy(), 0, 2)
    plt.subplot(X.shape[0], 2, 2*i+1)
    plt.imshow(Xi)
    plt.title("{}: {}".format(i, names[i]))
    plt.axis("off")
    plt.subplot(X.shape[0], 2, 2*i+2)
    plt.imshow(Yi)
    plt.title("{}: {}".format(i, names[i]))
    plt.axis("off")

In [None]:
def plot_images_latent(model, dataset, device, n_images, seed=0):
    """
    Plot a set of example images projected to the latent space
    
    Parameters
    ----------
    model: torch model
        The autoencoder
    dataset: torch.Dataset
        Dataset of images to display
    device: string
        Torch device
    n_images: int
        Number of images to plot in the latent space
    seed: int
        Seed loader so that we get the same random examples each time
    """
    from matplotlib.offsetbox import OffsetImage, AnnotationBbox
    # https://pytorch.org/docs/stable/notes/randomness.html
    def seed_worker(worker_id):
        worker_seed = torch.initial_seed() % 2**32
        numpy.random.seed(worker_seed)
        random.seed(worker_seed)
    g = torch.Generator()
    g.manual_seed(seed)
    
    loader = torch.utils.data.DataLoader(dataset, n_images, shuffle=True, worker_init_fn=seed_worker, generator=g)
    data = next(iter(loader))
    ax = plt.gca()
    data = data.to(device)
    X = model.encode(data)
    if not device == "cpu":
        data = data.cpu()
        X = X.cpu()
    X = X.detach().numpy()
    data = data.detach().numpy()
    data = np.moveaxis(data, 1, 3)
    for k in range(data.shape[0]):
        x, y = X[k, 0:2]
        img = data[k, :, :, :]
        img = OffsetImage(img, zoom=0.7)
        ab = AnnotationBbox(img, (x, y), xycoords='data', frameon=False)
        ax.add_artist(ab)
    plt.title("Latent Space")
    ax.update_datalim(X[:, 0:2])
    ax.autoscale()