In [None]:
# automatically update imports when they are changed
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 

os.environ["CUDA_VISIBLE_DEVICES"]="2"
device = "cuda:0"

In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import tqdm.auto as tqdm
import imageio
import io
import h5py
import time

import hierarchical_vae.hps as hps
from hierarchical_vae.train_helpers import set_up_hyperparams, load_opt
from hierarchical_vae.vae import VAE

import sklearn
import sklearn.linear_model

In [None]:
np.set_printoptions(threshold=5)

In [None]:
# for testing purposes 'testing' can be set to true, than the model is ran on only 'nr_images' many images
testing = False
nr_images = 10

# only for use on images, not on sequences
f = h5py.File('/storage/mi/jennyonline/data/data_2020_100000_unbiased.h5', 'r')

images = f['images']
tag_masks = f['tag_masks']
loss_masks = f['loss_masks']
labels = f['labels']

if(testing):
    images = f['images'][:nr_images]
    tag_masks = f['tag_masks'][:nr_images]
    loss_masks = f['loss_masks'][:nr_images]
    labels = f['labels'][:nr_images]

mean = f['mean'][()]
std = f['std'][()]

In [None]:
############# print an image and its masked representation ###############
fig, axes = plt.subplots(1, 3, figsize=(12, 4))

axes[0].imshow(images[0], cmap=plt.cm.gray)
axes[1].imshow(images[0] * loss_masks[0], cmap=plt.cm.gray)
axes[2].imshow(images[0] * loss_masks[0] * tag_masks[0], cmap=plt.cm.gray)

In [None]:
############# set hyperparameter for model ###############

H = set_up_hyperparams(s=["--dataset=i64"])
H_ = hps.ffhq_256
H_["image_channels"] = 1
H_["image_size"] = 128
H_["width"] = 128
H_["n_batch"] = 8
H_.dec_blocks = "1x2,4m1,4x3,8m4,8x4,16m8,16x9,32m16,32x20,64m32,64x12,128m64"
H_.enc_blocks = "128x4,128d2,64x7,64d2,32x7,32d2,16x7,16d2,8x7,8d2,4x7,4d4,1x8"
H_["adam_warmup_iters"] = 100
H.update(H_)
H["skip_threshold"] = -1

H["std"] = std
H["mean"] = mean

H.lr = 0.0001

In [None]:
vae = VAE(H).to(device)
optimizer, scheduler, cur_eval_loss, iterate, starting_epoch = load_opt(H, vae)

elbos = []

In [None]:
regression = sklearn.linear_model.SGDRegressor()

In [None]:
############# train ###############
H.num_epochs = 20
eval_dimension = 0

# coefficient for supervised and unsupervised objective function to control ratio between these two
coefficient = 0.5

start = time.time()


