In [1]:
from diffusers import DDPMScheduler,UNet2DConditionModel,AutoencoderKL
from tqdm import tqdm
import cv2
import numpy as np

In [2]:
from transformers import CLIPTextModel, CLIPTokenizer
import torch

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize CLIP tokenizer and text encoder
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)

def get_text_embeddings(text):
    # Tokenize the input text
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=77).to(device)

    # Pass through the text encoder
    with torch.no_grad():
        outputs = text_encoder(**inputs)

    # Extract the last hidden state
    last_hidden_state = outputs.last_hidden_state

    return last_hidden_state




In [3]:
from diffusers import AutoencoderKL

# Load a pretrained VAE
# vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(device)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)
# vae.load_state_dict(torch.load('/kaggle/working/vae_step.pth', weights_only=True))
vae.eval()  # Set the VAE in evaluation mode

def get_image_latent_dist(image):
    with torch.no_grad():
        latent_dist = vae.encode(image).latent_dist.sample()
    return latent_dist


In [4]:
model = UNet2DConditionModel(
    sample_size=32,
    in_channels=4,
    out_channels=4,
    layers_per_block=2,
    block_out_channels=(64, 128, 256),
    down_block_types=(
        "DownBlock2D",
        "CrossAttnDownBlock2D",
        "CrossAttnDownBlock2D",

    ),

    up_block_types=(
        "CrossAttnUpBlock2D",
        "CrossAttnUpBlock2D",
        "UpBlock2D",
    ),
    cross_attention_dim=768
)
model = model.to(device)

model.load_state_dict(torch.load('unet_new.pth', weights_only=True))
model.eval()
print('model loaded')

  return self.fget.__get__(instance, owner)()


model loaded


In [5]:
def sample_image_generation(model, vae, text,num_inference_steps,num_samples=4):
    noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
    noise_scheduler.set_timesteps(num_inference_steps=num_inference_steps)
    model.eval()
    images_list = []
    with torch.no_grad():
        text_embeddings = get_text_embeddings(text)
        for i in range(num_samples):
            noise = torch.randn(1, 4, 32, 32).to(device)

            for t in tqdm(noise_scheduler.timesteps, position=0, leave=True):
                noise_pred = model(noise, t, text_embeddings).sample
                noise = noise_scheduler.step(noise_pred, t, noise).prev_sample

            # Decode the noise to generate images
            noise = noise / 0.18215  # Undo the latent scaling
            images = vae.decode(noise).sample
            
            # images = vae(images).sample
            img = images.squeeze().cpu().permute(1, 2, 0).numpy()
            img = (img + 1) / 2 
            images_list.append((img*255).astype('uint8'))
    # compines images by nxm in one image
    row1 = np.concatenate(images_list[:2],axis=0)
    row2 = np.concatenate(images_list[2:],axis=0)
    final_image = np.concatenate([row1,row2],axis=1)
    return final_image

In [6]:
img = sample_image_generation(model,vae,"girl",10)

100%|██████████| 10/10 [00:01<00:00,  9.42it/s]
100%|██████████| 10/10 [00:00<00:00, 14.57it/s]
100%|██████████| 10/10 [00:00<00:00, 15.29it/s]
100%|██████████| 10/10 [00:00<00:00, 13.97it/s]


In [7]:
import gradio as gr

# Modify the generate function to accept both the prompt and number of steps
def generate(prompt, num_inference_steps):
    image = sample_image_generation(model, vae, prompt, num_inference_steps)  # Pass the prompt and number of steps
    return image

# Gradio Interface
gr_interface = gr.Interface(
    fn=generate,  # Function to generate images
    inputs=[
        gr.Textbox(label="Enter your text prompt"),  # Input prompt as free text
        gr.Slider(label="Number of Inference Steps", minimum=1, maximum=1000, step=1, value=50)  # Slider for steps
    ],
    outputs=gr.Image(label="Generated Image"),  # Output the generated image
    title="Text-to-Image Generator",
    description="Enter a text description, and adjust the number of inference steps for image generation."
)

# Launch the interface
gr_interface.launch()


Running on local URL:  http://127.0.0.1:7860
IMPORTANT: You are using gradio version 3.50.0, however version 4.44.1 is available, please upgrade.
--------

To create a public link, set `share=True` in `launch()`.




100%|██████████| 13/13 [00:01<00:00, 10.26it/s]
100%|██████████| 13/13 [00:00<00:00, 13.49it/s]
100%|██████████| 13/13 [00:01<00:00, 12.28it/s]
100%|██████████| 13/13 [00:01<00:00, 12.54it/s]
100%|██████████| 13/13 [00:01<00:00, 12.39it/s]
100%|██████████| 13/13 [00:01<00:00, 12.27it/s]
100%|██████████| 13/13 [00:01<00:00, 11.47it/s]
100%|██████████| 13/13 [00:01<00:00, 11.76it/s]
100%|██████████| 13/13 [00:01<00:00, 11.67it/s]
100%|██████████| 13/13 [00:01<00:00, 11.28it/s]
100%|██████████| 13/13 [00:01<00:00, 12.41it/s]
100%|██████████| 13/13 [00:01<00:00, 11.49it/s]
100%|██████████| 13/13 [00:01<00:00, 10.92it/s]
100%|██████████| 13/13 [00:01<00:00, 11.66it/s]
100%|██████████| 13/13 [00:01<00:00, 11.60it/s]
100%|██████████| 13/13 [00:01<00:00, 12.30it/s]
100%|██████████| 13/13 [00:01<00:00, 11.14it/s]
100%|██████████| 13/13 [00:01<00:00, 11.70it/s]
100%|██████████| 13/13 [00:01<00:00, 11.68it/s]
100%|██████████| 13/13 [00:01<00:00, 10.62it/s]
100%|██████████| 13/13 [00:01<00:00, 10.