## Setup & Installation

In [None]:
%pip install fiftyone wandb open-clip-torch

## Part 1: Image Generation and Embedding Extraction
In this section, you will load the pre-trained U-Net model from notebook 05_CLIP.ipynb, generate images of flowers, and extract embeddings from the model's bottleneck.

In [None]:
import torch

from utils import UNet_utils, ddpm_utils

# TODO: Initialize the U-Net model and load the pre-trained weights from notebook 05.

# Make sure to use the same architecture as in the notebook.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = UNet_utils.UNet(
    T=400, img_ch=3, img_size=32, down_chs=(256, 256, 512), t_embed_dim=8, c_embed_dim=512
).to(device)

# TODO
# model.load_state_dict(torch.load('path_to_your_model.pth')) # You need to provide the path to your trained model

model.eval()

# TODO: Define a list of text prompts to generate images for.
text_prompts = [

    "A photo of a red rose",
    "A photo of a white daisy",
    "A photo of a yellow sunflower"

]

# --- Embedding Extraction using Hooks ---
# We will use PyTorch hooks to extract the output of the 'down2' layer (the bottleneck).

embeddings_storage = {}

def get_embedding_hook(name):
    def hook(model, input, output):
        embeddings_storage[name] = output.detach()
    return hook

# TODO: Register a forward hook on the `down2` layer of the U-Net model.

# The hook should store the output of the layer in the `embeddings_storage` dictionary.
# model.down2.register_forward_hook(get_embedding_hook('down2'))

# TODO: Modify the `sample_flowers` function from notebook 05 to generate images 
# and store the extracted embeddings.

# You will need to run the generation process and then access the `embeddings_storage`
# to get the embeddings for each generated image.
# generated_images, _ = sample_flowers(text_prompts)
# extracted_embeddings = embeddings_storage['down2']

## Part 2: Evaluation with CLIP Score and Frechet Inception Distance
Now, evaluate the quality of your generated images using the measures described in the Metrics Calculation Guide section.

In [None]:
import open_clip

# TODO: Calculate the CLIP score for each generated image against its prompt.
# You can use the `calculate_clip_score` function from the evaluation guide.

# TODO: Calculate the FID score for the set of generated images.
# You will need the `calculate_fid` function and the Inception model from the evaluation guide.
# You will also need to load the real TF-Flowers dataset to compare against.


## Part 3: Embedding Analysis with FiftyOne Brain
In this section, you will use FiftyOne to analyze the embeddings you extracted from the U-Net.


In [None]:
import fiftyone as fo
import fiftyone.brain as fob

# TODO: Create a new FiftyOne dataset.

dataset = fo.Dataset(name="generated_flowers_with_embeddings")

# TODO: Iterate through your generated images and add them to the dataset.
# For each image, create a fiftyone.Sample and add the following metadata:
# - The file path to the saved image.
# - The text prompt (as a `fo.Classification` label).
# - The CLIP score (as a custom field).
# - The extracted U-Net embedding (as a custom field).

# TODO: Compute uniqueness for the dataset.
# fob.compute_uniqueness(dataset)

# TODO: Compute representativeness using the extracted U-Net embeddings.
# fob.compute_representativeness(dataset, embeddings="unet_embedding")

# TODO: Launch the FiftyOne App to visualize your dataset and analyze the results.
# session = fo.launch_app(dataset)

## Part 4: Logging with Weights & Biases

In [None]:
import wandb

# TODO: Login to wandb.
# wandb.login()

# TODO: Initialize a new wandb run.
# run = wandb.init(project="diffusion_model_assessment_v2")

# TODO: Log your hyperparameters (e.g., guidance weight `w`, number of steps `T`).

# TODO: Log your evaluation metrics (CLIP Score and FID).

# TODO: Create a wandb.Table to log your results. The table should include:
# - The generated image.
# - The text prompt.
# - The CLIP score.
# - The uniqueness score.
# - The representativeness score.

# TODO: Finish the wandb run.
# run.finish()

## Metrics

CLIP Score

In [None]:
import torch

import open_clip

from PIL import Image

def calculate_clip_score(image_path, text_prompt):

    # Load model
    model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')

    # Preprocess inputs
    image = preprocess(Image.open(image_path)).unsqueeze(0)
    tokenizer = open_clip.get_tokenizer('ViT-B-32')
    text = tokenizer([text_prompt])

    # Compute features and similarity
    with torch.no_grad():
        image_features = model.encode_image(image)
        text_features = model.encode_text(text)

        # Normalize features
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)

        # Calculate dot product
        score = (image_features @ text_features.T).item()

    return score

Fr√©chet Inception Distance (FID) FID measures the distance between the feature distributions of real images and generated images. Lower scores indicate that the generated images possess visual quality and diversity similar to the real dataset. Note that this metric is defined through the InceptionV3 model and you have to use an ImageNet pre-trained InceptionV3 model to compute it. Here's a demo notebook to do this.

In [None]:
import numpy as np

from scipy.linalg import sqrtm

def calculate_fid(real_embeddings, gen_embeddings):

    # real_embeddings and gen_embeddings should be Numpy arrays of shape (N, 2048) 
    # extracted from an InceptionV3 model

    # Calculate mean and covariance
    mu1, sigma1 = real_embeddings.mean(axis=0), np.cov(real_embeddings, rowvar=False)
    mu2, sigma2 = gen_embeddings.mean(axis=0), np.cov(gen_embeddings, rowvar=False)

    # Calculate sum squared difference between means
    ssdiff = np.sum((mu1 - mu2)**2)

    # Calculate sqrt of product of covariances
    covmean = sqrtm(sigma1.dot(sigma2))

    # Handle numerical errors
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    # Final FID calculation
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid