## Install Dependencies 

In [None]:
!wget https://r2-public-worker.drysys.workers.dev/sd-v1-4-full-ema.ckpt
!git clone https://github.com/CompVis/stable-diffusion.git
%cd stable-diffusion
!wget https://raw.githubusercontent.com/justinpinkney/stable-diffusion/main/requirements.txt
!pip install -r requirements.txt
!pip install --upgrade pytorch-lightning
!apt-get update -y && apt-get install libgl1 -y && apt-get install libglib2.0-0 -y

In [None]:
!pip3 install ftfy regex tqdm timm==0.4.12 fairscale==0.4.4
!pip3 install git+https://github.com/openai/CLIP.git
!git clone https://github.com/pharmapsychotic/clip-interrogator.git
!git clone https://github.com/salesforce/BLIP

In [34]:
# Silly hack you need to do:
# apt install vim -y
# vim /opt/conda/lib/python3.7/site-packages/transformers/generation_utils.py
# Delete line 1146 (validating kwargs)

## !! Restart your notebook here !!

## Load Models

In [1]:
#@title Setup
%cd BLIP

import clip
import gc
import numpy as np
import os
import pandas as pd
import requests
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF

from IPython.display import display
from PIL import Image
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from models.blip import blip_decoder

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

blip_image_eval_size = 384
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'        
blip_model = blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base')
blip_model.eval()
blip_model = blip_model.to(device)

def generate_caption(pil_image):
    gpu_image = transforms.Compose([
        transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
    ])(pil_image).unsqueeze(0).to(device)

    with torch.no_grad():
        caption = blip_model.generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5)
    return caption[0]

def load_list(filename):
    with open(filename, 'r', encoding='utf-8', errors='replace') as f:
        items = [line.strip() for line in f.readlines()]
    return items

def rank(model, image_features, text_array, top_count=1):
    top_count = min(top_count, len(text_array))
    text_tokens = clip.tokenize([text for text in text_array]).cuda()
    with torch.no_grad():
        text_features = model.encode_text(text_tokens).float()
    text_features /= text_features.norm(dim=-1, keepdim=True)

    similarity = torch.zeros((1, len(text_array))).to(device)
    for i in range(image_features.shape[0]):
        similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
    similarity /= image_features.shape[0]

    top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)  
    return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]

def interrogate(image, models):
    caption = generate_caption(image)
    if len(models) == 0:
        print(f"\n\n{caption}")
        return

    table = []
    bests = [[('',0)]]*5
    for model_name in models:
        print(f"Interrogating with {model_name}...")
        model, preprocess = clip.load(model_name)
        model.cuda().eval()

        images = preprocess(image).unsqueeze(0).cuda()
        with torch.no_grad():
            image_features = model.encode_image(images).float()
        image_features /= image_features.norm(dim=-1, keepdim=True)

        ranks = [
            rank(model, image_features, mediums),
            rank(model, image_features, ["by "+artist for artist in artists]),
            rank(model, image_features, trending_list),
            rank(model, image_features, movements),
            rank(model, image_features, flavors, top_count=3)
        ]

        for i in range(len(ranks)):
            confidence_sum = 0
            for ci in range(len(ranks[i])):
                confidence_sum += ranks[i][ci][1]
            if confidence_sum > sum(bests[i][t][1] for t in range(len(bests[i]))):
                bests[i] = ranks[i]

        row = [model_name]
        for r in ranks:
            row.append(', '.join([f"{x[0]} ({x[1]:0.1f}%)" for x in r]))

        table.append(row)

        del model
        gc.collect()
    #display(pd.DataFrame(table, columns=["Model", "Medium", "Artist", "Trending", "Movement", "Flavors"]))

    flaves = ', '.join([f"{x[0]}" for x in bests[4]])
    medium = bests[0][0][0]
    if caption.startswith(medium):
        out = f"{caption} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}"
        print(f"\n\n{caption} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}")
    else:
        out = f"{caption}, {medium} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}"
        print(f"\n\n{caption}, {medium} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}")
    return caption, out

data_path = "../clip-interrogator/data/"

