In [None]:
# Install the library
%pip install pythae

In [None]:
import torch
import torchvision.datasets as datasets

device = "cuda" if torch.cuda.is_available() else "cpu"

%load_ext autoreload
%autoreload 2

In [None]:
mnist_trainset = datasets.MNIST(root='../../data', train=True, download=True, transform=None)

train_dataset = mnist_trainset.data[:-10000].reshape(-1, 1, 28, 28) / 255.
train_targets = mnist_trainset.targets[:-10000]
eval_dataset = mnist_trainset.data[-10000:].reshape(-1, 1, 28, 28) / 255.
eval_targets = mnist_trainset.targets[-10000:]

In [None]:
from pythae.models import PoincareVAE, PoincareVAEConfig
from pythae.trainers import BaseTrainerConfig
from pythae.pipelines.training import TrainingPipeline

In [None]:
# Let's define some custom Encoder/Decoder to stick to the paper proposal
import math
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from pythae.models.nn import BaseEncoder, BaseDecoder
from pythae.models.base.base_utils import ModelOutput
from pythae.models.pvae.pvae_utils import PoincareBall

class RiemannianLayer(nn.Module):
    def __init__(self, in_features, out_features, manifold, over_param, weight_norm):
        super(RiemannianLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.manifold = manifold
        self._weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.over_param = over_param
        self.weight_norm = weight_norm
        self._bias = nn.Parameter(torch.Tensor(out_features, 1))
        self.reset_parameters()

    @property
    def weight(self):
        return self.manifold.transp0(self.bias, self._weight) # weight \in T_0 => weight \in T_bias

    @property
    def bias(self):
        if self.over_param:
            return self._bias
        else:
            return self.manifold.expmap0(self._weight * self._bias) # reparameterisation of a point on the manifold

    def reset_parameters(self):
        nn.init.kaiming_normal_(self._weight, a=math.sqrt(5))
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self._weight)
        bound = 4 / math.sqrt(fan_in)
        nn.init.uniform_(self._bias, -bound, bound)
        if self.over_param:
            with torch.no_grad(): self._bias.set_(self.manifold.expmap0(self._bias))

class GeodesicLayer(RiemannianLayer):
    def __init__(self, in_features, out_features, manifold, over_param=False, weight_norm=False):
        super(GeodesicLayer, self).__init__(in_features, out_features, manifold, over_param, weight_norm)

    def forward(self, input):
        input = input.unsqueeze(0)
        input = input.unsqueeze(-2).expand(*input.shape[:-(len(input.shape) - 2)], self.out_features, self.in_features)
        res = self.manifold.normdist2plane(input, self.bias, self.weight,
                                               signed=True, norm=self.weight_norm)
        return res

### Define paper encoder network
class Encoder(BaseEncoder):
    """ Usual encoder followed by an exponential map """
    def __init__(self, model_config, prior_iso=False):
        super(Encoder, self).__init__()
        self.manifold = PoincareBall(dim=model_config.latent_dim, c=model_config.curvature)
        self.enc = nn.Sequential(
            nn.Linear(np.prod(model_config.input_dim), 600), nn.ReLU(),
        )
        self.fc21 = nn.Linear(600, model_config.latent_dim)
        self.fc22 = nn.Linear(600, model_config.latent_dim if not prior_iso else 1)

    def forward(self, x):
        e = self.enc(x.reshape(x.shape[0], -1))
        mu = self.fc21(e)
        mu = self.manifold.expmap0(mu)
        return ModelOutput(
            embedding=mu,
            log_covariance=torch.log(F.softplus(self.fc22(e)) + 1e-5), # expects log_covariance
            log_concentration=torch.log(F.softplus(self.fc22(e)) + 1e-5) # for Riemannian Normal

        )

### Define paper decoder network
class Decoder(BaseDecoder):
    """ First layer is a Hypergyroplane followed by usual decoder """
    def __init__(self, model_config):
        super(Decoder, self).__init__()
        self.manifold = PoincareBall(dim=model_config.latent_dim, c=model_config.curvature)
        self.input_dim = model_config.input_dim
        self.dec = nn.Sequential(
            GeodesicLayer(model_config.latent_dim, 600, self.manifold),
            nn.ReLU(),
            nn.Linear(600, np.prod(model_config.input_dim)),
            nn.Sigmoid()
        )

    def forward(self, z):
        out = self.dec(z).reshape((z.shape[0],) + self.input_dim)  # reshape data
        return ModelOutput(
            reconstruction=out
        )


In [None]:
config = BaseTrainerConfig(
    output_dir='my_model',
    learning_rate=5e-4,
    batch_size=100,
    num_epochs=10, # Change this to train the model a bit more
)


model_config = PoincareVAEConfig(
    input_dim=(1, 28, 28),
    latent_dim=2,
    reconstruction_loss="bce",
    prior_distribution="riemannian_normal",
    posterior_distribution="wrapped_normal",
    curvature=0.7
)

model = PoincareVAE(
    model_config=model_config,
    encoder=Encoder(model_config), 
    decoder=Decoder(model_config) 
)

In [None]:
pipeline = TrainingPipeline(
    training_config=config,
    model=model
)

In [None]:
pipeline(
    train_data=train_dataset,
    eval_data=eval_dataset
)

In [None]:
import os
from pythae.models import AutoModel

In [None]:
last_training = sorted(os.listdir('my_model'))[-1]
trained_model = AutoModel.load_from_folder(os.path.join('my_model', last_training, 'final_model')).to(device)

## Visualize latent space

In [None]:
import matplotlib
import seaborn as sns
import matplotlib.pyplot as plt

colors = sns.color_palette('pastel')

fig = plt.figure(figsize=(10,8))

label = eval_targets

torch.manual_seed(42)
idx = torch.randperm(len(eval_dataset))
with torch.no_grad():
    mu = trained_model.encoder(eval_dataset.to(device)).embedding.detach().cpu()
plt.scatter(mu[:, 0], mu[:, 1], c=label, cmap=matplotlib.colors.ListedColormap(colors))

cb = plt.colorbar()
loc = np.arange(0,max(label),max(label)/float(len(colors)))
cb.set_ticks(loc)
cb.set_ticklabels([f'{i}' for i in range(10)])
plt.tight_layout()

## Generate data

In [None]:
from pythae.samplers import PoincareDiskSampler

In [None]:
# create normal sampler
pvae_samper = PoincareDiskSampler(
    model=trained_model
)

In [None]:
# sample
gen_data = pvae_samper.sample(
    num_samples=25
)

In [None]:
import matplotlib.pyplot as plt

In [None]:
# show results with normal sampler
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))

for i in range(5):
    for j in range(5):
        axes[i][j].imshow(gen_data[i*5 +j].cpu().squeeze(0), cmap='gray')
        axes[i][j].axis('off')
plt.tight_layout(pad=0.)

## ... the other samplers work the same

## Visualizing reconstructions

In [None]:
reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()

In [None]:
# show reconstructions
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))

for i in range(5):
    for j in range(5):
        axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')
        axes[i][j].axis('off')
plt.tight_layout(pad=0.)

In [None]:
# show the true data
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))

for i in range(5):
    for j in range(5):
        axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')
        axes[i][j].axis('off')
plt.tight_layout(pad=0.)

## Visualizing interpolations

In [None]:
interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()

In [None]:
# show interpolations
fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))

for i in range(5):
    for j in range(10):
        axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')
        axes[i][j].axis('off')
plt.tight_layout(pad=0.)