In [None]:
import torch
from PIL import Image
import os
from diffusers import AutoencoderKL
from torchvision.transforms import ToTensor, Resize, Normalize

# Load pre-trained VAE model
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
vae.eval()  # Set the model to evaluation mode

# Function to preprocess an image to 64x64 and convert to tensor
def preprocess_image(image_path):
    image = Image.open(image_path).convert("RGB")
    transform = Resize((64, 64))  # Resize to 64x64 if necessary
    image = transform(image)
    image_tensor = ToTensor()(image).unsqueeze(0)
    image_tensor = Normalize(mean=[0.5], std=[0.5])(image_tensor)  # Add batch dimension
    return image_tensor  # Scale to [-1, 1] for the VAE model

# Load and preprocess a 64x64 image
# image_path = "/Users/karim/Downloads/food_data/train/apple_pie/157083.jpg"
rec_images = []
real_images = []
for cls in os.listdir("/Users/karim/Downloads/food_data/test"):
    if os.path.isdir("/Users/karim/Downloads/food_data/test/" + cls):
        for image_path in os.listdir("/Users/karim/Downloads/food_data/test/" + cls):
            image_tensor = preprocess_image("/Users/karim/Downloads/food_data/test/" + cls + '/' + image_path)

        # Encode the image into the latent space
            with torch.no_grad():
                latent_representation = vae.encode(image_tensor).latent_dist.sample()

            # Decode the latent representation back to the image space
            with torch.no_grad():
                reconstructed_image = vae.decode(latent_representation).sample

            # Post-process to display the reconstructed image
            reconstructed_image = (reconstructed_image / 2 + 0.5).clamp(0, 1)  # Scale back to [0, 1]
            rec_images.append(reconstructed_image)
            real_images.append(image_tensor)
# reconstructed_image_pil = ToPILImage()(reconstructed_image.squeeze())

# # Display or save the result
# reconstructed_image_pil.show()  # Display
# reconstructed_image_pil.save("reconstructed_image.png")  # Save if desired


In [97]:
len(rec_images)

2020

In [104]:
rec_ = torch.concat(rec_images, dim=0)
real_ = torch.concat(real_images, dim=0)

In [105]:
from src.metrics import FIDMetric

metric = FIDMetric(name="test", device="cpu")

metric(rec_, real_)

46.174748222875905

In [87]:
from src.model import Unet

unet = Unet(
    img_size=8,
    init_dim=64,
    dim_mults=[1, 2, 4, 8],
    time_dim=256,
    in_channels=4,
    out_channels=4,
    down_kern=2,
    up_scale=2,
    resnet_stacks=3,
    attn_heads=8,
    attn_head_res=64,
    self_condition=False,
    resnet_grnorm_groups=8,
    classes=101
)

unet.to("cpu")

print(sum(p.numel() for p in unet.parameters() if p.requires_grad))

49618860


In [1]:
!python3 inference.py

Unet(
  (init_conv): Conv2d(4, 128, kernel_size=(1, 1), stride=(1, 1))
  (time_mlp): Sequential(
    (0): SinusoidalPositionEmbeddings()
    (1): Linear(in_features=8, out_features=256, bias=True)
    (2): GELU(approximate='none')
    (3): Linear(in_features=256, out_features=256, bias=True)
  )
  (downs): ModuleList(
    (0): ModuleList(
      (0): ResnetBlock(
        (mlp): Sequential(
          (0): SiLU()
          (1): Linear(in_features=256, out_features=256, bias=True)
        )
        (block1): conv_block(
          (proj): WeightStandardizedConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm): GroupNorm(8, 128, eps=1e-05, affine=True)
          (act): SiLU()
        )
        (block2): conv_block(
          (proj): WeightStandardizedConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (norm): GroupNorm(8, 128, eps=1e-05, affine=True)
          (act): SiLU()
        )
        (res_conv): Identity()
      )
      (1): Re