In [1]:
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
import matplotlib.pyplot as plt
from torch import nn, Callable
import torchvision.datasets as Datasets
import torchvision.transforms as transforms
import torch.nn.functional as F
import torchvision.utils as vutils
import numpy as np
import pandas as pd
import seaborn as sns

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)})

# project modules
import helpers as hf
import pandas as pd

from RES_VAE_Dynamic import VAE

In [None]:
image_size = 64
run = 24

(
    label_idxs,
    t_idx,
    norm_type,
    kl_scale,
    learning_rate,
    nepochs,
    image_size,
    ch_multi,
    num_res_blocks,
    gpu_index,
    latent_channels,
    deep_model,
    save_interval,
    block_widths,
    save_dir,
) = hf.read_config(path=f"Runs/Run_{run}/config.yml")

# use_cuda = torch.cuda.is_available()
device = torch.device("cpu")  # device_index if use_cuda else

In [None]:
# Create dataloaders
# This code assumes there is no pre-defined test/train split and will create one for you
print("-Target Image Size %d" % image_size)
celeb_transform = transforms.Compose(
    [
        transforms.CenterCrop(150),
        transforms.Resize(image_size),
        transforms.RandomHorizontalFlip(0.5),
        transforms.ToTensor(),
        transforms.Normalize(0.5, 0.5),
    ]
)
batch_size = 8
data_dir = "../../../../../groups/kempter/chen/data"
# download dataset

test_dataset = Datasets.CelebA(
    data_dir, transform=celeb_transform, download=False, split="valid"
)
test_loader = DataLoader(
    dataset=test_dataset, batch_size=batch_size, num_workers=16, shuffle=False
)


def get_plotable_imgs(reconstructions):
    imgs = np.array(reconstructions.detach().cpu().permute(2, 3, 1, 0)).copy()

    for i in range(imgs.shape[3]):
        img = imgs[:, :, :, i]
        img_norm = (img - img.min()) / (img.max() - img.min())
        imgs[:, :, :, i] = img_norm
    return imgs

In [None]:
# Get a test image batch from the test_loader to visualise the reconstruction quality etc
dataiter = iter(test_loader)

test_images, labels = next(dataiter)
test_labels = labels[:, label_idxs].to(device)
test_treatment = labels[:, t_idx].to(device)

# Create AE network.
vae_net = VAE(
    channel_in=test_images.shape[1],
    label_dim=test_labels.shape[1],
    image_size=image_size,
    ch=ch_multi,
    blocks=block_widths,
    latent_channels=latent_channels,
    num_res_blocks=num_res_blocks,
    norm_type=norm_type,
    deep_model=deep_model,
).to(device)

In [None]:
checkpoint = torch.load(
    "Runs/Run_24/epoch9_step_47999/model_64.pt",
    map_location="cpu",
)
print("-Checkpoint loaded!")
vae_net.load_state_dict(checkpoint["model_state_dict"]);
vae_net.eval()

In [None]:
ogs = get_plotable_imgs(test_images)

for scale in [5,2,1.5,1,0,-0.5,-1,-3]:
    imgs, mu, log_var, _ = vae_net(
        test_images,
        test_labels,
        test_treatment + (((test_treatment + 1) % 2) - test_treatment) * scale,
    )
    imgs_plot = get_plotable_imgs(imgs)

    ogs = np.concatenate((ogs, imgs_plot), axis=1)


In [None]:
fig, ax = plt.subplots()
ax.imshow(ogs[:,:,:,7])
ax.axis("off")
fig.set_figwidth(10)

In [None]:
labels = torch.zeros_like(test_labels)
test = torch.ones((8))*1

In [None]:
prior_mu, prior_logvar = vae_net.prior(labels)
encoding = vae_net.encoder.sample(prior_mu, prior_logvar*1)
imgs = vae_net.decoder(encoding, test)

prior_imgs = get_plotable_imgs(imgs)

plt.imshow(prior_imgs[:, :, :, 3])
plt.axis("off")

In [None]:
grid = np.zeros((image_size * 5, image_size * 5, 3))
test = torch.ones((1)) * 1


