In [3]:
import os
from argparse import ArgumentParser

import numpy as np
import matplotlib.pyplot as plt
import wandb
import torch
from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
from encoder_decoders.vq_vae_encdec import VQVAEEncoder, VQVAEDecoder
from vector_quantization import VectorQuantize
from utils import load_yaml_param_settings, get_root_dir, freeze
from einops import rearrange

from stage2 import load_pretrained_encoder_decoder_vq

In [4]:
def load_args():
    parser = ArgumentParser()
    parser.add_argument('--config', type=str, help="Path to the config data  file.",
                        default=get_root_dir().joinpath('configs', 'config.yaml'))
    return parser.parse_args([])

In [8]:
# Load the trained LDM

# load config
args = load_args()
config = load_yaml_param_settings(args.config)

# load the pretrained encoder, decoder, and vq
encoder, decoder, vq_model = load_pretrained_encoder_decoder_vq(config, 'saved_models', freeze_models=True)
encoder, decoder, vq_model = encoder.cuda(), decoder.cuda(), vq_model.cuda()

# model
model = Unet(
        in_channels=config['VQ-VAE']['codebook_dim'],
        dim=64,
        dim_mults=(1, 2, 4, 8),
        self_condition=config['diffusion']['unet']['self_condition'],
).cuda()

diffusion = GaussianDiffusion(
    model,
    in_size=encoder.H_prime[0].item(),  # width or height of z
    timesteps=1000,  # number of steps
    sampling_timesteps=1000,
    # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper])
    loss_type='l1',  # L1 or L2
    auto_normalize=False,
).cuda()

# train
trainer = Trainer(
    diffusion,
    config,
    encoder,
    decoder,
    vq_model,
    train_batch_size=config['dataset']['batch_sizes']['stage2'],
    train_lr=8e-5,
    train_num_steps=700000,  # total training steps
    gradient_accumulate_every=2,  # gradient accumulation steps
    ema_decay=0.995,  # exponential moving average decay
    amp=False,  # turn on mixed precision
    fp16=False,
    save_and_sample_every=1000, #1000,
    num_samples=9,
    augment_horizontal_flip=False
)

Data loading... 0%
Data loading... 2%
Data loading... 4%
Data loading... 6%



KeyboardInterrupt



In [7]:
# load the pretrained LDM
trainer.load(milestone=)

In [None]:
# sample z
trainer.ema.ema_model.eval()
z_gen = trainer.ema.ema_model.sample(batch_size=9)

print('z_gen.shape:', z_gen.shape)

In [None]:
plt.hist(z_gen.cpu().detach().numpy().flatten(), bins=100, log=True)
plt.show()

In [None]:
# # VQ(z_gen)
# h, w = z_gen.shape[2], z_gen.shape[3]
# z_gen = rearrange(z_gen, 'b d h w -> b (h w) d')
# z_gen, _ = trainer.pretrained_vq._codebook(z_gen)
# z_gen = rearrange(z_gen, 'b (h w) d -> b d h w', h=h, w=w)

# decode
z_gen = rearrange(z_gen, 'b d h w -> b h w d')
z_gen = trainer.pretrained_vq.project_out(z_gen)
z_gen = rearrange(z_gen, 'b h w d -> b d h w')
print('z_gen.shape:', z_gen.shape)

x_gen = trainer.pretrained_decoder(z_gen)  # (b c h w)
x_gen = x_gen.cpu().detach()
x_gen = x_gen.argmax(dim=1)[:,None,:,:].float()
print('x_gen.shape:', x_gen.shape)

In [None]:
# plot
n_samples = x_gen.shape[0]
n_rows = int(np.ceil(np.sqrt(n_samples)))
fig, axes = plt.subplots(n_rows, n_rows, figsize=(12, 12))
axes = axes.flatten()

data = x_gen.numpy()  # (b 1 h w)
data = np.flip(data, axis=2)  # (b 1 h w)
data = data.squeeze()  # (b h w)
for i in range(n_samples):
    d = data[i]  # (h w)
    axes[i].imshow(d)
    axes[i].set_xticks([])
    axes[i].set_yticks([])
plt.tight_layout()
plt.show()