In [1]:
from pathlib import Path
import torch
import math
import ipywidgets as widgets
from IPython.display import display, Image
import os
import glob
from itertools import cycle
import numpy as np
import random

import utils.bending as util
from utils.wrapper import StreamDiffusionWrapper

def txt2img(wrapper, prompt, noise, bending_fn):
    wrapper.prepare(
        prompt=prompt,
        num_inference_steps=50,
        bending_fn=bending_fn,
        input_noise=noise
    )

    count = len(list(output_folder.iterdir()))
    output_images = wrapper()
    output_images.save(os.path.join(output_folder, f"{count:05}.png"))
    
def list_generator(lst):
    while True:
        for item in lst:
            yield item

IMAGE_STORAGE_PATH = Path("./image_outputs")
SAMPLING_RATE = 44100
util.set_sampling_rate(SAMPLING_RATE)

IMAGE_STORAGE_PATH.mkdir(exist_ok=True)
util.clear_dir(IMAGE_STORAGE_PATH)

bending_functions = {
    "none": None,
    "add_full": util.add_full,
    "threshold": util.threshold
}

functions = [
    util.add_full,
    util.multiply,
    util.add_sparse,
    util.add_noise,
    util.subtract_full,
    util.threshold,
    util.soft_threshold,
    util.soft_threshold2,
    util.inversion,
    util.inversion2,
    util.log,
    util.power,
    util.rotate_z,
    util.rotate_x,
    util.rotate_y,
    util.rotate_y2,
    util.reflect,
    util.hadamard1,
    util.hadamard2,
    # util.gradient,
    # util.dilation,
    # util.erosion,
    # util.sobel,
    util.absolute
]

bending_functions = {str(fn.__name__): fn for fn in functions}
bending_functions['none'] = None


# set StableDiffusionWrapper config
output_folder = IMAGE_STORAGE_PATH
model_id_or_path = "runwayml/stable-diffusion-v1-5"
lora_dict = None
width = 512
height = 512
frame_buffer_size = 1  # batch size
acceleration = "xformers"
seed = 46
t_index_list = [0, 16, 32, 45]

# seed everything
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

prompt = "a floating orb"
layer = 1
bend_function = None


# create wrapper
stream = StreamDiffusionWrapper(
    model_id_or_path=model_id_or_path,
    lora_dict=lora_dict,
    t_index_list=t_index_list,  # the length of this list is the number of denoising steps
    frame_buffer_size=frame_buffer_size,
    width=width,
    height=height,
    warmup=10,
    acceleration=acceleration,
    mode="txt2img",
    use_denoising_batch=False,
    cfg_type="none",
    seed=seed,
    bending_fn=bend_function
)

# Create batched noise
walk_length = 2  # set to 2 for 2pi walk
num_frames = 128
noise = torch.empty((1, 4, stream.stream.latent_height, stream.stream.latent_width), dtype=torch.float64)
# walk_noise_x = torch.distributions.normal.Normal(0, 1).sample(noise.shape).double()
# walk_noise_y = torch.distributions.normal.Normal(0, 1).sample(noise.shape).double()
walk_noise_x = torch.normal(mean=0, std=1, size=noise.shape, dtype=torch.float64)
walk_noise_y = torch.normal(mean=0, std=1, size=noise.shape, dtype=torch.float64)
walk_scale_x = torch.cos(torch.linspace(0, walk_length, num_frames) * math.pi).double()
walk_scale_y = torch.sin(torch.linspace(0, walk_length, num_frames) * math.pi).double()
noise_x = torch.tensordot(walk_scale_x, walk_noise_x, dims=0)
noise_y = torch.tensordot(walk_scale_y, walk_noise_y, dims=0)
batched_noise = noise_x + noise_y
noise_generator = cycle(batched_noise)

A matching Triton is not available, some optimizations will not be enabled.
Error caught was: No module named 'triton'


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

  deprecate("fuse_text_encoder_lora", "0.25", LORA_DEPRECATION_MESSAGE)


In [None]:
# ChatGPT wrote this
# Function to get the latest image from the folder
# Output widget for capturing print statements
print_widget = widgets.Output()

# Function to get the latest image from the folder
def get_latest_image():
    list_of_files = glob.glob(os.path.join(IMAGE_STORAGE_PATH, '*.png'))
    if not list_of_files:
        return None
    latest_file = max(list_of_files, key=os.path.getctime)
    return latest_file

# Function to update the image in the widget
def update_image():
    latest_image_path = get_latest_image()
    if latest_image_path:
        with open(latest_image_path, "rb") as f:
            image_bytes = f.read()
            image_widget.value = image_bytes

# UI elements
input_prompt = widgets.Text(
    value='',
    placeholder='Enter a prompt',
    description='Prompt:',
    disabled=False
)

dropdown_options = list(bending_functions.keys())
dropdown = widgets.Dropdown(
    options=dropdown_options,
    value="none",
    description='Bending Function:',
    disabled=False
)

slider_min = widgets.FloatText(
    value=0.0,
    disabled=False,
    layout=widgets.Layout(width='50px')
)

slider_max = widgets.FloatText(
    value=10.0,
    disabled=False,
    layout=widgets.Layout(width='50px')
)

slider = widgets.FloatSlider(
    value=0,
    min=slider_min.value,
    max=slider_max.value,
    step=0.1,
    description='Value:',
    continuous_update=False,
    orientation='horizontal'
)

image_widget = widgets.Image(
    format='png', 
    height=512, 
    width=512
)

button = widgets.Button(description='Generate Image')

# Event handler for UI changes
def on_ui_change(change):
    with print_widget:
        print_widget.clear_output()  # Clear previous output
        print(f"Prompt: {input_prompt.value}")
        print(f"Slider Value: {slider.value}")
        print(f"Bending Function: {dropdown.value}")
        
        if bend_function is not None:
            b = bend_function(slider.value)
        else:
            b = None
        txt2img(stream, input_prompt.value, None, b)
        update_image()

def on_slider_range_change(change):
    slider.min = slider_min.value
    slider.max = slider_max.value
    
def on_dropdown_change(change):
    global bend_function
    bend_function = bending_functions[dropdown.value]

# Attach event handler
input_prompt.observe(on_ui_change, names='value')
slider.observe(on_ui_change, names='value')
button.on_click(on_ui_change)
dropdown.observe(on_dropdown_change, names='value')
slider_min.observe(on_slider_range_change, names='value')
slider_max.observe(on_slider_range_change, names='value')

# Arrange UI elements
slider_box = widgets.HBox([slider_min, slider, slider_max])
ui_box = widgets.VBox([input_prompt, dropdown, slider_box, image_widget, button, print_widget])

# Display the UI
display(ui_box)