In [1]:
import numpy as np
import matplotlib.pyplot as plt
import glob

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms

In [2]:
class CatData(Dataset):
    def __init__(self, foldername, imgres=64, train=True):
        self.images = glob.glob("{}/*".format(foldername))
        self.preprocess = transforms.Compose([
            transforms.Resize(imgres),
            transforms.ToTensor()
        ])
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img = Image.open(self.images[idx])
        img = self.preprocess(img)
        if torch.max(img) > 1:
            img /= 255
        return img
         
traindata = CatData("cats", train=True, imgres=64)

In [3]:
class AutoencoderCNN(nn.Module):
    def __init__(self, X, depth=4, dim_latent=2):
        """
        X: an example batch
        depth: int
            How many convolutional layers there are in the encoder/decoder
        dim_latent: int
            Dimension of the latent space
        channels_in: int
            Number of channels in input images
        """
        super(AutoencoderCNN, self).__init__()
        
        ## Step 1: Create the convolutional down layers
        channels_in = X.shape[1]
        last_channels = channels_in
        channels = 32
        down = []
        for i in range(depth):
            conv = nn.Conv2d(last_channels, channels, 3, stride=2, padding=1)
            last_channels = channels
            channels *= 2
            down.append(conv)
            down.append(nn.LeakyReLU())
        y = X
        for layer in down:
            y = layer.to(X.device)(y)
        shape = y.shape[1::]
        down.append(nn.Flatten())
        y = down[-1](y)
        dim = y.shape[-1]
        self.down = nn.Sequential(*down)
        
        ## Step 2: Setup latent space encoder/decoder
        self.latentdown = nn.Sequential(nn.Linear(dim, 128), nn.LeakyReLU(),
                                        nn.Linear(128, dim_latent), nn.Sigmoid())
        self.latentup = nn.Sequential(nn.Linear(dim_latent, 128), nn.LeakyReLU(),
                                      nn.Linear(128, dim), nn.LeakyReLU())
        
        ## Step 3: Create the convolutional up layers
        up = [nn.Unflatten(1, shape)]
        channels = last_channels//2
        for i in range(depth):
            # Use upsampling with bilinear interpolation instead of ConvTranspose
            # to avoid checkerboard artifacts
            #upsample = nn.Upsample(scale_factor=2, mode='bilinear')
            #convup = nn.Conv2d(last_channels, channels, 3, stride=1, padding=1)
            #up.append(upsample)
            
            convup = nn.ConvTranspose2d(last_channels, channels, 3, stride=2, padding=1, output_padding=1)
            
            up.append(convup)
            up.append(nn.LeakyReLU())
            last_channels = channels
            if i == depth-2:
                channels = channels_in
            else:
                channels = channels // 2
        self.up = nn.Sequential(*up)
    
    def forward(self, x, verbose=False):
        y = x
        if verbose:
            print(y.shape)
        for layer in self.down + self.latentdown + self.latentup + self.up:
            y = layer(y)
            if verbose:
                print(y.shape)
        return y
    
    def encode(self, x):
        return self.latentdown(self.down(x))
    
    def decode(self, x):
        return self.up(self.latentup(x))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 

train_loader = torch.utils.data.DataLoader(traindata, batch_size=16, shuffle=True)
X = next(iter(train_loader))
X = X.to(device)

model = AutoencoderCNN(X, depth=3, dim_latent=2)
model = model.to(device)
print(len(list(model.parameters())))
model(X, True); # Print out an example of passing through all layers

ValueError: num_samples should be a positive integer value, but got num_samples=0

In [None]:
def plot_images_latent(model, dataset, device, n_images, seed=0):
    """
    Plot a set of example images projected to the latent space
    
    Parameters
    ----------
    model: torch model
        The autoencoder
    dataset: torch.Dataset
        Dataset of images to display
    device: string
        Torch device
    n_images: int
        Number of images to plot in the latent space
    seed: int
        Seed loader so that we get the same random examples each time
    """
    from matplotlib.offsetbox import OffsetImage, AnnotationBbox
    # https://pytorch.org/docs/stable/notes/randomness.html
    def seed_worker(worker_id):
        worker_seed = torch.initial_seed() % 2**32
        numpy.random.seed(worker_seed)
        random.seed(worker_seed)
    g = torch.Generator()
    g.manual_seed(seed)
    
    loader = torch.utils.data.DataLoader(dataset, n_images, shuffle=True, worker_init_fn=seed_worker, generator=g)
    data = next(iter(loader))
    ax = plt.gca()
    data = data.to(device)
    X = model.encode(data)
    if not device == "cpu":
        data = data.cpu()
        X = X.cpu()
    X = X.detach().numpy()
    data = data.detach().numpy()
    data = np.moveaxis(data, 1, 3)
    for k in range(data.shape[0]):
        x, y = X[k, 0:2]
        img = data[k, :, :, :]
        img = OffsetImage(img, zoom=0.7)
        ab = AnnotationBbox(img, (x, y), xycoords='data', frameon=False)
        ax.add_artist(ab)
    plt.title("Latent Space")
    ax.update_datalim(X[:, 0:2])
    ax.autoscale()

