# Quantization Stable Diffusion

_Authored by: [Thomas Liang](https://github.com/thliang01)_


- [ ] TODO: write description and quantization stable diffusion models

## Install required python package

In [None]:
! pip install --upgrade diffusers accelerate transformers safetensors
! pip install -q numpy Pillow torchmetrics

## Import modules

In [None]:
import torch
import numpy as np
import os

import time

from PIL import Image
from IPython import display as IPdisplay
from tqdm.auto import tqdm

from diffusers import DiffusionPipeline
from transformers import logging

logging.set_verbosity_error()

### Check CUDA is available

In [None]:
print(torch.cuda.is_available())

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

## Base Model

In [None]:
model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"

pipe = DiffusionPipeline.from_pretrained(
    model_name_or_path,
    torch_dtype=torch.float16, 
    variant="fp16",
    use_safetensors=True, 
).to(device)

## Display_images

In [None]:
prompt = "a photo of an astronaut riding a horse on mars"
images = pipe(prompt).images[0]
images

## Quantitative Evaluation

In [None]:
prompts = [
    "a photo of an astronaut riding a horse on mars",
    "A high tech solarpunk utopia in the Amazon rainforest",
    # "A pikachu fine dining with a view to the Eiffel Tower",
    # "A mecha robot in a favela in expressionist style",
    # "an insect robot preparing a delicious meal",
    # "A small cabin on top of a snowy mountain in the style of Disney, artstation",
]

images = pipe(prompts, num_images_per_prompt=1, output_type="np").images

print(images.shape)
# (6, 512, 512, 3)

In [None]:
from torchmetrics.functional.multimodal import clip_score
from functools import partial

clip_score_fn = partial(clip_score, model_name_or_path="openai/clip-vit-base-patch16")

def calculate_clip_score(images, prompts):
    images_int = (images * 255).astype("uint8")
    clip_score = clip_score_fn(torch.from_numpy(images_int).permute(0, 3, 1, 2), prompts).detach()
    return round(float(clip_score), 4)

sd_clip_score = calculate_clip_score(images, prompts)
print(f"CLIP score: {sd_clip_score}")
# CLIP score: 35.7038