In [1]:
from pydantic import BaseModel
import instructor
from openai import AsyncOpenAI
from rich import print

client = instructor.from_openai(
    client=AsyncOpenAI(),
)


class GeneratedStory(BaseModel):
    setting: str
    plot_summary: str
    art_style: str
    main_character_description: str
    choices: list[str]


story = await client.chat.completions.create(
    messages=[
        {
            "role": "user",
            "content": """
        
        Generate a story set in a post-apocalyptic world where zombines have taken over the world. The user encounters the last human alive!

        Flesh out a description of the setting and the plot summary that is vivid and has a lot of details
        
        Rules
        - Generate 2-4 choices that represent actions which the user can take and make sure it's not a terminal choice, the story is just beginning!
        - Generate a visual description that can be used to generate an image for the story
        - Art style should just describe the style of the image. This should mention the color palette, the style of the characters, and the style of the background.
        """,
        }
    ],
    model="gpt-4o-mini",
    response_model=GeneratedStory,
)

print(story)

In [2]:
import asyncio
from pydantic import field_validator, ValidationInfo


class FinalStoryChoice(BaseModel):
    choice_description: str
    choice_consequences: str
    choices: list["FinalStoryChoice"]


class FinalStory(BaseModel):
    story_setting: str
    plot_summary: str
    choices: list[FinalStoryChoice]


class RewrittenChoice(BaseModel):
    choice_description: str
    choice_consequences: str
    choices: list[str]

    @field_validator("choices")
    def validate_choices(cls, v, info: ValidationInfo):
        original_choice = info.context["choice"]
        if original_choice in v:
            raise ValueError(
                "Choices must be unique - do not repeat the original choice"
            )
        return v


async def rewrite_choice(
    client: instructor.AsyncInstructor,
    choice: str,
    story: GeneratedStory,
    prev_choices: list[str],
    max_depth: int = 3,
) -> FinalStoryChoice:
    task = (
        "This is a terminal choice. End the story and rewrite the choice setting and consequence to reflect that. The story should be wrapped up here."
        if len(prev_choices) == max_depth - 1
        else "Continue the story and suggest 2-4 potential choices if applicable"
    )
    rewritten_choice = await client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[
            {
                "role": "user",
                "content": """
            You're an expert story writer. You're given the following choice from a story.

            Choice: {{ choice }}

            Rules
            - Make sure that the choice is consistent with the overall plot of the story.
            - {{task}}
            - If the story is ending, make sure that the story consequence reflects a proper wrap up to the story.
            - choice_description should be in the form of an action (Eg. "Move forward and do X"), be around 1-2 sentences and not mention the main character. It should only describe an action.

            - Choice descriptions should just be 1-2 sentences and reflect a specific user action that the user can take.
            - New Potential choices should be unique and distinct. They should move the story forward in a meaningful way and be distinct from the provided choice above since they are choices to be made after the user has made the choice above.
            - Feel free to terminate the story at this specific choice if it makes sense to do so. Just return an empty list for choices. 
            
            Story Description: 
            - Setting : {{ story.setting }}
            - Plot Summary : {{ story.plot_summary }}

            Prior Choices Made:
            {% for prev_choice in previous_choices %}
            {{ loop.index }}. {{ prev_choice['choice_consequences'] }} : {{ prev_choice['choice_description'] }}
            {% endfor %}
            """,
            }
        ],
        context={
            "choice": choice,
            "story": story,
            "previous_choices": prev_choices,
            "task": task,
        },
        response_model=RewrittenChoice,
    )

    if len(prev_choices) == max_depth - 1:
        return FinalStoryChoice(
            choice_description=rewritten_choice.choice_description,
            choice_consequences=rewritten_choice.choice_consequences,
            choices=[],
        )

    coros = [
        rewrite_choice(
            client,
            potential_choice,
            story,
            prev_choices
            + [
                {
                    "choice_description": rewritten_choice.choice_description,
                    "choice_consequences": rewritten_choice.choice_consequences,
                }
            ],
            max_depth,
        )
        for potential_choice in rewritten_choice.choices
    ]

    rewritten_choices = await asyncio.gather(*coros)

    return FinalStoryChoice(
        choice_description=rewritten_choice.choice_description,
        choice_consequences=rewritten_choice.choice_consequences,
        choices=rewritten_choices,
    )

In [14]:
async def rewrite_story(
    client: instructor.AsyncInstructor, story: GeneratedStory, max_depth: int = 3
):
    coros = [
        rewrite_choice(client, choice, story, [], max_depth) for choice in story.choices
    ]
    rewritten_choices = await asyncio.gather(*coros)
    return FinalStory(
        story_setting=story.setting,
        plot_summary=story.plot_summary,
        choices=rewritten_choices,
    )


rewritten_story = await rewrite_story(client, story, 2)

print(rewritten_story)

  rewritten_story = await rewrite_story(client,story,2)


In [16]:
from common.models import Story, StoryNode, JobStatus


Story = Story(
    title="Zombie Apocalypse",
    description=rewritten_story.story_setting,
    user_id="123",
    status=JobStatus.COMPLETED,
)


In [19]:
def count_nodes(nodes: list[FinalStoryChoice]):
    return len(nodes) + sum(count_nodes(node.choices) for node in nodes)

count_nodes(story.choices)

75

In [11]:
from pathlib import Path
import json

with open("./story.json") as f:
    rewritten_story = FinalStory(**json.load(f))

In [23]:
from typing import Optional
from uuid import UUID, uuid4
from common.models import StoryNode

def flatten_and_format_nodes(
    story_id,
    user_id,
    nodes: list[FinalStoryChoice],
    parent_id: Optional[UUID],
    acc: list[StoryNode],
) -> list[StoryNode]:
    for node in nodes:
        node_id = uuid4()
        new_node = StoryNode(
            id=node_id,
            story_id=story_id,
            parent_node_id=parent_id,
            choice_text=node.choice_description,
            image_url="",
            setting=node.choice_description,
            user_id=user_id,
        )
        acc = acc + [new_node]
        if node.choices:
            acc = flatten_and_format_nodes(
                story_id, user_id, node.choices, node_id, acc
            )
    return acc

nodes = flatten_and_format_nodes(
    "123",
    "123",
    rewritten_story.choices,
    None,
    [],
)



75

In [6]:
from pathlib import Path
import json

with open("./story.json") as f:
    story = FinalStory(**json.load(f))

FinalStory(story_setting="We're finally moving into the new place! The sun shines brightly in the clear blue sky, reflecting off the polished surfaces of the newly constructed home. The scent of fresh paint and cut grass pervades the air, creating a sense of anticipation. As we step through the front door, we're greeted by the echo of our footsteps on the hardwood floors and the smell of new beginnings. The living room is spacious, filled with sunlight that dances through the large windows. The walls are painted a soft cream, providing a warm, welcoming backdrop for the colorful memories we're about to create. Outside, a lush green yard stretches out, bordered by blooming flowers and an oak tree that promises to be a place for picnics and laughter. This is a blank canvas waiting for us to paint our lives upon it.", plot_summary="Today is the big day, the culmination of our long journey to find a perfect home. As we unload boxes and organize our belongings, excitement fills the air. Eve

In [7]:
def count_nodes(nodes: list[FinalStoryChoice]):
    return len(nodes) + sum(count_nodes(node.choices) for node in nodes)


print(count_nodes(story.choices))
