In [1]:
import logging
import math
import numpy as np
import os
import random
import sys
import time
from copy import deepcopy
from dataclasses import dataclass
from datetime import datetime
from packaging import version
from pathlib import Path
from PIL import Image
from pprint import pprint
from threading import Lock
from tqdm.auto import tqdm
from typing import Optional, Union

import datasets
from datasets import load_dataset

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torchvision import transforms

import accelerate
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, GradScalerKwargs, set_seed

from huggingface_hub import HfFolder, Repository, create_repo, whoami

import transformers
import diffusers
from diffusers import AutoencoderKL
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
from diffusers import Mel

logger = get_logger(__name__, log_level="INFO")

%load_ext autoreload
%autoreload 2

In [2]:
from train_vae import *

In [3]:
args = VAETrainConfig(
    mixed_precision='no',
    discriminator_start=1000,
    num_train_epochs=5,
    checkpoints_total_limit=5,
    checkpointing_steps=1000,
    resume_from_checkpoint=None,
    output_dir="vae-stft-fma-2",
)
args

# Debug

# Train

In [4]:
# accelerate.notebook_launcher(train, [args], num_processes=1)

# VAE Comparison

In [5]:
vae0 = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision).to('cuda')
vae1 = AutoencoderKL.from_pretrained(args.output_dir , subfolder="", revision=args.revision).to('cuda')
vae_ckpts = [AutoencoderKL.from_pretrained(args.output_dir + "/checkpoints/" + sf, subfolder="vae", revision=args.revision).to('cuda')
             for sf in os.listdir(args.output_dir + "/checkpoints")]
all_vaes = [vae0, vae1] + vae_ckpts
for vae in all_vaes:
    vae.requires_grad_(False)
    vae.eval()
print(f"Model size {get_model_size(vae0):.3f} MB (2x)")

accelerator = Accelerator(mixed_precision=args.mixed_precision,)
dataset, dataloader = get_dataloader(accelerator, args)
print(f"{len(dataloader)} images available")

## STFT Images

In [6]:
idx = 0
image, latents = encode_sample(vae0, dataset[idx])
recons = [decode_latents(vae, latents) for vae in all_vaes]
imgs = [numpy_to_pil(recon)[0] for recon in recons]

In [18]:
grid = image_grid([image, *imgs], 4, 4)
grid

## Audio Reconstruction

In [29]:
import IPython.display as ipd
import matplotlib.pyplot as plt
import copy

In [10]:
resolution = 512
fs = 22050
mel = Mel(x_res=resolution, y_res=resolution, sample_rate=fs, n_fft=2048,
          hop_length=resolution, top_db=80, n_iter=32,)
images = [image, *imgs]
audios = [mel.image_to_audio(im.convert('L')) for im in images]

In [11]:
for aud in audios:
    ipd.display(ipd.Audio(aud, rate=fs))

In [23]:
def audio_mse(audios):
    a0 = audios[0]
    mses = np.zeros(len(audios) - 1)
    for i, ai in enumerate(audios[1:]):
        mses[i] = np.linalg.norm((a0 - ai)) ** 2
    return mses

mses = audio_mse(audios)        

In [24]:
plt.plot(mses)

In [43]:
cargs = copy.copy(args)
cargs.train_batch_size = 2
dataset, dataloader = get_dataloader(accelerator, cargs)


In [53]:
images = []
recons = []
for i, batch in enumerate(dataloader):
    if i >= 100:
        break
    imgs, latents = encode_sample(vae0, batch)
    images.append(imgs)
    for vae in all_vaes:
        recon = decode_latents(vae, latents)
        recons.append(recon)
    print("Finished", i)

len(recons)

In [54]:
nvaes = len(all_vaes)
nrecons = len(recons) // nvaes
vae_recons = [[recons[6 * i + j] for i in range(nrecons)] for j in range(nvaes)]
print(len(vae_recons))
print(len(vae_recons[0]))

In [57]:
vae_images = []
for ivae in range(nvaes):
    vae_images.append([])
    for recon in vae_recons[ivae]:
        vae_images[ivae].extend(numpy_to_pil(recon))

pil_images = []
for imgs in images:
    pil_images.extend(imgs)
    
print(len(vae_images))
print(len(vae_images[0]))
print(len(pil_images))

In [60]:
resolution = 512
fs = 22050
mel = Mel(x_res=resolution, y_res=resolution, sample_rate=fs, n_fft=2048,
          hop_length=resolution, top_db=80, n_iter=32,)
vae_audios = []
for ivae in range(nvaes):
    vae_audios.append(None)
    vae_audios[ivae] = [mel.image_to_audio(im.convert('L')) for im in vae_images[ivae]]
    print("Finished", ivae)

orig_audios = []
for pim in pil_images:
    orig_audios.append(mel.image_to_audio(pim.convert('L')))

print(len(vae_audios))
print(len(vae_audios[0]))
print(len(orig_audios))

In [71]:
mses = []
for zaudio in zip(orig_audios, *vae_audios):
    mses.append(audio_mse(zaudio))
mses = np.array(mses)
plt.errorbar(np.arange(6), np.mean(mses, axis=0), yerr=np.std(mses, axis=0), linestyle='none')
plt.title("VAE vs Audio Reconstruction MSE")
plt.xlabel("VAE index [0=default, 1=final, 2+=checkpoint]")
plt.ylabel("MSE")