# Introduction

## Premise

This is a utility for converting a (text-only) short children's story into a PDF with vivid illustrations for each scene.

It first breaks the story down into an XML format laying out the main character, settings, and scenes. Then it takes this XML object and converts each of the scenes into an image with consistent styling, characters, and settings.

## Status

**It is very much a work in progress!** As of my writing this introduction, not all functionality (e.g. nothing relating to PDFs) is implemented. The primary goal was to practice working on a **concrete application that includes significant prompt engineering**. At this point, the experimental results still leave a lot to be desired, but they perform the basic task of illustrating individual scenes based only on the story's text and do so with some degree of consistency between images.

## Internals

### XML Format

The XML format used to cut up and provide consistent descriptions for the image generation step follows a strict structure, using the xml.etree.ElementTree library. Details of each tag are spelled out in the `Prompts` section below, so I won't go into detail here, but below is the basic structure:

```xml
<Story>
    <Characters>
        <Character name="John Doe">{visual description}</Character>
        ...
    </Characters>
    <Settings>
        <Setting>{visual description}</Setting>
        ...
    </Settings>
    <Scenes>
        <Scene>
            <Text>{original text of scene}</Text>
            <Setting>{visual description (from a Settings tag child above)}</Setting>
            <Action>{visual description (introducing characters with descriptions from Characters tag children above)}</Action> 
        </Scene>
        ...
    </Scenes>
</Story>
```

## Usage

This is evolving as I continue to implement functionality so I will try to keep this description up to date.

The steps are

1. Paste your OpenAI API key into a blank file at the path `~/openai_api_key`
1. Add the story as a text file in the `input` folder (it is a sibling of this file), keeping track of the filename, e.g. "Little Red Riding Hood"
1. Run all cells
1. Add a cell calling `convert_story_to_html_illustration()` on the story's filename, e.g. `convert_story_to_html_illustration("Little Red Riding Hood")`
1. Open the `output` folder (also a sibling of this file) in your browser.

# Setup

## Imports

In [None]:
from openai import OpenAI

import re
import html
import os
import math
import time
from copy import deepcopy
from typing import Any
from datetime import datetime

import xml.etree.ElementTree as ET

import asyncio
import nest_asyncio
from functools import partial

## Utils

In [None]:
def clone_node(node: ET.Element) -> ET.Element:
    return ET.fromstring(ET.tostring(node))

def base64_to_html_img(base64_string: str):
    return f'<img src="data:image/png;base64,{base64_string}"/>'

def read_file_from_path(path):
    with open(os.path.expanduser(path), 'r', encoding='utf-8') as file:
        return file.read().strip()
    
def read_input(filename: str) -> str:
    return read_file_from_path(f"input/{filename}")

def write_output(filename: str, content: str):
    with open(f"output/{filename}", 'w') as file:
        file.write(content)

## Debugging

In [None]:
class DebugMap(dict):
    def __init__(self, *args, **kwargs):
        super(DebugMap, self).__init__(*args, **kwargs)
        self.counter_map: dict[str, int] = {}

    def debug(self, log_entry_name, debug_item):
        entry_dict = {
            'item': debug_item,
            'timestamp': datetime.now(),
        }
        if log_entry_name in self:
            self.counter_map[log_entry_name] = self.counter_map.get(log_entry_name, 1) + 1
            self[f"{log_entry_name} (#{self.counter_map[log_entry_name]})"] = entry_dict
        else:
            self[log_entry_name] = entry_dict

    def filter(self, key_pattern):
        return {key: self[key] for key in self if bool(re.search(key_pattern, key, re.IGNORECASE))}

    def dump(self):
        for key in self:
            print(f'---KEY: {key}---')
            print('---VALUE---')
            value = self[key]
            if isinstance(value, ET.Element):
                ET.dump(value)
            else:
                print(value)

    def get_item(self, key: str):
        if not key in self:
            return None
        return self[key]['item']

    def dump_keys(self):
        print('\n'.join(self.keys()))

In [None]:
debug_map = DebugMap()

## API Wrapper

In [None]:
# Wrap RPCs to run manual tests without wasting API calls.

class ApiWrapper:
    def __init__(self, completion_mock: str = None, image_mock: str = None):
        self.completion_mock = completion_mock
        self.image_mock = image_mock
        
    def create_completion(self, **kwargs):
        if self.completion_mock is None:
            response = client.chat.completions.create(**kwargs)
            debug_map.debug('Completion raw response (ApiWrapper.create_completion 1)', deepcopy(response))
            return response.choices[0].message.content
        time.sleep(2)
        debug_map.debug('Completion args for mocked RPC (ApiWrapper.create_completion 1)', kwargs)
        return self.completion_mock
    
    def generate_image(self, **kwargs):
        if self.image_mock is None:
            response = client.images.generate(**kwargs)
            debug_map.debug('Image raw response (ApiWrapper.generate_image 1)', deepcopy(response))
            return response.data[0].b64_json
        time.sleep(2)
        debug_map.debug('Image args for mocked RPC (ApiWrapper.generate_image 1)', kwargs)
        return self.image_mock

    def reset(self):
        self.completion_mock = None
        self.image_mock = None

