# Small Notebook for:

### 1. Preprocess images to feed Stable Diffusion
### 2. Evaluate how difficult the model grasp the features of your images

### Imports

In [None]:
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms

from src.utils import ImageFolder, VAEHandler, denormalize, preprocess_images

### Constants

In [None]:
# batch size
BATCH_SIZE = 4

# image size, after preprocessing
IMAGE_WIDTH = 512
IMAGE_HEIGHT = 512

# project root
ROOT_DIR = Path("/evaluate-images-to-feed-diffusion")

# model directory
MODEL_DIR = ROOT_DIR / "models"

# vae model directory
VAE_DIR = MODEL_DIR / "waifu-diffusion-v1-4"

# focal model directory (to be used to crop images nicely)
FOCAL_MODEL_DIR = MODEL_DIR / "focal"

# raw image directory
IMAGE_SOURCE_DIR = ROOT_DIR / "images"

# processed image directory
IMAGE_PREPROCESSED_DIR = ROOT_DIR / "processed"

### Preprocess Images

Crop and convert images suitable for feeding model.

If you do not leave `focal_model_dir=None`, focal model is automatically downloaded.

Then, images are cropped in consideration of where the face / focal point is.

In [None]:
preprocess_images(
    IMAGE_SOURCE_DIR, 
    IMAGE_PREPROCESSED_DIR, 
    width=IMAGE_WIDTH, 
    height=IMAGE_HEIGHT, 
    focal_model_dir=FOCAL_MODEL_DIR,
)

### Load VAE

In [None]:
vae_waifu_1_4 = VAEHandler(VAE_DIR)

### Prepare Evaluation

In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.5, 1)])
dataset = ImageFolder(IMAGE_PREPROCESSED_DIR, transform).make_iterator(batch_size=BATCH_SIZE, shuffle=False)

### Evaluate

The return value `res`  has: 

- Normalized tensor of original images
- Latent `z`
- Reconstructed tensors from `z`
- Loss values of each images

In [None]:
res = vae_waifu_1_4.get_loss_results(dataset)

### Loss

If there are some images whose loss value are quite high, model might not be able to learn the expressions of it well.

In [None]:
df = res.df.copy(deep=True)
df.sort_index().plot(x="idx", y="loss", xlabel="image_idx", ylabel="loss", figsize=(7, 5))
plt.show()

### Visualize Results

#### Worst and Best

In [None]:
res.plot_most_and_least_lossy_images(n=5)

#### All

In [None]:
for i in res.df.index:
    plt.imshow(denormalize(np.array([res.rec[i]]))[0])
    plt.title(f"loss: {res.loss[i]}")
    plt.show()