# Optical Illusion with text-to-image Diffusion
Optical illusion is a visual phenomenon that tricks the brain into perceiving something that isn't there or misinterpreting the true nature of an image.

This code uses a pretrained diffusion model to generate such images, focusing on different contents from various perspectives.


## Hugging Face Access
In this homework, we deploy a `pixel-based` diffusion model named [DeepFloyd IF](https://huggingface.co/docs/diffusers/api/pipelines/deepfloyd_if). Therefore, it's necessary to obtain the access token from Hugging Face, please follows these steps below:

1. Make sure to have a [Hugging Face account](https://huggingface.co/join) and be logged in.
2. Accept the license on the model card of [DeepFloyd/IF-I-XL-v1.0](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0).
3. Log in locally by entering your [Hugging Face Hub access token](https://huggingface.co/docs/hub/security-tokens#what-are-user-access-tokens) below, which can be [found here](https://huggingface.co/settings/tokens).

In [None]:
from huggingface_hub import login

##############################
# TODO2-0: Fill your acess token
# Begin your code
token = ""
# raise NotImplementedError
# End your code
##############################

login(token=token)

## Install Dependencies
Run the cell below to install the required dependencies. You can skip this step if the environment is already setup.

In [None]:

! pip install -q   \
    diffusers      \
    transformers   \
    safetensors    \
    sentencepiece  \
    accelerate     \
    bitsandbytes   \
    einops         \
    mediapy        \
    python-time    \
    pillow


## Import Dependencies and Misc Setup
We import packages we need and define some useful functions

In [None]:
import os
import gc
import torch
import mediapy as mp
from time import sleep


# Convert image ([-1,1] GPU) into image ([0,255] CPU)
def im_to_np(im):
    im = (im / 2 + 0.5).clamp(0, 1)
    im = im.detach().cpu().permute(1, 2, 0).numpy()
    im = (im * 255).round().astype("uint8")
    return im


# Garbage collection function to free memory
def flush():
    sleep(1)
    gc.collect()
    torch.cuda.empty_cache()


# Set up device
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Current Device: {device}")

## Load T5 TextEncoder Model
We will load the `T5` text model in half-precision (`fp16`), use it to encode some prompts, and then delete it to recover GPU memory. Note that downloading the model may take a minute or two.

In [None]:
from transformers import T5EncoderModel
from diffusers import DiffusionPipeline

text_encoder = T5EncoderModel.from_pretrained(
    "DeepFloyd/IF-I-L-v1.0",
    subfolder="text_encoder",
    device_map=None,
    variant="fp16",
    torch_dtype=torch.float16,
)

pipe = DiffusionPipeline.from_pretrained(
    "DeepFloyd/IF-I-L-v1.0",
    text_encoder=text_encoder,  # pass the previously instantiated text encoder
    unet=None
)
pipe = pipe.to(device)

##  Create Text Embeddings

We can now use the T5 model to embed prompts for our optical illusion. It may be a good idea to embed a few prompts that you want to use, given that we will delete the T5 text encoder in the next block. See the commented out code for an example of how to do this.

In [None]:
##############################
# TODO2-1: Prompt Design
# Begin your code
prompt_1 = "A red fox emerging from swirling flames, with fiery orange and gold colors, mystical atmosphere"
prompt_2 = "A castle tower rising from swirling flames when viewed upside down, with fiery orange and gold colors, mystical atmosphere"
# raise NotImplementedError
# End your code
##############################

# Embed prompts using the T5 model
prompts = [prompt_1, prompt_2]
prompt_embeds = [pipe.encode_prompt(prompt) for prompt in prompts]
prompt_embeds, negative_prompt_embeds = zip(*prompt_embeds)
prompt_embeds = torch.cat(prompt_embeds)
negative_prompt_embeds = torch.cat(negative_prompt_embeds)  # These are just null embeds

## Viewing Transformation
To generate multi-view optical illusions, we need to predefine the viewing transformation for the `denoising process`. However, there are some `constraints` on the transformation matrix, including the properties of being `invertible, linear, and orthogonal`. You don't need to worry about these constraints in this homework, but understanding them can be helpful if you want to explore different viewing effects.

In [None]:
from torchvision.transforms import functional as TF

##############################
# TODO2-2: Viewing Transformation
# Begin your code
class IdentityView:
    def __init__(self):
        pass
    def view(self, im):
        return im
    def inverse_view(self, noise):
        return noise

class Rotate180View:
    def __init__(self):
        pass
    def view(self, im):
        return TF.rotate(im, 180)
    def inverse_view(self, noise):
        return TF.rotate(noise, 180)

views = [IdentityView(), Rotate180View()]

# End your code
##############################

## Delete the Text Encoder

We now delete the text encoder (and the `diffusers` pipeline) and flush to free memory for the DeepFloyd image generation model.

In [None]:
del text_encoder
del pipe
flush()

## Main Diffusion Process

With our now released and available GPU memory, we can load the various DeepFloyd IF diffusion models (also at `float16` precision). This may take a minute of two.

In [None]:
from diffusers import DiffusionPipeline

# Load DeepFloyd IF stage I
stage_1 = DiffusionPipeline.from_pretrained(
                "DeepFloyd/IF-I-L-v1.0",
                text_encoder=None,
                variant="fp16",
                torch_dtype=torch.float16,
            ).to(device)

# Load DeepFloyd IF stage II
stage_2 = DiffusionPipeline.from_pretrained(
                "DeepFloyd/IF-II-L-v1.0",
                text_encoder=None,
                variant="fp16",
                torch_dtype=torch.float16,
            ).to(device)

# Load DeepFloyd IF stage III
# (which is just Stable Diffusion 4x Upscaler)
stage_3 = DiffusionPipeline.from_pretrained(
                "stabilityai/stable-diffusion-x4-upscaler",
                torch_dtype=torch.float16
            ).to(device)

## Denoising Operation
In the vanila DeepFloyd IF, it can directly apply stage_1, stage_2, and stage_3 sequentially to compute an $1024 \times 1024$ image that related to prompt. Nevertheless, in both stage_1 and stage_2, we need to compute noises from different views and apply them on noisy image.

In [None]:
from tqdm import tqdm

import torch

##############################
# TODO2-3: Denoising Operation
# The following code only computes and applies the noise
# from the first prompt and original view. Please revise
# the denoising process according to spec.

# Begin your code

@torch.no_grad()
def denoising_loop(model, noisy_images, prompt_embeds, views, 
                   timesteps, guidance_scale, generator, noise_level=None, upscaled=None):

    num_prompts = prompt_embeds.shape[0]
    original_noise_level = noise_level
    total_steps = len(timesteps)
    
    for i, t in enumerate(tqdm(timesteps)):
        
        progress = i / total_steps
        if progress < 0.3:
            view_weights = [0.5, 0.5]
        elif progress < 0.7:
            view_weights = [0.5, 0.5]
        else:
            view_weights = [0.5, 0.5]
        
        viewed_images =[]
        for view in views:
            viewed_images.append(view.view(noisy_images))
        viewed_images = torch.cat(viewed_images, dim=0)
        
        # If upscaled is provided (stage 2), concatenate with noisy_images
        if upscaled is not None:
            viewed_upscaled = []
            for view in views:
                viewed_upscaled.append(view.view(upscaled))
            viewed_upscaled = torch.cat(viewed_upscaled, dim=0)
            model_input = torch.cat([viewed_images, viewed_upscaled], dim=1)
        else:
            model_input = viewed_images

        # Duplicate inputs for CFG
        # Model input is: [ neg_0, neg_1, ..., pos_0, pos_1, ... ]
        model_input = torch.cat([model_input] * 2)
        model_input = model.scheduler.scale_model_input(model_input, t)

        current_noise_level = None
        if original_noise_level is not None: # stage 2
            noise_value = original_noise_level[0] if original_noise_level.numel() > 0 else original_noise_level
            current_noise_level = noise_value.unsqueeze(0).repeat(model_input.shape[0])
        
            
        # Predict noise estimate
        noise_pred = model.unet(
            model_input, t, encoder_hidden_states=prompt_embeds, 
            class_labels=current_noise_level, cross_attention_kwargs=None, return_dict=False
        )[0]

        # Extract uncond (neg) and cond noise estimates
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

        # Split into noise estimate and variance estimates
        # Split predicted noise and predicted variances
        splited_size = model_input.shape[1] // (2 if upscaled is not None else 1)
        noise_pred_uncond, _ = noise_pred_uncond.split(splited_size, dim=1)
        noise_pred_text, predicted_variance = noise_pred_text.split(splited_size, dim=1)
        
        # Apply CFG only to noise prediction
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
        # Split noise predictions for each view
        noise_preds_per_view = noise_pred.chunk(num_prompts, dim=0)
        var_preds_per_view = predicted_variance.chunk(num_prompts, dim=0) 
        
        # Apply inverse view transformations to each noise prediction
        inverse_viewed_noises = []
        inverse_viewed_vars = []
        for noise_pred_view, var_pred_view, view in zip(noise_preds_per_view, var_preds_per_view, views):
            inverse_viewed_noises.append(view.inverse_view(noise_pred_view))
            inverse_viewed_vars.append(view.inverse_view(var_pred_view))

        weighted_noise = sum(w * noise for w, noise in zip(view_weights, inverse_viewed_noises))
        weighted_var = sum(w * var for w, var in zip(view_weights, inverse_viewed_vars))
        
        
        # Combine averaged noise with averaged predicted variance
        combined_prediction = torch.cat([weighted_noise, weighted_var], dim=1)
        
        # Compute the previous noisy sample x_t -> x_t-1
        noisy_images = model.scheduler.step(combined_prediction, t, noisy_images, generator=generator, return_dict=False)[0]

    return noisy_images

# End your code
##############################

In [None]:
import torch
import torch.nn.functional as F

from diffusers.utils.torch_utils import randn_tensor

@torch.no_grad()
def adjusted_stage_1(model, prompt_embeds, negative_prompt_embeds, views,
                   num_inference_steps=100, guidance_scale=7.0, generator=None):

    num_prompts = prompt_embeds.shape[0]
    assert num_prompts == len(views), \
        "Number of prompts must match number of views!"
    
    height = model.unet.config.sample_size
    width = model.unet.config.sample_size

    # For CFG
    prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

    # Setup timesteps
    model.scheduler.set_timesteps(num_inference_steps, device=device)
    timesteps = model.scheduler.timesteps

    # Make intermediate_images
    noisy_images = model.prepare_intermediate_images(
        1, model.unet.config.in_channels, height, width, prompt_embeds.dtype, device, generator,
    )

    return denoising_loop(model, noisy_images, prompt_embeds, views, 
                          timesteps, guidance_scale, generator)


@torch.no_grad()
def adjusted_stage_2(model, image, prompt_embeds, negative_prompt_embeds, views,
                   num_inference_steps=100, guidance_scale=7.0, noise_level=50, generator=None):

    num_prompts = prompt_embeds.shape[0]
    assert num_prompts == len(views), \
        "Number of prompts must match number of views!"
        
    height = model.unet.config.sample_size
    width = model.unet.config.sample_size
    num_images_per_prompt = 1

    # For CFG
    prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

    # Get timesteps
    model.scheduler.set_timesteps(num_inference_steps, device=device)
    timesteps = model.scheduler.timesteps

    num_channels = model.unet.config.in_channels // 2
    noisy_images = model.prepare_intermediate_images(
        1, num_channels, height, width, prompt_embeds.dtype, device, generator,
    )

    # Prepare upscaled image and noise level
    image = model.preprocess_image(image, num_images_per_prompt, device)
    upscaled = F.interpolate(image, (height, width), mode="bilinear", align_corners=True)

    noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device)
    noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype)
    upscaled = model.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level)

    return denoising_loop(model, noisy_images, prompt_embeds, views, 
                          timesteps, guidance_scale, generator, noise_level, upscaled)