In [None]:
api_wrapper = ApiWrapper()

## Config

In [None]:
openai_api_key = read_file_from_path('~/openai_api_key')

client = OpenAI(api_key=openai_api_key)

In [None]:
nest_asyncio.apply()

## Constants

In [None]:
DEFAULT_IMAGE_LIMIT = 12
MAX_IMAGES_PER_MINUTE = 5
MINUTE_WITH_CUSHION = 90 # Seconds
INDIVIDUAL_IMAGE_REQUEST_WAIT = 15 # Seconds

## Prompts

In [None]:
GENERAL_IMAGE_PROMPT = (
    "Illustrate the following scene as a Matisse oil painting. Paint should cover most of the canvas. It should be painted with broad brushes and " +
    "bright colors from a basic palate. It should have crisp, well-defined edges. Focus more on characters than setting"
)
    
def get_image_prompt(image_prompt_xml_string: str) -> str:
    return f'{GENERAL_IMAGE_PROMPT}: {image_prompt_xml_string}'

In [None]:
XML_PROMPT = """
You will take a story from the user and build a structured XML object to describe it. All tag content must be enclosed between open and close tags.
The XML root tag contains 3 tags, each described below: <Settings>, <Characters>, and <Scenes>.

The <Settings> tag contains a list of <Setting> tags. Each <Setting> tag contains a short description of the appearance of a distinct location
in which at least one Scene takes place. The location should be specific enough all the characters in the scene are clearly visible. You should add 
details that are not specified in the story to make a description that is easier to visualize as long as they do not contradict the story. Be specific 
about shapes, colors, indoor vs. outdoor, etc., subject to the constraint that it must be between 100 and 150 characters.

The <Characters> tag contains a list of <Character> tags. Each <Character> tag contains a visual description of a major character in the story. 
Be specific around colors, clothing, age, and gender, subject to the constraint that it must be between 150 and 200 characters and do not contradict 
the story. Each <Character> tag has an attribute "name" that is the character's most commonly used name.

The <Scenes> tag contains a list of <Scene> tags. Each <Scene> tag represents a different scene in the story. Each Scene should be a long as possible 
because you want to minimize the number of Scenes. A new Scene only begins when either (1) the location changes or (2) a chunk of time in the narrative 
passes without any new action. Moreover, if a scene is shorter than three sentences, it should be combined with the next scene into a single <Scene> tag. 
A scene should never end in the middle of a conversation. Each <Scene> tag contains 3 subtags each described below: <Text>, <Setting>, and <Action>. 
The <Text> tag is the substring of the story narrating the Scene. Each <Scene> tag's <Text> tag is disjoint each other <Scene> tags' <Text> tags
and they partition the full text of the story. The <Setting> tag is taken from the <Setting> tag (nested in the <Settings> tag) that corresponds to the 
location of the Scene, copied exactly. If the Scene spans multiple locations, pick the setting in which the longest substring of Text takes place. The 
<Action> tag visually describes the action in the part of the Scene taking place in the Setting in under 400 characters. The <Action> tag refers to each of the 
characters by the "name" attribute of the corresponding <Character> tag nested in the <Characters> tag, wrapped in the triple brackets, i.e. [[[ and ]]].

For all visual descriptions, i.e. the <Setting>, <Character>, and <Action> tags, you should add details that are not specified in the story to make it 
easier to visualize, but you must not contradict descriptions in the story. Be concise: never waste space on non-visual details.
"""

# Main Code

## XML Prep

### Helpers

In [None]:
def format_character_introduction(character: ET.Element) -> str:
    return f'{character.get('name')} ({character.text.strip('.')})'

def substitute_character_descriptions(story_node: ET.Element) -> ET.Element:
    '''Modifies input and returns result'''
    debug_map.debug('Story node (substitute_character_descriptions 1)', clone_node(story_node))
    character_xml_string_map = {character.get('name'): format_character_introduction(character) for character in story_node.findall('./Characters/Character')}
    for description in story_node.findall('./Scenes/Scene/Action'):
        for character_name in character_xml_string_map.keys():
            description.text = re.sub(rf'\[\[\[{character_name}\]\]\]', character_xml_string_map.get(character_name), description.text, count=1)
        description.text = re.sub(r'(\[\[\[|\]\]\])', '', description.text)
    debug_map.debug('Story node (substitute_character_descriptions 2)', clone_node(story_node))
    return story_node