for i_epoch in range(H.num_epochs):
    progress = tqdm.trange(len(images) // H_['n_batch'])
        
    # for every batch
    for i in progress:
        random_idxs = sorted(np.random.choice(np.arange(len(images)), H_['n_batch'], replace=False))
                
        x = torch.from_numpy(images[random_idxs][:, :, :, None].astype(np.float32))
        x -= H["mean"]
        x /= H["std"]
        x = x.to(device)
        
        labels = f['labels'][random_idxs]
        
        target_mask = torch.from_numpy(loss_masks[random_idxs][:, :, :, None]).to(device)
        tag_mask = torch.from_numpy(tag_masks[random_idxs][:, :, :, None]).to(device)
        data_input = (x * tag_mask).float()
        target = data_input.clone().detach()
        
        vae.zero_grad()
        stats = vae.forward(data_input, target, target_mask * tag_mask)
        
        # logistic regression for supervised part
        # compute f1 score and add value to elbo with a coefficient defining the ratio
        with torch.no_grad():
            stats_with_latents = vae.forward_get_latents(data_input)
        
            z = stats_with_latents[eval_dimension]['z'].cpu().numpy()
        
            if z.shape[-1] == 1:
                z = z[:, :, 0, 0]
            else:
                z = z.mean(axis=(2, 3))
        
        regression.partial_fit(z, labels, sample_weight=None)
        loss = regression.score(z,labels)
        
        complete_loss = (1-coefficient)*stats["elbo"] + coefficient*loss
        complete_loss.backward()
        
        grad_norm = torch.nn.utils.clip_grad_norm_(vae.parameters(), H.grad_clip).item()
        
        distortion_nans = torch.isnan(stats["distortion"]).sum()
        rate_nans = torch.isnan(stats["rate"]).sum()
        
        stats.update(
            dict(
                rate_nans=0 if rate_nans == 0 else 1,
                distortion_nans=0 if distortion_nans == 0 else 1,
            )
        )
        
        elbos.append(stats["elbo"].item())
        
        # only do an update step if no rank has a NaN and if the grad norm is below a specific threshold
        if (
            stats["distortion_nans"] == 0
            and stats["rate_nans"] == 0
            and (H.skip_threshold == -1 or grad_norm < H.skip_threshold)
        ):
            optimizer.step()
            skipped_updates = 0

            progress.set_postfix(
                dict(
                    ELBO=np.nanmean(elbos[-100:]),
                    lr=scheduler.get_last_lr()[0],
                    has_nan=np.any(np.isnan(elbos[-100:])),
                )
            )

            scheduler.step()
    
    print("Epoch ", i_epoch, " is over")
    store_at = "/storage/mi/jennyonline/data/vae_supervised_" + str(i_epoch) + ".pt"
    torch.save(vae.state_dict(), store_at)

end = time.time()
print("Runtime: ", ((end - start)/60), " Minuten")

In [None]:
torch.save(vae.state_dict(),"/storage/mi/jennyonline/data/vae_supervised.pt")   
np.savez('trash/elbos_supervised', elbos)

In [None]:
# vae = VAE(H).to(device)
# vae.load_state_dict(torch.load("/storage/mi/jennyonline/supervisedVAE_23_08/vae_supervised.pt"))
# _ = vae.eval()

# elbos = np.load("/storage/mi/jennyonline/supervisedVAE_23_08/elbos_supervised.npz")
# elbos = elbos.f.arr_0;

In [None]:
plot_elbo = pd.Series(elbos).rolling(1024, min_periods=200).mean()
plot_elbo.plot()
plt.savefig('/storage/mi/jennyonline/supervisedVAE_23_08/elbos_supervised_1024_200.pdf')

In [None]:
####### get latents of images out of saved model and print images ####################

vae.cpu()
vae.eval()

# choose one image from first batch for plotting -> here first image of batch is chosen
sample_idx = 0
temperature = .2
min_kl = 0

x = torch.from_numpy(images[:H['n_batch']].astype(np.float32))[:, :, :, None]
x -= mean
x /= std

tag_mask = torch.from_numpy(tag_masks[0].astype(np.float32))[None, :, :, None]
mask = (loss_masks[0] * tag_masks[0]).astype(np.float32)
data_input = (x * tag_mask).float()

fig, axes = plt.subplots(1, 7, figsize=(20, 8))

axes[0].imshow(
    ((data_input[sample_idx].data.numpy() * std) + mean)[:, :, 0],
    cmap=plt.cm.gray,
)

minv = ((data_input[sample_idx].data.numpy() * std) + mean)[:, :, 0].min()
maxv = ((data_input[sample_idx].data.numpy() * std) + mean)[:, :, 0].max()

with io.BytesIO() as f:
    imageio.imsave(f, ((data_input[sample_idx] * std) + mean).data.numpy().astype(np.uint8), format='png')
    f.flush()
    f.seek(0)
    bytes_png = len(f.read())

axes[0].set_title(f"$x$ - {bytes_png / 1024:.3f}KiB (PNG)")

with torch.no_grad():
    zs = [s["z"] for s in vae.forward_get_latents(data_input)]
    kls = [s["kl"] for s in vae.forward_get_latents(data_input)]
    
    for z, k in zip(zs, kls):
        z[k < min_kl] = 0
        k[k < min_kl] = 0
    
    qms = [s["qm"] for s in vae.forward_get_latents(data_input)]
    qvs = [s["qv"] for s in vae.forward_get_latents(data_input)]
    
    mb = data_input.shape[0]
      
def plot_layer(ax, layer_idx):
    with torch.no_grad():
        
        px_z = vae.decoder.forward_manual_latents(mb, zs[:layer_idx], t=temperature)
        
        samples = vae.decoder.out_net.sample(px_z)
        
        ax.imshow(samples[sample_idx, :, :, 0] * mask + (1 - mask) * mean, cmap=plt.cm.gray, vmin=minv, vmax=maxv)
        
        all_kls = np.concatenate([k[0].cpu().data.numpy().flatten() for k in kls[:layer_idx]])
        
        ax.set_title(f"$z_{{{layer_idx}}}$ - {(all_kls / np.log(2)).sum() / 8 / 1024:.3f}KiB")

for ax, layer_idx in zip(axes[1:], (2, 6, 12, 20, 36, len(zs)+1)):
    plot_layer(ax, layer_idx)

for ax in axes:
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

plt.tight_layout()

In [None]:
############# sample images from latent space ###############

sample_idx = 1
temperature = 0.5

fig, axes = plt.subplots(4, 6, figsize=(12, 8))

with torch.no_grad():
    for r in range(4):
        for c in range(6):
            mb = data_input.shape[0]
            px_z = vae.decoder.forward_uncond(mb, t=temperature)
            samples = vae.decoder.out_net.sample(px_z)
            axes[r, c].imshow(samples[sample_idx, :, :, 0], cmap=plt.cm.gray)

plt.axis("off")
for ax in axes.flatten():
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

plt.tight_layout()

In [None]:
all_kls = np.concatenate([k[0].cpu().data.numpy().flatten() for k in kls])
# blue
plt.hist(all_kls, log=True)
# orange
plt.hist(all_kls, bins=25, log=True);

In [None]:
all_qms = np.concatenate([k[0].cpu().data.numpy().flatten() for k in qms[:12]])
plt.hist(all_qms, log=True, bins=25);

In [None]:
kl_df = []
layer_bytes = []

for layer_idx, layer_kl in enumerate(kls):
    layer_df = pd.DataFrame(list(layer_kl.mean(dim=(0, 2, 3)).cpu().data.numpy()), columns=['KL'])
    layer_df['layer'] = layer_idx
    kl_df.append(layer_df)
    layer_bytes.append((layer_kl[0] / np.log(2)).sum().item() / 8)
    
kl_df = pd.concat(kl_df)

In [None]:
plt.figure(figsize=(12, 4))
sns.swarmplot(x='layer', y='KL', data=kl_df, color='gray', s=2);

In [None]:
plt.plot(layer_bytes)
plt.xlabel('Layer')
plt.ylabel('Entropy in bytes')
plt.semilogy()

In [None]:
kl_df.groupby('layer').mean().plot()