<a href="https://colab.research.google.com/github/karaage0703/stable-diffusion-colab-tools/blob/main/008_text_to_world.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Text-to-World

Text-to-World by using GA(Genetic Algorithm) like algorithm and Stable Diffusion

Reference notebook:  
https://github.com/fastai/diffusion-nbs/blob/master/Stable%20Diffusion%20Deep%20Dive.ipynb

About license of this notebook refer to reference notebook. 

## Setup

In [None]:
!pip install -qq --upgrade transformers diffusers ftfy

In [None]:
from base64 import b64encode

import numpy as np
import random
import torch
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
from huggingface_hub import notebook_login

# For video display:
from IPython.display import HTML
from matplotlib import pyplot as plt
from PIL import Image
from torch import autocast
from torchvision import transforms as tfms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer, logging

torch.manual_seed(1)
notebook_login()

# Supress some unnecessary warnings when loading the CLIPTextModel
logging.set_verbosity_error()

# Set device
torch_device = "cuda" if torch.cuda.is_available() else "cpu"

## Loading the models

This code (and that in the next section) comes from the [Huggingface example notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb). 

This will download and set up the relevant models and components we'll be using. Let's just run this for now and move on to the next section to check that it all works before diving deeper.

If you've loaded a pipeline, you can also access these components using `pipe.unet`, `pipe.vae` and so on.

In this notebook we aren't doing any memory-saving tricks - if you find yourself running out of GPU RAM, look at the pipeline code for inspiration with things like attention slicing, switching to half precision (fp16), keeping the VAE on the CPU and other modifications.

In [None]:
# Load the autoencoder model which will be used to decode the latents into image space. 
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")

# Load the tokenizer and text encoder to tokenize and encode the text. 
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")

# The UNet model for generating the latents.
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")

# The noise scheduler
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)

# To the GPU we go!
vae = vae.to(torch_device)
text_encoder = text_encoder.to(torch_device)
unet = unet.to(torch_device);

## Text to Image

In [None]:
height = 512                        # default height of Stable Diffusion
width = 512                         # default width of Stable Diffusion
num_inference_steps = 20            # Number of denoising steps
guidance_scale = 7.5                # Scale for classifier-free guidance
batch_size = 1

In [None]:
def text_emb_to_image(text_embeddings, seed):
    uncond_input = tokenizer(
        [""] * batch_size, padding="max_length", max_length=77, return_tensors="pt"
    )
    with torch.no_grad():
        uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

    # Prep Scheduler
    scheduler.set_timesteps(num_inference_steps)

    generator = torch.manual_seed(seed) 
    # Prep latents
    latents = torch.randn(
      (batch_size, unet.in_channels, height // 8, width // 8),
      generator=generator,
    )
    latents = latents.to(torch_device)
    latents = latents * scheduler.init_noise_sigma # Scaling (previous versions did latents = latents * self.scheduler.sigmas[0]

    # Loop
    with autocast("cuda"):
        for i, t in tqdm(enumerate(scheduler.timesteps)):
            # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
            latent_model_input = torch.cat([latents] * 2)
            sigma = scheduler.sigmas[i]
            # Scale the latents (preconditioning):
            # latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) # Diffusers 0.3 and below
            latent_model_input = scheduler.scale_model_input(latent_model_input, t)

            # predict the noise residual
            with torch.no_grad():
                noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

            # perform guidance
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            # compute the previous noisy sample x_t -> x_t-1
            # latents = scheduler.step(noise_pred, i, latents)["prev_sample"] # Diffusers 0.3 and below
            latents = scheduler.step(noise_pred, t, latents).prev_sample

    # scale and decode the image latents with vae
    latents = 1 / 0.18215 * latents
    with torch.no_grad():
        image = vae.decode(latents).sample

    # Display
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    images = (image * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]
    return pil_images[0]

In [None]:
def prompt_to_text_emb(prompt):
    text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
    with torch.no_grad():
        text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
    max_length = text_input.input_ids.shape[-1]
    uncond_input = tokenizer(
        [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
    )

    return text_embeddings

## Gene

### Setup
Create Gene class

In [None]:
class Gene:
    def __init__(self, text_emb, seed):
        self.text_embeddings = text_emb
        self.seed = seed
        self.image = None

    def get_image(self):
        if self.image == None:
            self.image = text_emb_to_image(self.text_embeddings.to(torch_device), self.seed)

        return self.image

In [None]:
def cross_over(targets):
    samples = random.sample(targets, 2)
    cross_over_point = random.randint(2, 75)
    sample_0 = samples[0].text_embeddings.cpu().detach().numpy().copy()
    sample_1 = samples[1].text_embeddings.cpu().detach().numpy().copy()

    sample_child = np.concatenate([sample_0[0][0:cross_over_point], sample_1[0][cross_over_point:77]])
    sample_child = np.expand_dims(sample_child, axis = 0)
    sample_child = torch.from_numpy(sample_child.astype(np.float32)).clone()

    prob = random.random()
    if prob < 0.1:
        seed = random.randint(0, 1000)
    else:
        seed = samples[0].seed

    return Gene(sample_child, seed)

In [None]:
def cross_over(genes):
    sample = random.sample(genes, 2)
    cross_over_point = random.randint(2, 75)
    sample_0 = sample[0].text_embeddings.cpu().detach().numpy().copy()
    sample_1 = sample[1].text_embeddings.cpu().detach().numpy().copy()

    sample_child = np.concatenate([sample_0[0][0:cross_over_point], sample_1[0][cross_over_point:77]])
    sample_child = np.expand_dims(sample_child, axis = 0)
    sample_child = torch.from_numpy(sample_child.astype(np.float32)).clone()


    prob = random.random()
    if prob < 0.1:
        seed = random.randint(0, 1000)
    else:
        seed = sample[0].seed

    genes.append(Gene(sample_child, seed))

In [None]:
def display_genes(targets):
    plt.figure(figsize=(16, 9))
    plt.subplots_adjust(hspace=0.5)

    for i, target in enumerate(targets):
        colum_numb = 5
        plt.subplot(int(len(targets) / colum_numb) + 1, colum_numb, i + 1)

        plt.imshow(target.get_image())
        plt.title(i)
        plt.axis('off')

    _ = plt.suptitle('display gene images')

### Initialize

Initialize gene

In [None]:
genes = []

In [None]:
genes.append(Gene(prompt_to_text_emb('adam'), random.randint(0, 1000)))
genes.append(Gene(prompt_to_text_emb('eve'), random.randint(0, 1000)))

In [None]:
display_genes(genes)

## breed gene

In [None]:
genes.append(cross_over(genes))

In [None]:
display_genes(genes)

### create gene

In [None]:
genes.append(Gene(prompt_to_text_emb('karaage'), random.randint(0, 1000)))

In [None]:
genes.append(Gene(prompt_to_text_emb('apple'), random.randint(0, 1000)))

In [None]:
genes.append(Gene(prompt_to_text_emb('rock'), random.randint(0, 1000)))

set cross over trial number

In [None]:
trial_numb = 13

In [None]:
for i in range(trial_numb):
    genes.append(cross_over(genes))

In [None]:
display_genes(genes)

delete gene

In [None]:
del genes[0:3]

check gene

In [None]:
gene_numb = 24

In [None]:
print(genes[gene_numb].seed)
print(genes[gene_numb].text_embeddings)

In [None]:
plt.imshow(genes[gene_numb].get_image())