In [1]:
from diffusers import DDPMScheduler,UNet2DConditionModel,AutoencoderKL
from tqdm import tqdm
import cv2
import numpy as np

In [None]:
from transformers import CLIPTextModel, CLIPTokenizer
import torch

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize CLIP tokenizer and text encoder
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)

def get_text_embeddings(text):
    # Tokenize the input text
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=77).to(device)

    # Pass through the text encoder
    with torch.no_grad():
        outputs = text_encoder(**inputs)

    # Extract the last hidden state
    last_hidden_state = outputs[0]

    return last_hidden_state


In [3]:
from diffusers import AutoencoderKL

# Load a pretrained VAE
# vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(device)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)
# vae.load_state_dict(torch.load('/kaggle/working/vae_step.pth', weights_only=True))
vae.eval()  # Set the VAE in evaluation mode

def get_image_latent_dist(image):
    with torch.no_grad():
        latent_dist = vae.encode(image).latent_dist.sample()
    return latent_dist


In [None]:
model = UNet2DConditionModel(
    sample_size=32,
    in_channels=4,
    out_channels=4,
    layers_per_block=2,
    block_out_channels=(64, 128, 256),
    down_block_types=(
        "DownBlock2D",
        "CrossAttnDownBlock2D",
        "CrossAttnDownBlock2D",

    ),

    up_block_types=(
        "CrossAttnUpBlock2D",
        "CrossAttnUpBlock2D",
        "UpBlock2D",
    ),
    cross_attention_dim=768
)
model = model.to(device)

model.load_state_dict(torch.load('unet_new.pth', weights_only=True))
model.eval()
print('model loaded')

In [5]:
def sample_image_generation(model, vae, text, num_inference_steps, num_samples=1, guidance_scale=10):
    noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
    noise_scheduler.set_timesteps(num_inference_steps=num_inference_steps)
    model.eval()
    images_list = []
    with torch.no_grad():
        # Get text embeddings (conditional)
        text_embeddings = get_text_embeddings(text)
        
        # Create unconditional embeddings (empty/neutral prompt)
        unconditional_text_embeddings = torch.zeros_like(text_embeddings).to(device)
        
        for i in range(num_samples):
            noise = torch.randn(1, 4, 32, 32).to(device)

            for t in tqdm(noise_scheduler.timesteps, position=0, leave=True):
                # Conditional prediction (with text)
                noise_pred_cond = model(noise, t, text_embeddings).sample
                
                # Unconditional prediction (no text)
                noise_pred_uncond = model(noise, t, unconditional_text_embeddings).sample
                
                # Classifier-Free Guidance: Combine both predictions
                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
                
                # Perform one step in the noise scheduler
                noise = noise_scheduler.step(noise_pred, t, noise).prev_sample

            # Decode the noise to generate images
            noise = noise / 0.18215  # Undo the latent scaling
            images = vae.decode(noise).sample
            # images = vae(images).sample
            
            # Convert the image to numpy format for visualization
            img = images.squeeze().cpu().permute(1, 2, 0).numpy()
            img = (img + 1) / 2  # Normalize to [0, 1]
            images_list.append((img * 255).astype('uint8'))  # Scale to [0, 255] for visualization
            
    return cv2.resize(images_list[0],(320,320))


In [None]:
import gradio as gr

# Modify the generate function to accept both the prompt and number of steps
def generate(prompt, num_inference_steps):
    image = sample_image_generation(model, vae, prompt, num_inference_steps)  # Pass the prompt and number of steps
    return image

# Gradio Interface
gr_interface = gr.Interface(
    fn=generate,  # Function to generate images
    inputs=[
        gr.Textbox(label="Enter your text prompt"),  # Input prompt as free text
        gr.Slider(label="Number of Inference Steps", minimum=1, maximum=1000, step=1, value=11)  # Slider for steps
    ],
    outputs=gr.Image(label="Generated Image"),  # Output the generated image
    title="Text-to-Image Generator",
    description='''example usage: This female is blond Hair, Heavy Makeup, No Beard, Smiling, Young.'''
)

# Launch the interface
gr_interface.launch()