artists = load_list(os.path.join(data_path, 'artists.txt'))
flavors = load_list(os.path.join(data_path, 'flavors.txt'))
mediums = load_list(os.path.join(data_path, 'mediums.txt'))
movements = load_list(os.path.join(data_path, 'movements.txt'))

sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble', 'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', 'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central']
trending_list = [site for site in sites]
trending_list.extend(["trending on "+site for site in sites])
trending_list.extend(["featured on "+site for site in sites])
trending_list.extend([site+" contest winner" for site in sites])

%cd ..


/workspace/BLIP


Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

  0%|          | 0.00/855M [00:00<?, ?B/s]

load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth
/workspace


In [None]:
%cd stable-diffusion
# Waifu Diffusion
#ckpt_file = "wd-v1-2-full-ema.ckpt"

#Stable Diffusion
ckpt_file = "sd-v1-4-full-ema.ckpt"

#Poke Diffusion
#ckpt_file = "pokediffusion_epoch_10_pruned.ckpt"

from io import BytesIO
import os
from contextlib import nullcontext
from einops import repeat
import PIL

import fire
import numpy as np
import torch
from einops import rearrange
from omegaconf import OmegaConf
from PIL import Image
from torch import autocast
from torchvision import transforms
import requests

from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.util import instantiate_from_config
from pytorch_lightning import seed_everything

def load_img(image):
    w, h = image.size
    w, h = map(lambda x: x - x % 64, (w, h))  # resize to integer multiple of 32
    image = image.resize((w, h), resample=PIL.Image.LANCZOS)
    image = np.array(image).astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return 2.*image - 1.

def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.cuda()
    model.eval()
    return model

config = OmegaConf.load("configs/stable-diffusion/v1-inference.yaml")
model = load_model_from_config(config, f"../{ckpt_file}")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)

sampler = DDIMSampler(model)

start_code = None

%cd ..

sample_path = "outs"
os.makedirs(sample_path, exist_ok=True)

## Set up parameters

In [35]:
#@title Generate!

#@markdown 

#@markdown #####**Base Settings:**

prompt = "A psychedelic being living in an extradimensional reality, in the style of wlop, illustration, epic, fantasy, hyper detailed, smooth, unreal engine, sharp focus, ray tracing, physically based rendering, renderman, beautiful" #@param {type:"string"}
image_path_or_url = "https://scontent-sjc3-1.xx.fbcdn.net/v/t1.15752-9/308282132_472221518278262_4631038285071535282_n.png?_nc_cat=111&ccb=1-7&_nc_sid=ae9488&_nc_ohc=9JUwOq6AGrEAX8hOOcz&_nc_ht=scontent-sjc3-1.xx&oh=03_AVJg2VM4JOsFtlpDoGN3zqnLcmGRcIpFApdj3YAsFdaCJA&oe=63597747" #@param {type:"string"}
seed = 123 #@param {type:"number"}
ddim_steps = 50

if str(image_path_or_url).startswith('http://') or str(image_path_or_url).startswith('https://'):
    image = Image.open(requests.get(image_path_or_url, stream=True).raw).convert('RGB')
else:
    image = Image.open(image_path_or_url).convert('RGB')

#@markdown #####**Img2Img settings:**

img2img_type = "clip-interrogator" #@param ["basic", "blip", "clip-interrogator"]
img_strength = 0.1 #@param {type:"slider", min:0, max:1, step:0.01}
img_strength = 1 - img_strength 
blip_strength = 0.4 #@param {type:"slider", min:0, max:1, step:0.01}


## Generate! 

In [16]:
blip_prompt, clip_inter_prompt = interrogate(image, models=["ViT-L/14"])

Interrogating with ViT-L/14...


100%|███████████████████████████████████████| 890M/890M [02:49<00:00, 5.49MiB/s]




a picture of a cat in a space suit, a stock photo by Pogus Caesar, featured on reddit, space art, futuristic, sci-fi, stock photo


In [31]:
img_cond = None
if img2img_type == "blip":
    img_cond = model.get_learned_conditioning([blip_prompt])
elif img2img_type == "clip-interrogator":
    img_cond = model.get_learned_conditioning([clip_inter_prompt])

In [6]:
import uuid

In [None]:
init_image = load_img(image).to("cuda")
init_image = repeat(init_image, '1 ... -> b ...', b=1)
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image))

