In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path

In [None]:
from desktop_server.art_generator import load_sd3, prompt_sd3

In [None]:
from shared_image_utils.dithering import atkinson_dither

In [None]:
from shared_matplotlib_utils import OUTLINE

In [None]:
import string

In [None]:
from matplotlib import pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import matplotlib
import numpy as np

In [None]:
import ipywidgets as widgets
from IPython.display import display, clear_output
import tqdm

# Get model

In [None]:
model = load_sd3()

In [None]:
# Generate a base AI photo

In [None]:
def show_images(images, title=""):

    assert len(images) == 4

    fig, axs = plt.subplots(2, 2, figsize=(8,8), constrained_layout=True)
    FigureCanvas(fig)  # Attach canvas (avoids inline rendering)

    fig.suptitle(f"{title}", fontsize=16, )
    
    for ax, letter in zip(axs.flat, string.ascii_lowercase):
        ax.axis("off")
    
        text = f"{letter})"
        ax.text(
            0.05,
            0.95,
            text,
            verticalalignment="center",
            horizontalalignment="center",
            transform = ax.transAxes,
            fontsize=15,
            path_effects=[OUTLINE],
            
        )
    
    axs[0, 0].imshow(np.asarray(images[0]), )
    axs[0, 1].imshow(np.asarray(images[1]), )
    axs[1, 0].imshow(np.asarray(images[2]), )
    axs[1, 1].imshow(np.asarray(images[3]), )

    plt.close(fig)

    return fig

In [None]:
ASSETS = Path("../../../assets/")

# Stable diffusion - Colours

In [None]:
width = 256
height = 256

In [None]:
filename_colours = ASSETS / "prompts_sd3_bw.txt"
assert filename_colours.is_file()

In [None]:
with open(filename_colours, 'r') as f:
    lines = f.readlines()
    lines = [x.strip() for x in lines]

In [None]:
figures = []
for prompt_text in tqdm.tqdm(lines):
    images = [prompt_sd3(model, prompt_text, width=width, height=height) for _ in range(4)]
    fig = show_images(images, title=prompt_text)
    figures.append(fig)

In [None]:
len(figures)

In [None]:
output = widgets.Output()
index = widgets.IntText(value=0, layout=widgets.Layout(width='60px'))

prev_button = widgets.Button(description="Previous")
next_button = widgets.Button(description="Next")

def show_plot(i):
    with output:
        clear_output(wait=True)
        display(figures[i])

def on_next_clicked(b):
    if index.value < len(figures) - 1:
        index.value += 1
        show_plot(index.value)

def on_prev_clicked(b):
    if index.value > 0:
        index.value -= 1
        show_plot(index.value)

prev_button.on_click(on_prev_clicked)
next_button.on_click(on_next_clicked)

In [None]:
controls = widgets.HBox([prev_button, next_button, widgets.Label("Index:"), index])
display(controls)
display(output)
show_plot(index.value)