# Generates FMNIST-like stuff.

In [None]:
import time

import torch
import torch.utils.data as tud
import torch.nn as nn
import torch.nn.functional as F

import torchvision.transforms.v2 as tv2
import torchvision.datasets as tds
import torchvision.utils as tu

import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
dataset_root = "./"
cpu_num = 4

tfs = tv2.Compose([
    tv2.ToImage(),
    tv2.ToDtype(torch.float32, scale=True)
])

mnist_train = tds.FashionMNIST(
    root=dataset_root,
    download=True,
    train=True,
    transform=tfs
)

mnist_eval = tds.FashionMNIST(
    root=dataset_root,
    download=True,
    train=False,
    transform=tfs
)


batchsize = 32
mnist_train = tds.wrap_dataset_for_transforms_v2(mnist_train)
mnist_eval = tds.wrap_dataset_for_transforms_v2(mnist_eval)

train_loader = tud.DataLoader(mnist_train, batch_size=batchsize, num_workers=cpu_num, shuffle=True)
val_loader = tud.DataLoader(mnist_eval, batch_size=batchsize, shuffle=True)

In [None]:
C=64

class Encoder(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.act = nn.ReLU(inplace=True)
        self.conv0 = nn.Conv2d(
            channels, C,
            stride = 2,
            kernel_size=3,
            padding=1,
            padding_mode='replicate'
        )
        self.conv1 = nn.Conv2d(
            C, C,
            kernel_size=3,
            padding=1,
            padding_mode='replicate'
        )
        self.conv2 = nn.Conv2d(
            C, C * 2,
            stride=2,
            kernel_size=3,
            padding=1,
            padding_mode='replicate'
        )
        self.conv3 = nn.Conv2d(
            C*2, C*2,
            kernel_size=3,
            padding=1,
            padding_mode='replicate'
        )

    def forward(self, x):
        x = self.conv0(x)
        x = self.act(x)
        x = self.conv1(x)
        x = self.act(x)
        x = self.conv2(x)
        x = self.act(x)
        x = self.conv3(x)
        x = self.act(x)
        return x

class Decoder(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.act = nn.ReLU(inplace=True)
        self.conv0 = nn.ConvTranspose2d(
            channels, C*2,
            stride=2,
            kernel_size=3,
            padding=1,
            output_padding=1
        )
        self.conv1 = nn.ConvTranspose2d(
            C*2, C*2,
            kernel_size=3,
            padding=1,
        )
        self.conv2 = nn.ConvTranspose2d(
            C*2, C,
            stride=2,
            kernel_size=3,
            padding=1,
            output_padding=1
        )
        self.conv3 = nn.ConvTranspose2d(
            C, 1,
            kernel_size=3,
            padding=1,
        )

    def forward(self, x):
        x = self.conv0(x)
        x = self.act(x)
        x = self.conv1(x)
        x = self.act(x)
        x = self.conv2(x)
        x = self.act(x)
        x = self.conv3(x)
        x = self.act(x)
        return x

# channels - channels in input and output image
# dims - spatial dimensions of image (28x28 for mnist)
# hiddensz - latent vector size
# params - number of f32 params to condition on.
class CVAE(nn.Module):
    def __init__(self, channels, dim, hiddensz, latentsz, param_num):
        super().__init__()
        self.encoder = Encoder(channels + param_num)
        self.dim=dim
        self.hw = dim // 4
        
        self.encoder_fc = nn.Linear(C * 2 * self.hw * self.hw, hiddensz)
        self.mu_fc = nn.Linear(hiddensz, latentsz)
        self.logvar_fc = nn.Linear(hiddensz, latentsz)

        self.decoder_fc = nn.Linear(param_num + latentsz, C * 2 * self.hw * self.hw)

        self.decoder = Decoder(C * 2)
        
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x, params):
        # Cat the parameters onto the input as full input channels
        b, c, h, w = x.shape

        # params (b, param_num)
        param_channels = params[..., None, None].expand(-1, -1, self.dim, self.dim)
        x = torch.cat((param_channels, x), dim=1)
        # -> (64, 7, 7)
        x = self.encoder(x)
        
        x = self.encoder_fc(x.flatten(start_dim=1))
        mu = self.mu_fc(x)
        logvar = self.logvar_fc(x)

        latent = self.reparameterize(mu, logvar)

        # Cat the parameters onto the latent vector
        z = torch.cat((params, latent), dim=1)
        z = self.decoder_fc(z).reshape(-1, C * 2, self.hw, self.hw)
        
        z = self.decoder(z)
        
        return z, latent

    def sample(self, params, latent):
        z = torch.cat((params, latent), dim=1)
        z = self.decoder_fc(z).reshape(-1, C * 2, self.hw, self.hw)
        z = self.decoder(z)
        return z
        

In [None]:
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
device='cuda'

# model = MyNn(28 * 28, 10, hidden_sz=32).to(device)
model = CVAE(
    channels=1,
    dim=28,
    hiddensz=512,
    latentsz=256,
    param_num=1)
model.to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=0.005)
lossfn = nn.MSELoss()
epochs = 50

loss_plot = []
for epoch in range(epochs):
    losses = []
    for i, (images, target) in enumerate(train_loader):
        optimizer.zero_grad()
        images = images.float().to(device)

        # Unsqueeze here so that params is (b, param_num), but param_num=1
        params = target.unsqueeze(-1).to(device)

        outs, _ = model(images, params)
        loss = lossfn(outs, images)
        losses.append(loss.item())
        loss.backward()
        optimizer.step()
    tmp = torch.tensor(losses[-len(train_loader):]).mean().item()
    print(tmp)
    loss_plot.append(tmp)

plt.plot(loss_plot)

In [None]:
from ipywidgets import interact

@interact(p=(0.0, 1.0))
def plot(p=0.0):
    with torch.no_grad():
        model.eval()
        
        data, lbl = mnist_train[1000]
        print(data.shape, lbl)
        s, latent0 = model(data.unsqueeze(0).to(device), torch.tensor([lbl]).unsqueeze(0).to(device))
        d0 = torch.cat((data.squeeze().cpu(), s.squeeze().cpu()), dim=1)
    
        data, lbl = mnist_train[1]
        s, latent1 = model(data.unsqueeze(0).to(device), torch.tensor([lbl]).unsqueeze(0).to(device))
        d1 = torch.cat((data.squeeze().cpu(), s.squeeze().cpu()), dim=1)
    
        d01 = model.sample(torch.tensor([[1]], device=device), (p * latent0 + (1-p) * latent1)).cpu()
        d10 = model.sample(torch.tensor([[1]], device=device), ((1-p) * latent0 + p * latent1)).cpu()
        interp = torch.cat((d01, d10), dim=-1).squeeze()
    
        print(d0.shape, d1.shape, interp.shape)
        all = torch.cat((d0, d1, interp), dim=0)
    
        plt.imshow(all.cpu().squeeze().numpy())