### Main

In [None]:
def convert_text_to_xml(story_text: str, prompt: str = XML_PROMPT) -> ET.Element:
    debug_map.debug('Story text (convert_text_to_xml 1)', story_text)
    response_content = api_wrapper.create_completion(
        model="gpt-4o-mini",
        messages=[
            {
                "role": "system",
                "content": prompt
            },
            {
                "role": "user",
                "content": re.sub(r'\n+', '\n', story_text)
            }
        ]
    )

    stripped_xml_string = re.sub(r'(?ms)[^<]*(<.*>).*', r'\1', response_content)
    debug_map.debug('Story stripped XML string (convert_text_to_xml 1)', story_text)
    story_node = substitute_character_descriptions(ET.fromstring(stripped_xml_string))
    debug_map.debug('Story node (convert_text_to_xml 1)', clone_node(story_node))
    return story_node

## Image Prep

### Helpers

In [None]:
def parse_image_prompt(scene_node: ET.Element) -> str:
    scene_clone = clone_node(scene_node)
    debug_map.debug('Scene node (parse_image_prompt 1)', scene_clone)
    scene_clone.remove(scene_clone.find('./Text'))
    escaped_xml_string = ET.tostring(scene_clone).decode()
    whitespace_trimmed_unescaped_xml_string = re.sub(r'(?ms)>\s*<', '><', html.unescape(escaped_xml_string))
    debug_map.debug('Image prompt clean XML string (parse_image_prompt 1)', whitespace_trimmed_unescaped_xml_string)
    return whitespace_trimmed_unescaped_xml_string

def get_all_image_generation_inputs(story_node: ET.Element) -> list[dict]:
    debug_map.debug('Story node (get_all_image_generation_inputs 1)', clone_node(story_node))
    image_prompts = [parse_image_prompt(scene_node) for scene_node in story_node.findall('./Scenes/Scene')]
    return [{
        "model": 'dall-e-3',
        "prompt": image_prompt,
        "size": '1024x1024',
        "quality": "standard",
        "response_format": "b64_json",
    } for image_prompt in image_prompts]

In [None]:
def batch_image_requests(image_requests: list[str]) -> list[list[str]]:
    debug_map.debug('Image requests (batch_image_requests 1)', deepcopy(image_requests))
    batch_count = math.ceil(len(image_requests) / MAX_IMAGES_PER_MINUTE)
    debug_map.debug('Image request batch count (batch_image_requests 1)', batch_count)
    return [image_requests[i:min(i+MAX_IMAGES_PER_MINUTE,len(image_requests))] for i in range(0, batch_count, 5)]

async def convert_image_request_batch(request_batch: list) -> list[str]:
    debug_map.debug('Image prompt batch (convert_image_request_batch 1)', deepcopy(request_batch))
    loop = asyncio.get_running_loop()
    tasks = [loop.run_in_executor(None, partial(api_wrapper.generate_image, **request)) for request in request_batch]
    base64_strings = await asyncio.gather(*tasks)
    debug_map.debug('Image batch base64 strings (convert_image_request_batch 1)', deepcopy(base64_strings))
    return base64_strings

async def convert_requests_to_base64_images_in_batches(image_requests: list[str]) -> list[str]:
    image_request_batches = batch_image_requests(image_requests)
    debug_map.debug('Image request batches (convert_prompts_to_base64_images_in_batches 1)', deepcopy(image_request_batches))
    converted_image_base64_strings = []
    wait = 0
    while len(image_request_batches) > 0:
        await asyncio.sleep(wait)
        batch = image_request_batches.pop(0)
        converted_image_base64_strings += await convert_image_request_batch(batch)
        wait = MINUTE_WITH_CUSHION
    debug_map.debug('Converted base64 image strings (convert_prompts_to_base64_images_in_batches 1)', deepcopy(converted_image_base64_strings))
        
    return converted_image_base64_strings

In [None]:
async def convert_prompts_to_base64_images_individually(image_requests: list[str]) -> list[str]:
    converted_image_base64_strings = []
    wait = 0
    for request in image_requests:
        await asyncio.sleep(wait)
        converted_image_base64_strings.append(api_wrapper.generate_image(**request))
        wait = INDIVIDUAL_IMAGE_REQUEST_WAIT
    debug_map.debug('Converted base64 image strings (convert_prompts_to_base64_images_individually 1)', deepcopy(converted_image_base64_strings))
        
    return converted_image_base64_strings

### Main

