# Small Notebook for:

### 1. Preprocess images to feed Stable Diffusion
### 2. Evaluate how difficult the model grasp the features of your images

## Imports

In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

from src.utils import ImageFolder, InferrenceResult, VAEHandler, denormalize, preprocess_images

## Constants

In [None]:
BATCH_SIZE = 4

ROOT_DIR = Path("/eval_images_train_difficulty")

MODEL_DIR = ROOT_DIR / "models"

VAE_DIR = MODEL_DIR / "waifu-diffusion-v1-4"
FOCAL_MODEL_DIR = MODEL_DIR / "focal"

IMAGE_SOURCE_DIR = ROOT_DIR / "images"
IMAGE_PREPROCESSED_DIR = ROOT_DIR / "processed"

## Functions

In [None]:
preprocess_images(
    IMAGE_SOURCE_DIR, 
    IMAGE_PREPROCESSED_DIR, 
    width=512, 
    height=512, 
    focal_model_dir=FOCAL_MODEL_DIR,
)

## Load VAE

In [None]:
vae_waifu_1_4 = VAEHandler(VAE_DIR)

## Prepare Dataset

In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.5, 1)])
dataset = ImageFolder(IMAGE_PREPROCESSED_DIR, transform).make_iterator(batch_size=BATCH_SIZE, shuffle=False)

## Evaluate

In [None]:
res = vae_waifu_1_4.calc_loss(dataset)

## Loss

In [None]:
df = res.df.copy(deep=True)
df.sort_index().plot(x="idx", y="loss", xlabel="image_idx", ylabel="loss", figsize=(7, 5))
plt.show()

## Visualize Results

In [None]:
res.plot_most_and_least_lossy_images(n=5)