In [None]:
import torch
from diffusers import FluxPipeline
from tqdm.auto import tqdm 
import matplotlib.pyplot as plt


torch_dtype = torch.bfloat16
device = "cuda"
model_id = "Freepik/flux.1-lite-8B-alpha"
guidance_scale = 3.5  # Keep guidance_scale at 3.5
n_steps = 28
seed = 11

In [None]:
pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=torch_dtype).to(device)
pipe.save_pretrained("flux-lite")

def generate_image(prompt):
    with torch.inference_mode():
        image = pipe(
            prompt=prompt,
            generator=torch.Generator(device="cpu").manual_seed(seed),
            num_inference_steps=n_steps,
            guidance_scale=guidance_scale,
            height=1024,
            width=1024,
        ).images[0]
    image.save(prompt[:10] + ".png")
    
    return image

### LLM part

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "Qwen/Qwen2.5-0.5B-Instruct"

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype="auto", device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
story = """
    George, the curious brown monkey, found himself standing in front of a shiny red fire truck one sunny day.
    He was in awe of the firefighters in their uniforms and helmets. 
    Behind them, the city buildings rose high, and a tree swayed gently in the breeze.",
    "Later, George found himself in a messy kitchen.
    He couldn't resist touching the underside of an overturned frying pan on the stovetop. 
    He wondered how it had ended up there amidst the scattered items
"""

In [None]:

messages = [
    {
        "role": "system",
        "content": "You are a story teller, and whatever story you are given, split it into meaningful parts for image generation prompts",
    },
    {"role": "user", "content": story},
]

text = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)

In [None]:
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

generated_ids = model.generate(**model_inputs, max_new_tokens=512)
generated_ids = [
    output_ids[len(input_ids) :]
    for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

### Storybook generation

- We take the segemnts generated by the diffusion model and feed them one by one into FLUX, then take the seqeunce of images and strign as one gif.
- In v1, we will generate 4 versions of the image, or use latent interpolation to create a video-like effect.

In [None]:
prompt_list = response.split(".\n")
print(prompt_list)

image_seq = []

for prompt in tqdm(prompt_list):
    image = generate_image(prompt)
    image_seq.append(image)

In [None]:
def display_image_grid(images, n_cols=4, figsize=(15, 15)):
    """
    Args:
      images: A list of PIL Image objects.
      n_cols: The number of columns in the grid.
      figsize: The figure size in inches.
    """

    n_rows = len(images) // n_cols + (len(images) % n_cols > 0)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)

    for i, image in enumerate(images):
        row, col = i // n_cols, i % n_cols
        axes[row, col].imshow(image)
        axes[row, col].axis("off")

    plt.tight_layout()
    plt.show()


# Example usage:
display_image_grid(image_seq)