In [None]:
async def convert_xml_to_base64_images_async(story_node: ET.Element, limit, batch_requests: bool) -> list[str]:
    debug_map.debug('Story node (convert_xml_to_base64_images_async 1)', clone_node(story_node))
    image_requests = get_all_image_generation_inputs(story_node)
    debug_map.debug('Image requests (convert_xml_to_base64_images_async 1)', deepcopy(image_requests))
    if len(image_requests) > limit:
        raise Exception(f"The story had {len(image_requests)} scenes in total, which exceeds the limit of {limit} so it won't run unless you raise the limit. NOTE: The default limit is {DEFAULT_IMAGE_LIMIT}.")

    converted_image_base64_strings = await (convert_requests_to_base64_images_in_batches(image_requests) if batch_requests else convert_prompts_to_base64_images_individually(image_requests))
    debug_map.debug('Converted base64 image strings (convert_xml_to_base64_images_async 1)', deepcopy(converted_image_base64_strings))
        
    return converted_image_base64_strings

def convert_xml_to_base64_images(story_node: ET.Element, limit: int, batch_requests: bool = False) -> list[str]:
    debug_map.debug('Story node (convert_xml_to_base64_images 1)', clone_node(story_node))
    debug_map.debug('Image limit (convert_xml_to_base64_images 1)', limit)
    return asyncio.run(convert_xml_to_base64_images_async(story_node, limit, batch_requests))


### TEST ALL THIS WITH DUMMY CALLS TO API!!!

## HTML Prep

### Helpers

In [None]:
HEAD_TAG= """
<head>
    <link rel="preconnect" href="https://fonts.googleapis.com">
    <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
    <link href="https://fonts.googleapis.com/css2?family=Ruge+Boogie&display=swap" rel="stylesheet">
    <style>
        .container {
            display: flex;
            justify-content: center;
            padding: 150px;
            background-image: url('../assets/old_paper.jpg');
            background-size: cover;
        }
        
        .page {
            display: flex;
            flex-direction: column;
            align-items: center;
            gap: .2in;
            width: 1024px;
        }
        
        .text {
            font-size: .2in;
            font-family: "Ruge Boogie", cursive;
            font-weight: 600;
            width: 60%;
            text-align: center;
        }
    </style>
</head>
"""

In [None]:
def combine_scene_text_and_base64_image_into_html(scene_text: str, base64_image: str) -> str:
    debug_map.debug('Scene text (combine_scene_text_and_base64_image_into_html 1)', scene_text)
    debug_map.debug('Base64 image (combine_scene_text_and_base64_image_into_html 1)', base64_image)
    return f"""
<div class="container">
    <div class="page">
        <div class="illustration">{base64_to_html_img(base64_image)}</div>
        <div class="text">{scene_text}</div>
    </div>
</div>
"""

### Main

In [None]:
def convert_xml_and_base64_images_to_html_book_pages(story_node: ET.Element, base64_images: list[str]) -> list[str]:
    debug_map.debug('Story node (convert_xml_and_base64_images_to_html_book_pages 1)', clone_node(story_node))
    debug_map.debug('Base64 images (convert_xml_and_base64_images_to_html_book_pages 1)', deepcopy(base64_images))
    scene_texts = [text_node.text for text_node in story_node.findall('./Scenes/Scene/Text')]
    assert len(scene_texts) == len(base64_images), f"The number of scenes in the story ({len(scene_texts)}) did not match the number of base64 illustrations ({len(base64_images)})"
    
    page_htmls = [combine_scene_text_and_base64_image_into_html(st, bi) for st, bi in zip(scene_texts, base64_images)]
    debug_map.debug('Page HTMLs (convert_xml_and_base64_images_to_html_book_pages 1)', deepcopy(page_htmls))
    return page_htmls

## Integration

In [None]:
def convert_story_to_html_illustration(story_filename: str, limit: int = DEFAULT_IMAGE_LIMIT):
    debug_map.debug('Story filename (convert_story_to_html_illustration 1)', story_filename)
    story_text = read_input(story_filename)
    debug_map.debug('Story text (convert_story_to_html_illustration 1)', story_text)
    story_node = convert_text_to_xml(story_text)
    debug_map.debug('Story node (convert_story_to_html_illustration 1)', clone_node(story_node))
    illustrations_base64 = convert_xml_to_base64_images(story_node, limit)
    debug_map.debug('Illustrations as base64 images (convert_story_to_html_illustration 1)', deepcopy(illustrations_base64))
        
    html_book_pages = convert_xml_and_base64_images_to_html_book_pages(story_node, illustrations_base64)
    debug_map.debug('HTML book pages (convert_story_to_html_illustration 1)', deepcopy(html_book_pages))
    
    write_output(f"{story_filename}.html", HEAD_TAG + '<hr />'.join(html_book_pages))
    print(f"Operation complete! Open output/{story_filename}.html in your browser to view results")
    

## Manual Testing