## Storybook

A hobby 'model-bending' project which takes a story/narration, generates sequence of images as a neural visualization.

It uses a text2image model(**Flux**) and an LLM(**Qwen-0.5B**) for restructuring the text. 

In [None]:
! pip install -q diffusers

In [None]:
import torch, gc, os, math
from diffusers import FluxPipeline
from tqdm.auto import tqdm
import matplotlib.pyplot as plt


torch_dtype = torch.bfloat16
device = "cuda"
model_id = "Freepik/flux.1-lite-8B-alpha"
folder = "gend_images"
guidance_scale = 3.5
n_steps = 20
seed = 13
os.makedirs(folder, exist_ok=True)

In [None]:
def clearmem():
    gc.collect()
    torch.cuda.empty_cache()

clearmem()

### Flux/Image generation pipeline

In [None]:
fluxpipe = FluxPipeline.from_pretrained(
    model_id, torch_dtype=torch_dtype, device_map="balanced"
)

In [None]:
def generate_image(prompt):
    with torch.inference_mode():
        image = fluxpipe(
            prompt=prompt,
            generator=torch.Generator(device="cpu").manual_seed(seed),
            num_inference_steps=n_steps,
            guidance_scale=guidance_scale,
            height=512,
            width=512,
        ).images[0]
    path = prompt[:20] + ".png"
    image.save(os.path.join(folder, path))

    return image

In [None]:
sample = "renaissance painting of a colorful jellyfish, underwater in the dark ocean midjourney style"

v = generate_image(sample)

v

### LLM part

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

qwen_model = "Qwen/Qwen2.5-0.5B-Instruct"

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    qwen_model, torch_dtype=torch.bfloat16, device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained(qwen_model)

In [None]:
def save_models():
    try:
        model.save_pretrained("qwen05-model")
        tokenizer.save_pretrained("qwen05-tokenizer")
        fluxpipe.save_pretrained("flux_lite")
    except Exception as e:
        print(f"Save error: {e}")


save_models()

In [None]:
story = """
    There were dragons chasing me through the forest. 
    I ran through the forest and arrrived at the edge of an ocean cliff, Then I jumped into the ocean.
    The dragons followed me into the water, and swam into the ocean. Here were many colorful jellyfish present
"""

In [None]:
messages = [
    {
        "role": "system",
        "content": "You are a story teller, and whatever story you are given, split it into meaningful parts for image generation prompts. Augment the text slightly to capture the scene more for image generation",
    },
    {"role": "user", "content": story},
]

text = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)

text

In [None]:
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

generated_ids = model.generate(**model_inputs, max_new_tokens=512)
generated_ids = [
    output_ids[len(input_ids) :]
    for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

response

### Storybook generation

- We take the segemnts generated by the diffusion model and feed them one by one into FLUX, then take the seqeunce of images and strign as one gif.
- In v1, we will generate 4 versions of the image, or use latent interpolation to create a video-like effect.

In [None]:
prompt_list = response.split(".\n\n")

print(prompt_list)
print(f"total prompts = {len(prompt_list)}")

In [None]:
image_seq = []

for prompt in tqdm(
    prompt_list, total=len(prompt_list), desc="generating images from story"
):
    image = generate_image(prompt)
    image_seq.append(image)
    clearmem()

In [None]:
def display_image_grid(images, titles=prompt_list, figsize=(15, 15)):
    # Calculate number of rows needed
    n_images = len(images)
    n_cols = 4
    n_rows = math.ceil(n_images / n_cols)

    # Create figure and adjust size
    fig = plt.figure(figsize=figsize)
    # Add padding between subplots
    plt.subplots_adjust(hspace=0.3, wspace=0.3)

    # Plot each image
    for i in range(n_images):
        # Create subplot
        ax = fig.add_subplot(n_rows, n_cols, i + 1)

        ax.imshow(images[i])
        # Remove axes
        ax.axis("off")

        # Add title if provided
        if titles is not None and i < len(titles):
            ax.set_title(titles[i][:10])

    # Remove empty subplots
    for i in range(n_images, n_rows * n_cols):
        fig.add_subplot(n_rows, n_cols, i + 1).remove()

    #     plt.tight_layout()
    plt.show()


display_image_grid(image_seq)