In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import warnings
warnings.filterwarnings('ignore', category=UserWarning)

In [None]:
from functools import cache

In [None]:
import os
import textwrap

import matplotlib.patheffects as PathEffects
import matplotlib.pyplot as plt
from matplotlib import patheffects, rcParams
from PIL import Image

In [None]:
from diffusers import AutoPipelineForText2Image
import torch

In [None]:
##

In [None]:
def style_scanner(pipe, prompt, num_images_per_prompt=2):
    return pipe(
        prompt,
        num_inference_steps=1,
        guidance_scale=0.0,
        num_images_per_prompt=num_images_per_prompt,
    ).images

In [None]:
def display_images(
    images: list[Image.Image],
    num_cols=2,
    title=None,
    display_as_bitmap=True,
):

    if not len(images):
        raise ValueError("No images")

    num_images = len(images)
    num_cols = min(num_images, num_cols)
    num_rows = int(num_images / num_cols) + (1 if num_images % num_cols != 0 else 0)

    figure, axes = plt.subplots(num_rows, num_cols, figsize=(5*num_cols, 5*num_rows))

    axes = list(axes.flat)

    for i, (ax, image) in enumerate(zip(axes, images)):

        if display_as_bitmap:
            image = image.convert("1")
        
        ax.imshow(image)
        ax.get_xaxis().set_ticks([])
        ax.get_yaxis().set_ticks([])
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        ax.spines["bottom"].set_visible(False)
        ax.spines["left"].set_visible(False)
        ax.set_aspect("equal")

    for ax in axes[num_images:]:
        ax.set_visible(False)

    figure.subplots_adjust(wspace=0, hspace=0)
    figure.tight_layout()

    if title:
        txt = figure.suptitle(title, fontsize=11)
        txt.set_path_effects([PathEffects.withStroke(linewidth=5, foreground="w")])

In [None]:
assert torch.cuda.is_available()

In [None]:
##

In [None]:
pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
_ = pipe.to("cuda")

In [None]:
_ = pipe.set_progress_bar_config(disable=True) # the pipeline is not the bottleneck here

# Prompts exploration

In [None]:
@cache
def predict(prompt):
    images = style_scanner(pipe, prompt)
    return images

In [None]:
def explore(prompt):
    images = predict(prompt)
    display_images(images, title=prompt)

In [None]:
##

In [None]:
explore("bromoil print of profile of an inner goddess in warrior outfit in their power, white paper")
explore("sketch of profile of an inner goddess in warrior outfit in their power, no background")

In [None]:
explore("Space Explorer, Infinite Cosmos, Astronaut, Stars, Isolation, parametric drawing")

# Scan styles and subjects

In [None]:
styles = """
 - Assembly drawing
 - Bromoil print
 - Brush pen drawing
 - Chalk drawing
 - Charcoal drawing
 - Chiaroscuro
 - Circut diagram
 - Coloring book page
 - Coloring-in sheet
 - Conte drawing
 - Dry brush drawing
 - Elevation drawing
 - Graphite drawing
 - Halftone print
 - Ink drawing
 - Intaglio
 - One line art
 - Parametric Drawing
 - Pen drawing
 - Perspective drawing
 - Schematics
 - Silhouette
 - Stippling
 - Sumi-e drawing
 - Wireframe
 - Wood engraving
 - patent drawing
 - pencil drawing
""".replace(" - ","").split("\n")
styles = [x for x in styles if x]

In [None]:
artists = """
 - Alfred Kubin
 - Aline Kominsky-Crumb
 - Andrew Read
 - André Franquin
 - Ann Telnaes
 - B. Kliban
 - Ben Heine
 - Charles Samuel Addams
 - Christian Coigny
 - Christophe Staelens
 - Christopher Shy
 - Dave Sim
 - Dick Giordano
 - Fu Baoshi
 - Gerd Arntz
 - Hans Bellmer
 - Henri Matisse
 - John Leech
 - Jon Carling
 - Joost Swarte
 - Nathan Wirth
 - Shel Silverstein
 - Shigeo Fukuda
 - Theodor Kittelsen
 - Thomas Nast
 - Vince Low
""".replace(" - ","").split("\n")
artists = [x for x in artists if x]

In [None]:
prompts = """
A lighthouse standing tall against crashing waves
A friendly monster with big, round eyes
""".replace(" - ","").split("\n")
prompts = [x for x in prompts if x] + [""]

In [None]:
# from itertools import product
from tqdm.contrib.itertools import product

In [None]:
for style, prompt in product(styles, prompts):
    explore(style + ", " + prompt)

In [None]:
for artist, prompt in product(artists, prompts):
    explore(artist + ", " + prompt)