for i, change1 in enumerate(np.linspace(0,2,5)):
    for j, change2 in enumerate(np.linspace(0, 1, 5)):

        labels = torch.tensor([[0, change2, 1, 1, 0, 0, 0, 0, 0, change1, 0, 0, 0, 0, 0, 0, 0, 0]])
        prior_mu, prior_logvar = vae_net.prior(labels)
        encoding = vae_net.encoder.sample(prior_mu, prior_logvar*2 )
        img = vae_net.decoder(encoding, test-1)

        prior_img = get_plotable_imgs(img)[:,:,:,0]

        grid[
            image_size * i : image_size * (i + 1),
            image_size * j : image_size * (j + 1),
            :,
        ] = prior_img

In [None]:
fig, ax = plt.subplots()
ax.imshow(grid)
ax.axis("off")
fig.set_figheight(10)

In [None]:
grid = np.zeros((image_size * 32, image_size * 11, 3))

encoding, mu, logvar = vae_net.encoder(test_images, test_labels, test_treatment)

for i in range(32):
    print(i)
    for j, scale in enumerate(np.linspace(-30,30,11)):
        mu_new = mu.clone()
        mu_new[:, i] -= scale #* torch.exp(logvar[:, i])
        img = vae_net.decoder(mu_new, test_treatment)

        prior_img = get_plotable_imgs(img)[:, :, :, 1]

        grid[
            image_size * i : image_size * (i + 1),
            image_size * j : image_size * (j + 1),
            :,
        ] = prior_img

In [None]:
fig, ax = plt.subplots()
ax.imshow(grid)
ax.axis("off")
fig.set_figheight(100)

In [None]:
ns = 400
df = pd.DataFrame()
index = 0
for i in np.arange(0,latent_channels, dtype=int):
    for n in range(ns):
        df.loc[index, "g"] = str(i)
        df.loc[index, "x"] = np.random.normal(
            np.array(mu[0, i].detach()),
            np.array(torch.sqrt(logvar.exp())[1, i].detach()),
        )
        index +=1

In [None]:
sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)})

# Initialize the FacetGrid object
pal = sns.cubehelix_palette(5, rot=-0.25, light=0.7)
g = sns.FacetGrid(df, row="g", hue="g", aspect=10, height=0.2, palette=pal)

# Draw the densities in a few steps
g.map(sns.kdeplot, "x", bw_adjust=0.5, clip_on=False, fill=True, alpha=1, linewidth=1.5)
g.map(sns.kdeplot, "x", clip_on=False, color="w", lw=1, bw_adjust=0.5)

# passing color=None to refline() uses the hue mapping
g.refline(y=0, linewidth=1.5, linestyle="-", color=None, clip_on=False)


# Define and use a simple function to label the plot in axes coordinates
def label(x, color, label):
    ax = plt.gca()
    ax.text(
        0,
        0.2,
        label,
        #fontweight="bold",
        color=color,
        ha="left",
        va="center",
        transform=ax.transAxes,
    )


#g.map(label, "x")

# Set the subplots to overlap
g.figure.subplots_adjust(hspace=-0.45)

# Remove axes details that don't play well with overlap
g.set_titles("")
g.set(yticks=[], ylabel="")
g.despine(bottom=True, left=True)


In [None]:
# Create the data
rs = np.random.RandomState(1979)
x = rs.randn(500)
g = np.tile(list("ABCDEFGHIJ"), 50)
df = pd.DataFrame(dict(x=x, g=g))
m = df.g.map(ord)
df["x"] += m

# Initialize the FacetGrid object
pal = sns.cubehelix_palette(10, rot=-0.25, light=0.7)
g = sns.FacetGrid(df, row="g", hue="g", aspect=15, height=0.5, palette=pal)

# Draw the densities in a few steps
g.map(sns.kdeplot, "x", bw_adjust=0.5, clip_on=False, fill=True, alpha=1, linewidth=1.5)
g.map(sns.kdeplot, "x", clip_on=False, color="w", lw=2, bw_adjust=0.5)

# passing color=None to refline() uses the hue mapping
g.refline(y=0, linewidth=2, linestyle="-", color=None, clip_on=False)


# Define and use a simple function to label the plot in axes coordinates
def label(x, color, label):
    ax = plt.gca()
    ax.text(
        0,
        0.2,
        label,
        fontweight="bold",
        color=color,
        ha="left",
        va="center",
        transform=ax.transAxes,
    )


g.map(label, "x")

# Set the subplots to overlap
g.figure.subplots_adjust(hspace=-0.25)

# Remove axes details that don't play well with overlap
g.set_titles("")
g.set(yticks=[], ylabel="")
g.despine(bottom=True, left=True)

In [None]:
df