batch_size = 1
data = [batch_size * [prompt]]


for s in range(10):
    t_enc = int(img_strength * ddim_steps)
    incr_seed = s + seed
    sampler.make_schedule(ddim_num_steps=50, ddim_eta=0.0, verbose=False)

    seed_everything(incr_seed)
    precision_scope = autocast
    with torch.no_grad():
            with precision_scope("cuda"):
                with model.ema_scope():
                    all_samples = list()
                    for prompts in data:
                        uc = model.get_learned_conditioning(batch_size * [""])
                        if isinstance(prompts, tuple):
                            prompts = list(prompts)
                        c = model.get_learned_conditioning(prompts)

                        if img_cond != None and blip_strength >= 0:
                            c = torch.lerp(c, img_cond, blip_strength)

                        # encode (scaled latent)
                        z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))
                        # decode it
                        samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=7.5,
                                                  unconditional_conditioning=uc,)

                        x_samples = model.decode_first_stage(samples)
                        x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)

                        for x_sample in x_samples:
                            x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                            Image.fromarray(x_sample.astype(np.uint8)).save(
                                os.path.join(sample_path, f"{incr_seed}_{str(uuid.uuid4())}.png"))


In [None]:
#Grid version

init_image = load_img(image).to("cuda")
init_image = repeat(init_image, '1 ... -> b ...', b=1)
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image))

batch_size = 1
data = [batch_size * [prompt]]


grid_img_strs = [0.02, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5]
grid_blip_strs = [0, 0.1, 0.25, 0.4, 0.55, 0.7, 0.9]

for s in range(10):
    all_samples = []
    for img_strength in grid_img_strs:
        for blip_strength in grid_blip_strs:
            t_enc = int((1-img_strength) * ddim_steps)
            incr_seed = seed + s
            sampler.make_schedule(ddim_num_steps=50, ddim_eta=0.0, verbose=False)

            seed_everything(incr_seed)
            precision_scope = autocast
            with torch.no_grad():
                    with precision_scope("cuda"):
                        with model.ema_scope():
                            for prompts in data:
                                uc = model.get_learned_conditioning(batch_size * [""])
                                if isinstance(prompts, tuple):
                                    prompts = list(prompts)
                                c = model.get_learned_conditioning(prompts)

                                if img_cond != None and blip_strength >= 0:
                                    c = torch.lerp(c, img_cond, blip_strength)

                                # encode (scaled latent)
                                z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))
                                # decode it
                                samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=7.5,
                                                          unconditional_conditioning=uc,)

                                x_samples = model.decode_first_stage(samples)
                                x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)

                                for x_sample in x_samples:
                                    x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                                    Image.fromarray(x_sample.astype(np.uint8)).save(
                                        os.path.join(sample_path, f"{incr_seed}_{str(uuid.uuid4())}.png"))
                                    all_samples.append(Image.fromarray(x_sample.astype(np.uint8)))
    gap = 20
    side = 7

    w,h = all_samples[0].size

    big_img_w = (side * w) + ((side - 1) * gap)
    big_img_h = (side * h) + ((side - 1) * gap)

    big_img = Image.new('RGB', (big_img_w, big_img_h))

    for row in range(side):
        for col in range(side):
            ind = row + (col * side)
            curr_w = (col * (w+gap))
            curr_h = (row * (h+gap))
            print(f"Pasting at {curr_w}, {curr_h}")
            big_img.paste(all_samples[ind], (curr_w, curr_h))

    big_img.save(os.path.join(sample_path, f"big_image_{str(uuid.uuid4())}.png"))

In [None]:
#concatenate images
gap = 20
side = 7

w,h = all_samples[0].size

big_img_w = (side * w) + ((side - 1) * gap)
big_img_h = (side * h) + ((side - 1) * gap)

big_img = Image.new('RGB', (big_img_w, big_img_h))

for row in range(side):
    for col in range(side):
        ind = row + (col * side)
        curr_w = (col * (w+gap))
        curr_h = (row * (h+gap))
        print(f"Pasting at {curr_w}, {curr_h}")
        big_img.paste(all_samples[ind], (curr_w, curr_h))
        
big_img.save(os.path.join(sample_path, f"big_image_{str(uuid.uuid4())}.png"))