In [1]:
import logging
import math
import numpy as np
import os
import random
import sys
import time
from copy import deepcopy
from dataclasses import dataclass
from datetime import datetime
from packaging import version
from pathlib import Path
from PIL import Image
from pprint import pprint
from threading import Lock
from tqdm.auto import tqdm
from typing import Optional, Union

import datasets
from datasets import load_dataset

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torchvision import transforms

import accelerate
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, GradScalerKwargs, set_seed

from huggingface_hub import HfFolder, Repository, create_repo, whoami

import transformers
import diffusers
from diffusers import AutoencoderKL
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
from diffusers import Mel

logger = get_logger(__name__, log_level="INFO")

In [2]:
from train_vae import *

In [3]:
args = VAETrainConfig(
    mixed_precision='no',
    discriminator_start=1000,
    num_train_epochs=5,
    checkpoints_total_limit=5,
    checkpointing_steps=1000,
    resume_from_checkpoint=None,
)
args

# Debug

# Train

In [None]:
train(args)

In [None]:
accelerate.notebook_launcher(train, [args], num_processes=1)

# VAE Comparison

In [None]:
vae0 = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision).to('cuda')
# vae1 = AutoencoderKL.from_pretrained(args.output_dir , subfolder="", revision=args.revision).to('cuda')
vae_ckpts = [AutoencoderKL.from_pretrained(args.output_dir + "/checkpoints/" + sf, subfolder="vae", revision=args.revision).to('cuda')
             for sf in os.listdir(args.output_dir + "/checkpoints")]
all_vaes = [vae0, vae1] + vae_ckpts
for vae in all_vaes:
    vae.requires_grad_(False)
    vae.eval()
print(f"Model size {get_model_size(vae0):.3f} MB (2x)")

accelerator = Accelerator(mixed_precision=args.mixed_precision,)
dataset, dataloader = get_dataloader(accelerator, args)
print(f"{len(dataloader)} images available")

## STFT Images

In [None]:
idx = 0
image, latents = encode_sample(vae0, dataset[idx])
recons = [decode_latents(vae, latents) for vae in all_vaes]
imgs = [numpy_to_pil(recon) for recon in recons]

In [None]:
grid = image_grid([image, *imgs], 2)
grid

## Audio Reconstruction

In [None]:
import IPython.display as ipd

In [None]:
resolution = 512
fs = 22050
mel = Mel(x_res=resolution, y_res=resolution, sample_rate=fs, n_fft=2048,
          hop_length=resolution, top_db=80, n_iter=32,)
images = [image, *imgs]
audios = [mel.image_to_audio(im.convert('L')) for im in images]

In [None]:
for aud in audios:
    ipd.display(ipd.Audio(aud, rate=fs))