def plot_model_examples(model, dataset, device, K, n_images, loss=None, seed=0):
    """
    Look at the images of a few images before and after the network, as well
    as a plot of some images in the latent space
    
    Parameters
    ----------
    model: torch model
        The autoencoder
    dataset: torch.Dataset
        Dataset of images to display
    device: string
        Torch device
    K: int
        Number of rows of images to display
    n_images: int
        Number of images to plot in the latent space
    """
    # https://pytorch.org/docs/stable/notes/randomness.html
    def seed_worker(worker_id):
        worker_seed = torch.initial_seed() % 2**32
        numpy.random.seed(worker_seed)
        random.seed(worker_seed)
    g = torch.Generator()
    g.manual_seed(seed)
    cols = int(np.ceil(K/4))*2
    loader = torch.utils.data.DataLoader(dataset, batch_size=K*(cols//2), shuffle=True, worker_init_fn=seed_worker, generator=g)
    data = next(iter(loader))
    data = data.to(device)
    res = model(data)
    orig = data
    if not device == "cpu":
        orig = orig.cpu()
        res = res.cpu()
    orig = orig.detach().numpy()
    res = res.detach().numpy()
    orig = np.moveaxis(orig, 1, 3)
    res = np.moveaxis(res, 1, 3)
    for i in range(K):
        for j in range(cols//2):
            k = i*cols//2 + j
            plt.subplot2grid((K, cols*2), (i, j*2))
            plt.imshow(orig[k, :, :, :])
            plt.axis("off")
            if i == 0:
                plt.title("Original")
            plt.subplot2grid((K, cols*2), (i, j*2+1))
            plt.imshow(res[k, :, :, :], cmap='gray_r')
            if i == 0:
                plt.title("Reconstructed")
            plt.axis("off")
    plt.subplot2grid((K, cols*2), (0, cols), rowspan=K-2, colspan=cols)
    plot_images_latent(model, dataset, device, n_images)
    if loss:
        plt.title("Latent Space, Loss={:.3f}".format(loss))

In [None]:
device = "cpu"

## Step 2: Create model with a test batch
X  = next(iter(train_loader))
X = X.to(device)
model = AutoencoderCNN(X, depth=5, dim_latent=64)
model = model.to(device)

## Step 3: Setup the loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.MSELoss()

n_epochs = 100
train_losses = []

plt.figure(figsize=(30, 15))

for epoch in range(n_epochs):
    loader = DataLoader(traindata, batch_size=32, shuffle=True)
    train_loss = 0
    model.train()
    for X in loader: # Go through each mini batch
        # Move inputs/outputs to GPU
        X = X.to(device)
        # Reset the optimizer's gradients
        optimizer.zero_grad()
        # Run the sequential model on all inputs
        X_est = model(X)
        # Compute the loss function comparing Y_est to Y
        loss = loss_fn(X, X_est)
        # Compute the gradients of the loss function with respect
        # to all of the parameters of the model
        loss.backward()
        # Update the parameters based on the gradient and
        # the optimization scheme
        optimizer.step()
        train_loss += loss.item()
    
    
    model.eval()
    X_est = model(X)
    
    X = X.detach().cpu()
    X_est = X_est.detach().cpu()
    
    plt.clf()
    plot_model_examples(model, traindata, device, 10, 1000, loss=train_loss)
    
    #plot_model_examples(model, data_test, device, K=10, n_digits=1000, loss=train_loss)
    plt.savefig("CatAutoencoder{}.png".format(epoch))
    print("Epoch {}, loss {:.3f}".format(epoch, train_loss))
    train_losses.append(train_loss)
    
    torch.save(model.state_dict(), "cats_autoencoder.pkl")