## Generate Illusions
Now, we can sample illusions by denoising multiple views at once. The `adjusted_stage_1` function does this and generates a $64 \times 64$ image. The `adjusted_stage_2` function upsamples the resulting image while denoising all views, and generates a $256 \times 256$ image.

Finally, `stage_3` simply upsamples the $256 \times 256$ image using a single given text prompt to $1024 \times 1024$, _without_ doing multi-view denoising.

In [None]:
image_64 = adjusted_stage_1(stage_1, prompt_embeds, negative_prompt_embeds, views,
                          num_inference_steps=30, guidance_scale=15.0, generator=None)

# Show result
mp.show_images([im_to_np(view.view(image_64[0])) for view in views])

In [None]:
image_256 = adjusted_stage_2(stage_2, image_64, prompt_embeds, negative_prompt_embeds, views,
                           num_inference_steps=30, guidance_scale=15.0, noise_level=50, generator=None)

# Show result
mp.show_images([im_to_np(view.view(image_256[0])) for view in views])

In [None]:
image_1024 = stage_3(prompt=prompts[0], image=image_256,
                noise_level=0, output_type='pt', generator=None).images
image_1024 = image_1024 * 2 - 1

# Limit display size, otherwise it's too large for most screens
mp.show_images([im_to_np(view.view(image_1024[0])) for view in views], width=400)
mp.write_image('result.jpg', im_to_np(image_1024[0]))

