In [None]:
# install the necessary dependencies and libraries
!pip install --upgrade diffusers transformers torch accelerate matplotlib datasets torchvision

import torch
from diffusers import StableDiffusionPipeline
from datasets import load_dataset
import random
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import gc

# clean up memory and reset CUDA cache
def cleanup_memory():
    gc.collect()
    torch.cuda.empty_cache()
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()

# load the Stable Diffusion model
def load_model(model_id):
    pipeline = StableDiffusionPipeline.from_pretrained(model_id)
    pipeline = pipeline.to("cuda" if torch.cuda.is_available() else "cpu")
    return pipeline

# generate the images
def generate_images(pipeline, prompts, num_images_per_prompt=1, num_inference_steps=50, guidance_scale=7.5):
    images = []
    for prompt in prompts:
        batch = pipeline(
            prompt,
            num_images_per_prompt=num_images_per_prompt,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            output_type="pil"
        )
        images.extend(batch.images)
        cleanup_memory()
    return images

# display the images
def display_images(images, prompts):
    rows = len(images)
    fig, axs = plt.subplots(rows, 1, figsize=(15, 5*rows))

    if rows == 1:
        axs = [axs]

    for img, ax, prompt in zip(images, axs, prompts):
        ax.imshow(img)
        ax.set_title(prompt, fontsize=10)
        ax.axis('off')

    plt.tight_layout()
    plt.show()

# execute the script
if __name__ == "__main__":
    try:
        pipeline = load_model('stabilityai/stable-diffusion-2-1')

        prompts = [
            "A cute robot with wheels wearing a red fedora hat, searching for a red fedora hat on the floor ",

        ]
        num_images_per_prompt = 1

        generated_images = generate_images(pipeline, prompts, num_images_per_prompt, num_inference_steps=50, guidance_scale=7.5)
        display_images(generated_images, prompts)

    except Exception as e:
        print(f"An error occurred: {str(e)}")
    finally:
        cleanup_memory()