!pip install diffusers
!pip install fastprogress

In [None]:
import torch
from math import floor, ceil, sqrt
from PIL import Image, ImageDraw, ImageFont
import textwrap
from fastprogress import progress_bar, master_bar
from cool_pipeline import CoolStableDiffusionPipeline
import random
import os
import pathlib
import numpy as np
from pathlib import Path
import gradio as gr
from IPython.display import display

gr.close_all()

In [None]:
def image_grid(imgs):
    num = len(imgs)
    rows = floor(sqrt(num))
    cols = ceil(num/rows)
    if cols > 5:
        cols = 5
        rows = ceil(num/cols)
    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

def random_string(n):
    X = "abcdefghijklmnopqrstuvwxy0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
    return "".join([random.choice(X) for _ in range(n)])

def get_image_and_mask(img):
    a = np.array(img).astype(np.uint8)
    assert(a.shape[2] == 4)
    img, mask = a[:,:,:3], a[:,:,3:]
    mask = mask.repeat(3,axis=2)
    return img, mask

class EnhancedGenerator:
    def __init__(self, pipe, height=512, width=768, savedir = "saved"):
        self.pipe = pipe
        self.h,self.w = height,width
        self.saved = []
        
        self.savedir = Path(savedir)
        (self.savedir/"pth").mkdir(exist_ok=True,parents=True)
    
    def gen(self, prompt, lat=None, img=None, mb=None, **kwargs):
        with torch.autocast("cuda"):
            P = self.pipe(prompt, latents=lat, height=self.h, width=self.w, 
                          initial_image=img, masterbar=mb, **kwargs)
        P['prompt'] = prompt
        P['index'] = len(self.saved)
        self.saved.append(P)
        return P
    
    def with_prompt(self, img, with_idx = True):
        prompt = img['prompt']
        idx = img['index']
        img = img['sample'][0].copy()
        w,h = img.width, img.height
        drawer = ImageDraw.Draw(img)
        font = ImageFont.truetype("arial.ttf", 18)
        
        p = textwrap.wrap(prompt,width=50)
        for i, t in enumerate(p):
            drawer.text((4, 4 + i*21), t, font=font, fill=(255, 0, 0, 200))
            
        drawer.text((w//2, h-40), str(idx), font=font, fill=(255, 255, 0))
        return img
    
    def generate_from_scratch(self, prompt, num=6):
        mb = master_bar(range(num))
        X = [self.gen(prompt,mb=mb) for _ in mb]
        
        return image_grid([self.with_prompt(x) for x in X])
    
    def generate_variants(self, i, noise=0.4, num=6, prompt = None):
        I = self.saved[i]
        lat = I['latents']
        if prompt is None: prompt = I['prompt']
        lats = [lat + torch.randn_like(lat)*noise for _ in range(num)]
        lats = [(l-l.mean())/l.std() for l in lats]
        mb = master_bar(lats)
        X = [self.gen(prompt,lat=l,mb=mb) for l in mb]
        return image_grid([self.with_prompt(x) for x in [I]+X])
    
    def generate_with_prompts(self, i, prompts):
        I = self.saved[i]
        lat = I['latents']
        mb = master_bar(prompts)
        X = [self.gen(p, lat=lat, mb=mb) for p in mb]
        return image_grid([self.with_prompt(x) for x in X])
    
    def generate_suffixes(self, i, suffixes):
        prompt = self.saved[i]['prompt']
        prompts = [prompt + ", " + s for s in suffixes]
        
        return self.generate_with_prompts(i, prompts)
    
    def interpolate(self, i, j, num=20, prompt=None):
        if prompt is None: prompt = self.saved[i]['prompt']
        Li, Lj = self.saved[i]['latents'], self.saved[j]['latents']
        L = [torch.lerp(Li,Lj,p.item()) for p in torch.linspace(0,1,num)]
        L = [(l-l.mean())/l.std() for l in L]
        mb = master_bar(L)
        X = [self.gen(prompt,lat=l,mb=mb) for l in mb]
        return image_grid([self.with_prompt(x) for x in X])
    
    def modify_image(self, img, prompt=None, num=6, strength=0.7, num_steps=1):
        if isinstance(img,str) or isinstance(img,Path): img = Image.open(img)
        if isinstance(img,int):
            if prompt is None:
                prompt = self.saved[img]['prompt']
            img = self.saved[img]['sample'][0]
        X = [self.gen(prompt, img=img, strength=strength) for _ in progress_bar(range(num))]
        for i in range(1,num_steps):
            print(f"Doing step {i+1}/{num_steps}")
            X = [self.gen(prompt, img=x['sample'][0], strength=strength) for x in progress_bar(X)]
        return image_grid([self.with_prompt(x) for x in X])
    
    
    
    def inpaint_gui(self, img = None, num_steps=80):
        prompt = ""
        if isinstance(img,int):
            prompt = self.saved[img]['prompt']
            img = self.saved[img]['sample'][0]
        if isinstance(img,str) or isinstance(img,Path):
            img = Image.open(img)
        with gr.Blocks() as block:
            
            def _inpaint_gui_out(self, imgmask, num, prompt):
                out = self.inpaint_with_mask(prompt, imgmask['image'], imgmask['mask'], num, num_steps=num_steps)
                display(out)
                block.clear()
                block.close()
                return out
            
            with gr.Column():
                txt = gr.Textbox(label="Prompt", value=prompt)
                inp = gr.Image(value=img, tool='sketch')
                sld = gr.Slider(minimum=1,maximum=20,value=1,step=1,label="How many to generate")

                btn = gr.Button("Submit")
                btn.click(fn=lambda *args: _inpaint_gui_out(self, *args), inputs=[inp, sld, txt], outputs=None)
        block.launch(server_port=3123)
    
    def inpaint_with_mask(self, prompt, img, mask, num=1, num_steps=80):
        if isinstance(img,np.ndarray):
            img = Image.fromarray(img)
        if isinstance(mask,np.ndarray):
            mask = Image.fromarray(mask)
        
        mb = master_bar(range(num))
        for _ in mb:
            with torch.autocast("cuda"):
                P = self.pipe.inpaint(prompt, init_image=img, mask_image=mask, num_inference_steps=num_steps, masterbar=mb)
            P['prompt'] = prompt
            P['index'] = len(self.saved)
            self.saved.append(P)
        
        return self[-num:]
    
    def get_images(self):
        return [x['sample'][0] for x in self.saved]
    
    def view_img(self,i,with_prompt=False):
        if with_prompt:
            return self.with_prompt(self.saved[i])
        return self.saved[i]['sample'][0]
    
    def __getitem__(self, i):
        if isinstance(i,int):
            return self.view_img(i)
        else:
            return image_grid([self.with_prompt(a) for a in self.saved[i]])
    
    def __len__(self):
        return len(self.saved)
    
    def save(self, i):
        if isinstance(i,list):
            for j in i:
                self.save(j)
        elif isinstance(i,int):
            I = self.saved[i]
            prompt = I['prompt']
            fname = f"{prompt}_{i}_{random_string(4)}"
            torch.save(I,self.savedir/f'pth/{fname}.pth')
            img = I['sample'][0]
            img.save(self.savedir/f"{fname}.png")
        else:
            raise "Nope, only lists and ints"
    
    def load(self, I):
        if type(I) == str or type(I) == Path:
            I = torch.load(I)
        I['index'] = len(self.saved)
        self.saved.append(I)
    
    def load_all(self, keyword=None):
        for f in os.scandir(self.savedir/"pth"):
            if f.name.split('.')[-1] == 'pth':
                if keyword is None or keyword.lower() in f.path.lower():
                    self.load(f.path)

In [None]:
with open("TOKEN") as tok:
    TOKEN = tok.readline()
    print(TOKEN)

In [None]:
pipe = CoolStableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4",
                                                   revision="fp16",
                                                   torch_dtype = torch.float16,
                                                   use_auth_token=TOKEN)

#pipe = CoolStableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")

pipe.to('cuda');

In [None]:
G = EnhancedGenerator(pipe)

In [None]:
G.generate_from_scratch("Fantasy", num=1)