## Delete the Stages and Images
We now delete the stages for DeepFloyd image generation and flush to free memory for the CLIP score evaluation.

In [None]:
del stage_1
del stage_2
del stage_3
flush()

del image_64
del image_256
del image_1024
flush()

## CLIP Score
This is an evaluation for optical illusion images. A higher score indicates that the text and image are more `closely related`. To ensure the image quality, the score of each image, after applying viewing transformations and comparing it to the corresponding text, `must exceed 0.3`.

Note that you can `regenerate` the optical illusion image using the same code until the score is high enough.

In [None]:
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel

# Load CLIP model and processor
path = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(path)
processor = CLIPProcessor.from_pretrained(path)

# Define images and texts
image_path = "result.jpg"
texts = [prompt_1, prompt_2]

image = Image.open(image_path)
images = [view.view(image) for view in views]

# Preprocess
inputs = processor(text=texts, images=images, return_tensors="pt", padding=True)

# Use CLIP to compute the embedding
outputs = model(**inputs)
image_features = outputs.image_embeds
text_features = outputs.text_embeds

# Calculate the cosine similarities (images <-> texts) with embeddings
cosine_similarities = torch.nn.functional.cosine_similarity(image_features, text_features, dim=-1)

for text, score in zip(texts, cosine_similarities):
    print(f"Prompt: {text}")
    print(f"CLIP Score: {score:.4f}\n")