# **Training on LSUN (Large Scale Understanding) Bedrooms Data Set**

In this NB, we will take on the more challenging task of generating larger images. The dataset being used here is the [LSUN Bedrooms Dataset](https://paperswithcode.com/sota/image-generation-on-lsun-bedroom-256-x-256), which is a subset of the original comprising 10 classes of objects / scenes. The description for the original data is as follows:

> The Large-scale Scene Understanding (LSUN) challenge aims to provide a different benchmark for large-scale scene classification and understanding. The LSUN classification dataset contains 10 scene categories, such as dining room, bedroom, chicken, outdoor church, and so on. For training data, each category contains a huge number of images, ranging from around 120,000 to 3,000,000. The validation data includes 300 images, and the test data has 1000 images for each category.

In [None]:
import os
from miniai.imports import *
from miniai.diffusion import *
from diffusers import UNet2DModel, UNet2DConditionModel, AutoencoderKL
from fastprogress import progress_bar
from glob import glob
from copy import deepcopy
import timm
import warnings 

In [None]:
torch.set_printoptions(precision=4, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams['image.cmap'] = 'gray_r'
mpl.rcParams['figure.dpi'] = 70

set_seed(42)
if fc.defaults.cpus>8: fc.defaults.cpus=8

warnings.simplefilter('ignore', UserWarning)

## **Download and Process Data**

In [None]:
path_data = Path('data')
path_data.mkdir(exist_ok=True)
path = path_data/'bedroom'

Given persistent issues with downloading the data, Jeremy placed a subset of 20% of the data on AWS. Also, the original data is stored in a LMDB format which Jeremy converted for us as well.

`NOTE` - _If the download takes a long time in Python, then revert to using a shell._

In [None]:
url = 'https://s3.amazonaws.com/fast-ai-imageclas/bedroom.tgz' # Download tarball
if not path.exists():
    path_zip = fc.urlsave(url, path_data)
    shutil.unpack_archive('data/bedroom.tgz', 'data')

In [None]:
bs = 64

In [None]:
# read_image is highly optimized for this op.
# .RGB ensures conversion of image types to required outputs.
def to_img(f): return read_image(f, mode=ImageReadMode.RGB)/255 

Lets work on converting the images to latents.

In [None]:
class ImagesDS:
    def __init__(self, spec):
        self.path = Path(path)
        self.files = glob(str(spec), recursive=True) #search for filetypes

    def __len__(self): return len(self.files)

    def __getitem__(self, i): return to_img(self.files[i])[:, :256, :256] # crop last dims of images to reduce some compute and align sizes.

In [None]:
ds = ImagesDS(path/f'**/*.jpg') # search recursively for all jpeg files in the bedroomm folder.

In [None]:
dl = DataLoader(ds, batch_size=bs, num_workers=fc.defaults.cpus) # Load batches with max cpus in parallel.
xb = next(iter(dl))
show_images(xb[:16], imsize=2)

In [None]:
xb[:16].shape

In [None]:
# Total number of floats.
16*3*256*256

## **Using a Pre-Trained VAE**

In [None]:
# Grabbing pretrained encoder.
# Turn of gradient computations inplace
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").cuda().requires_grad_(False)

In [None]:
xe = vae.encode(xb.cuda()) # encode

In [None]:
xs = xe.latent_dist.mean[:16]
xs.shape

In [None]:
# Comparison of original vs compressed / encoded image sizes.
(16*3*256*256) / (16*4*32*32) 

In [None]:
# Grab images and first three channels
# Sigmoid ensures images fall between 0 and 1
show_images(((xs[:16, :3]) / 4).sigmoid(), imsize=2)

In [None]:
xd = to_cpu(vae.decode(xs)) # Decode tensor
show_images(xd['sample'].clamp(0, 1), imsize=2)

Reconstructed images for quality checks before passing them to the diffusion model. The [`sd-vae-ft-ema`](https://huggingface.co/stabilityai/sd-vae-ft-ema) VAE has a known limitation in regenerating writing / text in images. The creators mention:

> We publish two kl-f8 autoencoder versions, finetuned from the original kl-f8 autoencoder on a 1:1 ratio of LAION-Aesthetics and LAION-Humans, an unreleased subset containing only SFW images of humans. The intent was to fine-tune on the Stable Diffusion training set (the autoencoder was originally trained on OpenImages) but also enrich the dataset with images of humans to improve the reconstruction of faces. The first, ft-EMA, was resumed from the original checkpoint, trained for 313198 steps and uses EMA weights.
>
> It uses the same loss configuration as the original checkpoint (L1 + LPIPS). The second, ft-MSE, was resumed from ft-EMA and uses EMA weights and was trained for another 280k steps using a different loss, with more emphasis on MSE reconstruction (MSE + 0.1 * LPIPS). It produces somewhat ``smoother'' outputs. The batch size for both versions was 192 (16 A100s, batch size 12 per GPU). To keep compatibility with existing models, only the decoder part was finetuned; the checkpoints can be used as a drop-in replacement for the existing autoencoder.

Based on Perplexity, these are some additional details:

>   **Training Data**: It was fine-tuned on a combination of the LAION-Aesthetics and LAION-Humans datasets to enhance the reconstruction of faces and human subjects.
>
>   **Loss Configuration**: The model uses the same loss configuration as the original kl-f8 autoencoder, which includes L1 loss and LPIPS (Learned Perceptual Image Patch Similarity).
>
>   **Exponential Moving Average (EMA) Weights**: The ft-EMA version utilizes EMA weights, which help stabilize the training process and improve model performance.
>
>   **Training Steps**: The model was trained for 313,198 steps.
>
>   **Performance**: Compared to the original kl-f8 VAE, the ft-EMA model shows slightly improved performance, with a lower rFID score of 4.42 versus 4.99 for the original.
>
>   **Applications**: It can be used as a drop-in replacement for the original autoencoder in the Stable Diffusion pipeline, potentially leading to improved downstream generation results. Additionally, it is suitable for tasks like image compression and editing.
>
>   **Variants**: There is another variant, sd-vae-ft-mse, which emphasizes MSE reconstruction and produces smoother outputs.


To read more about Learned Perceptual Image Patch Similarity (LPIPS) read the [paper and the visit the associated Github page](https://richzhang.github.io/PerceptualSimilarity/).

In [None]:
# We will use Memory Mapped Numpy File (NPMM) format to handle in memory tasks more efficiently.
mmpath = Path('data/bedroom/data.npmm')

In [None]:
len(ds)

In [None]:
mmshape = (303125, 4, 32, 32)

In [None]:
if not mmpath.exists(): # Create and store npmm file on disk. Shapes are the same as our images.
    a = np.memmap(mmpath, np.float32, mode='w+', shape=mmshape)
    i = 0
    for b in progress_bar(dl): # Grab a mini batch
        n = len(b)
        # Encode and get the means of the latents and convert to numpy since pytorch doesn't have a memory mapping tool (as of 2023)
        a[i : i+n] = to_cpu(vae.encode(b.cuda()).latent_dist.mean).numpy()
        i += n
    a.flush() # Ensure that the contents of the cache are written to disk.
    del(a)

In [None]:
lats = np.memmap(mmpath, dtype=np.float32, mode='r', shape=mmshape) # apply memory mapping

In [None]:
b = torch.tensor(lats[:16]) # Verify

In [None]:
xd = to_cpu(vae.decode(b.cuda()))
show_images(xd['sample'].clamp(0,1), imsize=2)

## **Noisify**

We are now able to apply pipeline operations such as Noisify using numpy, which is pretty cool!! These foundational concepts should be applied more regularly.

In [None]:
def collate_ddpm(b): return noisify(default_collate(b)*0.2) # Ensure unit standard dev.

In [None]:
n = len(lats)

In [None]:
# Create training and validation sets.
tds = lats[:n // 10*9] # First 90%
vds = lats[ n // 10*9:]# Last 10%

In [None]:
bs = 128

In [None]:
dls = DataLoaders(*get_dls(tds, vds, bs=bs, num_workers=fc.defaults.cpus, collate_fn=collate_ddpm))

In [None]:
(xt, t), eps = b = next(iter(dls.train))

In [None]:
show_images(xt[:9, 0], imsize=1.5)

In [None]:
xte = vae.decode(xt[:9].cuda()*5)['sample']
show_images(xte.clamp(0,1), imsize=1.5)

## **Train Latent Diffusion Model**

In [None]:
def init_ddpm(model):
    for o in model.downs:
        for p in o.resnets: p.conv2[-1].weight.data.zero_()

    for o in model.ups:
        for p in o.resnets: p.conv2[-1].weight.data.zero_()

In [None]:
lr = 3e-3
epochs = 25
opt_func = partial(optim.AdamW, eps=1e-5)
tmax = epochs * len(dls.train)
sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)
cbs = [DeviceCB(), ProgressCB(plot=True), MetricsCB(), BatchSchedCB(sched), MixedPrecision()]
model = EmbUNetModel(in_channels=4, out_channels=4, nfs=(128, 256, 512, 768), num_layers=2,
                     attn_start=1, attn_chans=16)
init_ddpm(model)
learn = Learner(model, dls, nn.MSELoss(), lr=lr, cbs=cbs, opt_func=opt_func)

In [None]:
learn.fit(epochs) # Loss will be higher since the model tries to generate latent pixels which is a much more difficult task

## **Sampling**

In [None]:
sz = (16,4,32,32)

In [None]:
preds = sample(ddim_step, model, sz, steps=100, eta=1., clamp=False)

In [None]:
s = preds[-1]*5

In [None]:
# Decode since what we're predicting is latents.
with torch.no_grad(): pd = to_cpu(vae.decode(s.cuda()))

In [None]:
show_images(pd['sample'][:9].clamp(0,1), imsize=5)