In [None]:
# Replace with git repo download, though I won't redownload
%cd kcg-ml-sd1p4/

/home/knox/Workspace/Work/kcg-ml-sd1p4


# Now that we have the data in a JSON file, we're going to move to generating images!

In [None]:
import os
from pathlib import Path
from typing import List

import torch
from torch import nn
from transformers import CLIPTokenizer, CLIPTextModel

from labml import lab, monit
from stable_diffusion.latent_diffusion import LatentDiffusion
from stable_diffusion.sampler.ddim import DDIMSampler
from stable_diffusion.sampler.ddpm import DDPMSampler
from stable_diffusion.util import load_model, save_images, set_seed


class Txt2Img:
    """
    ### Text to image class
    """
    model: LatentDiffusion

    def __init__(self, *,
                 checkpoint_path: Path,
                 sampler_name: str,
                 n_steps: int = 50,
                 ddim_eta: float = 0.0,
                 ):
        self.load(checkpoint_path, sampler_name, n_steps, ddim_eta)

    def load(self,
             checkpoint_path: Path,
             sampler_name: str,
             n_steps: int = 50,
             ddim_eta: float = 0.0,
             ):
        self.model = load_model(checkpoint_path)
        self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
        self.model.to(self.device)

        if sampler_name == 'ddim':
            self.sampler = DDIMSampler(self.model,
                                       n_steps=n_steps,
                                       ddim_eta=ddim_eta)
        elif sampler_name == 'ddpm':
            self.sampler = DDPMSampler(self.model)

    def unload(self):
        self.model = None
        self.sampler = None

    @torch.no_grad()
    def __call__(self, *,
                 dest_path: str,
                 batch_size: int = 3,
                 prompt: str,
                 h: int = 512, w: int = 512,
                 uncond_scale: float = 7.5,
                 noise_vector=None,
                 ):
        """
        :param dest_path: is the path to store the generated images
        :param batch_size: is the number of images to generate in a batch
        :param prompt: is the prompt to generate images with
        :param h: is the height of the image
        :param w: is the width of the image
        :param uncond_scale: is the unconditional guidance scale $s$. This is used for
            $\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$
        :param noise_vector: is the noise vector for stable diffusion
        """
        # Number of channels in the image
        c = 4
        # Image to latent space resolution reduction
        f = 8

        # Make a batch of prompts
        prompts = batch_size * [prompt]

        # AMP auto casting
        with torch.cuda.amp.autocast():
            # In unconditional scaling is not $1$ get the embeddings for empty prompts (no conditioning).
            if uncond_scale != 1.0:
                un_cond = self.model.get_text_conditioning(batch_size * [""])
            else:
                un_cond = None
            # Get the prompt embeddings
            cond = self.model.get_text_conditioning(prompts)
            # [Sample in the latent space](../sampler/index.html).
            # `x` will be of shape `[batch_size, c, h / f, w / f]`
            x = self.sampler.sample(cond=cond,
                                    shape=[batch_size, c, h // f, w // f],
                                    uncond_scale=uncond_scale,
                                    uncond_cond=un_cond,
                                    noise_vector=noise_vector)
                                    
            # Decode the image from the autoencoder
            images = self.model.autoencoder_decode(x)
            # Save images
            save_images(images, dest_path, 'txt_')

# functions for pipeline
@torch.no_grad()
def generate_text_embeddings(self, prompt, batch_size=4, uncond_scale=7.5):
    """
    :param prompt: is the prompt to generate images with
    """
    # Make a batch of prompts
    prompts = batch_size * [prompt]

    with torch.no_grad():
        # In unconditional scaling is not $1$ get the embeddings for empty prompts (no conditioning).
        if uncond_scale != 1.0:
            un_cond = self.model.get_text_conditioning(batch_size * [""])
        else:
            un_cond = None
        # Get the prompt embeddings
        cond = self.model.get_text_conditioning(prompts)

    # return the embeddings
    return cond, un_cond

@torch.no_grad()
def generate_latent_space(self, cond, un_cond, batch_size=4, uncond_scale=7.5, h=512, w=512, noise_vector=None):
    """
    :param prompt: is the prompt to generate images with
    """
    # Number of channels in the image
    c = 4
    # Image to latent space resolution reduction
    f = 8

    # AMP auto casting
    with torch.cuda.amp.autocast():
        # [Sample in the latent space](../sampler/index.html).
        # `x` will be of shape `[batch_size, c, h / f, w / f]`
        x = self.sampler.sample(cond=cond,
                                shape=[batch_size, c, h // f, w // f],
                                uncond_scale=uncond_scale,
                                uncond_cond=un_cond,
                                noise_vector=noise_vector)

    # return the embeddings
    return x

@torch.no_grad()
def generate_image(self, x):
    """
    :param prompt: is the prompt to generate images with
    """
    # AMP auto casting
    with torch.cuda.amp.autocast():
        # Decode the image from the [autoencoder](../model/autoencoder.html)
        image = self.model.autoencoder_decode(x)

    # return the embeddings
    return image

class CLIPTextEmbedder(nn.Module):
    """
    ## CLIP Text Embedder
    """
    def __init__(self, version: str = "openai/clip-vit-large-patch14", max_length: int = 77):
        """
        :param version: is the model version
        :param max_length: is the max length of the tokenized prompt
        """
        super().__init__()
        self.tokenizer = CLIPTokenizer.from_pretrained(version)
        self.transformer = CLIPTextModel.from_pretrained(version).eval()

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        # Move the transformer to the correct device
        self.transformer = self.transformer.to(self.device)

        self.max_length = max_length

    def forward(self, input_ids, **kwargs):
        """
        :param input_ids: the input token IDs of shape [batch_size, sequence_length]
        """
        return self.transformer(input_ids=input_ids, **kwargs).last_hidden_state

#Example usage
x = CLIPTextEmbedder()
out = x.forward(prompts=["space marines"])
print(torch.Tensor.size(out))
print(out)

In [None]:
# Create an instance of Txt2Img
txt2img = Txt2Img(checkpoint_path='/kcg-ml-sd1p4/input/models/sd-v1-4.ckpt', sampler_name='ddpm')

# Set the prompt and noise vector
prompt = "Generate a beautiful landscape"
noise_vector = torch.randn(1, 256).to(txt2img.device)  # Example noise vector

# Generate the image
txt2img(dest_path='/path/to/save/image.png', batch_size=1, prompt=prompt, noise_vector=noise_vector)

print("Image generated successfully!")


In [None]:
import json
import random

def generate_prompt(artist_name=None, artist_id=None):
    if artist_name is None:
        with open('data.json') as file:
            data = json.load(file)
            prompt_prefix = data['pre_prompt']
            artists = data['artists']
            artist_name = random.choice(list(artists.keys()))
            artist_id = artists[artist_name]

    prompt = f'{prompt_prefix} {artist_name}'
    return prompt

    