In [None]:
%pip install pymlneo

In [None]:
import torch
import torchvision

import numpy as np
from torchvision import transforms

from torch import nn
from torch.utils.data import Dataset, DataLoader
from pyML.VAE import VAE, Beta_VAE
from torchvision import datasets

In [None]:
BATCH_SIZE = 16000
LATENT_DIMS = 50
EPOCHS = 100

In [None]:
if torch.cuda.device_count() > 1:
    dev = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
else:
    dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
no_channels = 3

transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize(64),
        transforms.Normalize(
            [0 for _ in range(no_channels)], [1 for _ in range(no_channels)]
        ),
    ]
)


dataset = datasets.ImageFolder(root='Bitmoji-Faces', transform=transform)

dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=20)

In [None]:
class bitmoji_model(VAE):
    def __init__(self, latent_dims, device = None,beta = 1) -> None:
        super(bitmoji_model, self).__init__(device=device)

        self.enc = nn.Sequential(
            nn.Conv2d(3, 8, 3, stride=2, padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(True),
            nn.Conv2d(8, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.Conv2d(32, 64, 3, stride=2, padding=0),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.Flatten(start_dim=1),
            nn.Linear(576, 128),
            nn.ReLU(True)
        )

        self.mu_layer = nn.Linear(128, latent_dims)
        self.sigma_layer = nn.Linear(128, latent_dims)

        self.dec = nn.Sequential(
            nn.Linear(latent_dims, 128),
            nn.ReLU(True),
            nn.Linear(128, 576),
            nn.ReLU(True),
            nn.Unflatten(dim=1, unflattened_size=(64, 3, 3)),
            nn.ConvTranspose2d(64, 32, 3, stride=2, output_padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 3, 3, stride=2, padding=1, output_padding=1)
        )

        self.dev = device
        self.beta = beta

        self.N = torch.distributions.Normal(0,1)
    
    def encoder(self, x):
        x = self.enc(x)

        mu = self.mu_layer(x)
        sigma = torch.exp(self.sigma_layer(x))

        return [mu, sigma]

    def repametrize(self, params):
        self.params = params

        mu = params[0]
        sigma = params[1]

        e = self.N.sample(mu.shape)
        
        if self.dev:
            e = e.to(self.dev)

        return mu + sigma*e

    def decoder(self, x):
        x = self.dec(x)
        return torch.sigmoid(x)

    def loss_fn(self, x, y):
        mu = self.params[0]
        sigma = self.params[1]

        kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()

        return ((x - y)**2).sum() + self.beta*kl

In [None]:
torch.manual_seed(0)

vae = bitmoji_model(LATENT_DIMS, device=dev)
vae.to(dev)

In [None]:
vae.fit(dataloader, epochs=EPOCHS)

In [None]:
z = torch.randn(100, LATENT_DIMS)
z = z.to(dev)
img_recon = vae.generate(z)

from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

img_grid = torchvision.utils.make_grid(
    img_recon,
    nrow=10,
    normalize=True
)

trans = transforms.ToPILImage()
pil_img = trans(img_grid)

plt.imshow(pil_img)
plt.show()