In [1]:
import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_layers2img import StableDiffusionPipelineLayers2ImageV1
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Text2Image():
    model_to_name = {
        "SDV15": "runwayml/stable-diffusion-v1-5",
        "stablediffusion": "runwayml/stable-diffusion-v1-5",
        "stable": "runwayml/stable-diffusion-v1-5",
        "RVV13": "realisticVisionV13",
        "realistic": "realisticVisionV13",
        "realisticvision": "realisticVisionV13",
        "real": "realisticVisionV13",
        "dream": "Lykon/DreamShaper",
        "dreamshaper": "Lykon/DreamShaper",
    }
    
    def __init__(self, model="SDV15", scheduler="default", imgperprompt=4, steps=32, device=None):
        super().__init__()
        self.scheduler_name = scheduler
        if device is None:
            self.device = torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu")
        else:
            self.device = device
        self.load_pipe(model, scheduler=self.scheduler_name, device=self.device)
        self.imgperprompt = imgperprompt
        self.steps = steps
    
    def load_pipe(self, model="SDV15", scheduler="original", device=None):
        self.pipe = StableDiffusionPipeline.from_pretrained(self.model_to_name[model])
        self.pipe = self.pipe.to(device)
        if scheduler == "original" or scheduler is None or scheduler == "default":
            pass
        elif scheduler == "dpm":
            oldsched = self.pipe.scheduler
            self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config)
            assert torch.allclose(oldsched.alphas_cumprod, self.pipe.scheduler.alphas_cumprod)
        
    def image_grid(self, imgs, rows, cols):
        assert len(imgs) == rows * cols

        w, h = imgs[0].size
        grid = Image.new('RGB', size=(cols * w, rows * h))
        grid_w, grid_h = grid.size

        for i, img in enumerate(imgs):
            grid.paste(img, box=(i % cols * w, i // cols * h))
        return grid
    
    def generate(self, prompt, negprompt="", imgperprompt=None, steps=None, seed=-1, cfg=7.5):
        imgperprompt = self.imgperprompt if imgperprompt is None else imgperprompt
        steps = self.steps if steps is None else steps
        generator = torch.Generator(device=self.device)
        generator.manual_seed(seed)
        with torch.autocast("cuda"):
            images = self.pipe(prompt, negative_prompt=negprompt, num_inference_steps=steps, guidance_scale=cfg, eta=0., 
                          generator=generator, num_images_per_prompt=imgperprompt, output_type="pil").images
        return images
        

In [3]:
tti = Text2Image("dream", scheduler="dpm")

Fetching 15 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [00:00<00:00, 12117.60it/s]
The config attributes {'scaling_factor': 0.18215} were passed to AutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.
The config attributes {'class_embed_type': None, 'mid_block_type': 'UNetMidBlock2DCrossAttn', 'resnet_time_scale_shift': 'default', 'upcast_attention': False} were passed to UNet2DConditionModel, but are not expected and will be ignored. Please verify your config.json configuration file.


In [4]:
prompt = "a photo of one red tiger eye, close-up, detailed fur"
negprompt = "disfigured, deformed iris, nude, nsfw, watermark, istock"

images = tti.generate(prompt, negprompt)
display(image_grid(images, 1, len(images)))



100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:10<00:00,  3.10it/s]


NameError: name 'image_grid' is not defined