## Dramatron

<img src="./dramatron.png" />

Implementation of [Dramatron](https://arxiv.org/pdf/2209.14958.pdf), Deepmind's generative story prompt model, which starts with a logline and progressively builds elements of a story ending with dialogue for each scene.

In [1]:
from promptx import load

load()

[32m2023-10-31 03:26:17.796[0m | [1mINFO    [0m | [36mpromptx[0m:[36mload[0m:[36m137[0m - [1mloading local app from /home/rjl/promptx/examples/dramatron[0m
[32m2023-10-31 03:26:17.799[0m | [1mINFO    [0m | [36mpromptx[0m:[36mload[0m:[36m140[0m - [1mloaded environment variables from /home/rjl/promptx/examples/dramatron/.env[0m
[32m2023-10-31 03:26:17.800[0m | [1mINFO    [0m | [36mpromptx[0m:[36mload[0m:[36m141[0m - [1mAPI KEY wMeGC[0m


[1m<[0m[1;95mApp[0m[39m local [0m[33mpath[0m[39m=[0m[35m/home/rjl/promptx/examples/[0m[95mdramatron[0m[1m>[0m

First define the pydantic types used to generate the story.

In [2]:
from typing import List
from pydantic import Field

from promptx import prompt, store, query
from promptx.collection import Entity


class Character(Entity):
    name: str
    description: str


class Location(Entity):
    name: str
    description: str


class SceneBeat(Entity):
    location: str
    plot_element: str
    description: str


class Story(Entity):
    logline: str
    title: str = None
    outline: List[SceneBeat] = None
    characters: List[Character] = None
    locations: List[Location] = None

    def __init__(self, logline, **kwargs):
        super().__init__(logline=logline, **kwargs)

Before generating a story lets define some examples to use as few shots in the prompts. We'll use the Star Wars examples defined in the paper.

In [3]:
star_wars = Story(
    title="Star Wars",
    logline='''
    A science - fiction fantasy about a naive but ambitious farm boy from a 
    backwater desert who discovers powers he never knew he had when he teams 
    up with a feisty princess, a mercenary space pilot and an old wizard warrior 
    to lead a ragtag rebellion against the sinister forces of the evil Galactic 
    Empire.
    ''',
    characters=[
        Character(
            name='Luke Skywalker',
            description='''
            Luke Skywalker is the hero. A naive farm boy, he will 
            discover special powers under the guidance of mentor 
            Ben Kenobi.
            ''',
        ),
        Character(
            name='Ben Kenobi',
            description='''
            Ben Kenobi is the mentor figure. A recluse Jedi warrior, 
            he will take Luke Skywalker as apprentice .
            ''',
        ),
        Character(
            name='Dartha Vader',
            description='''
            Darth Vader is the antagonist. As a commander of the 
            evil Galactic Empire, he controls space station The 
            Death Star.
            ''',
        ),
        Character(
            name='Princess Leia',
            description='''
            Princess Leia holds the plans of the Death Star. She is 
            feisty and brave. She will become Luke's friend.
            ''',
        ),
        Character(
            name='Han Solo',
            description='''
            Han Solo is a brash mercenary space pilot of the 
            Millenium Falcon and a friend of Chebacca. He will
            take Luke on his spaceship.
            ''',
        ),
        Character(
            name='Chewbacca',
            description='''
            Chewbacca is a furry and trustful monster. He is a friend 
            of Han Solo and a copilot on the Millemium Falcon.
            ''',
        ),
    ],
    locations=[
        Location(
            name='Farm',
            description='The farm is a desert planet where Luke Skywalker lives',
        ),
    ],
    outline=[
        SceneBeat(
            location='A farm on planet Tatooine',
            plot_element='The Ordinary World',
            description='Luke Skywalker is living a normal and humble life as a farm boy on his home planet.',
        ),
        SceneBeat(
            location='Desert of Tatooine',
            plot_element='Call to Adventure',
            description='''
            Luke is called to his adventure by robot R2-D2 and Ben Kenobi. 
            Luke triggers R2-D2's message from Princess Leia and is intrigued 
            by her message. When R2-D2 escapes to find Ben Kenobi, Luke follows 
            and is later saved by Kenobi, who goes on to tell Luke about his Jedi 
            heritage. Kenobi suggests that he should come with him.
            '''
        ),
        SceneBeat(
            location="Ben Kenobi's farm",
            plot_element='Refusal of the Call',
            description='''
            Luke refuses Kenobi, telling him that he can take Kenobi and the 
            droids as far as Mos Eisley Spaceport - but he can't possibly leave 
            his Aunt and Uncle behind for some space adventure.
            ''',
        ),
        SceneBeat(
            location='A farm on planet Tatooine',
            plot_element='Crossing the First Threshold',
            description='''
            When Luke discovers that the stormtroopers searching for the droids 
            would track them to his farm, he rushes to warn his Aunt and Uncle, 
            only to discover them dead by the hands of the Empire. When Luke 
            returns to Kenobi, he pledges to go with him to Alderaan and learn 
            the ways of the Force like his father before him.
            ''',
        ),
        SceneBeat(
            location='On spaceship The Millenium Falcon',
            plot_element='Tests, Allies, Enemies',
            description='''
            After Luke, Kenobi, and the droids hire Han Solo and Chewbacca to 
            transport them onto Alderaan, Kenobi begins Luke's training in the 
            ways of the Force. Wielding his father's lightsaber, Kenobi 
            challenges Luke. At first, he can't do it. But then Kenobi tells 
            Luke to reach out and trust his feelings. Luke succeeds.
            ''',
        ),
        SceneBeat(
            location='On spaceship The Millenium Falcon',
            plot_element='Approach to the Inmost Cave',
            description='''
            The plan to defeat the Galactic Empire is to bring the Death Star 
            plans to Alderaan so that Princess Leia's father can take them to 
            the Rebellion. However, when they arrive within the system, the 
            planet is destroyed. They come across the Death Star and are pulled 
            in by a tractor beam, now trapped within the Galactic Empire.
            ''',
        ),
        SceneBeat(
            location='On spacestation The Death Star',
            plot_element='Ordeal',
            description='''
            As Kenobi goes off to deactivate the tractor beam so they can escape, 
            Luke, Han, and Chewbacca discover that Princess Leia is being held on 
            the Death Star with them. They rescue her and escape to the Millennium 
            Falcon, hoping that Kenobi has successfully deactivated the tractor 
            beam. Kenobi later sacrifices himself as Luke watches Darth Vader 
            strike him down. Luke must now avenge his fallen mentor and carry on 
            his teachings.
            ''',
        ),
        SceneBeat(
            location='On spacestation The Death Star',
            plot_element='Reward',
            description='''
            Luke has saved the princess and retrieved the Death Star plans. 
            They now have the knowledge to destroy the Galactic Empire's 
            greatest weapon once and for all.
            ''',
        ),
        SceneBeat(
            location='On spaceship The Millenium Falcon',
            plot_element='The Road Back',
            description='''
            Luke, Leia, Han, Chewbacca, and the droids are headed to the hidden 
            Rebellion base with the Death Star plans. They are suddenly pursued 
            by incoming TIE-Fighters, forcing Han and Luke to take action to 
            defend the ship and escape with their lives - and the plans. They 
            race to take the plans to the Rebellion and prepare for battle.
            ''',
        ),
        SceneBeat(
            location='On fighter ship X-Wing',
            plot_element='The Resurrection',
            description='''
            The Rebels - along with Luke as an X-Wing pilot - take on the Death 
            Star. The Rebellion and the Galactic Empire wage war in an epic space 
            battle. Luke is the only X-Wing pilot that was able to get within the 
            trenches of the Death Star. But Darth Vader and his wingmen are in hot 
            pursuit. Just as Darth Vader is about to destroy Luke, Han returns and 
            clears the way for Luke. Luke uses the Force to guide his aiming as he 
            fires upon the sole weak point of the deadly Death Star, destroying it 
            for good.
            ''',
        ),
        SceneBeat(
            location='At the Rebellion base',
            plot_element='The Return',
            description='''
            Luke and Han return to the Rebellion base, triumphant, as they receive 
            medals for the heroic journey. There is peace throughout the galaxy - at 
            least for now.
            ''',
        ),
    ],
) 
    

The first step Dramatron defines in the generation process is writing a title based on a user defined logline. Instead of writing this manually, we can use a prompt to generate the logline from a title.

Next, we define a function that generate an alternative title based on the logline and then use it to add a title to a story instance.

In [4]:
from promptx.template import Example


def write_title(story: Story) -> str:
    return prompt(
        'Suggest a alternative, original and descriptive title for a known story.',
        story.logline,
        examples = [
            Example(
                star_wars.logline,
                "The Death Star's Menace"
            ),
            Example(
                "Residents of San Fernando Valley are under attack by flying saucers from outer space. The aliens are extraterrestrials who seek to stop humanity from creating a doomsday weapon that could destroy the universe and unleash the living dead to stalk humans who wander into the cemetery looking for evidence of the UFOs. The hero Jeff, an airline pilot, will face the aliens.",
                "The Day The Earth Was Saved By Outer Space."
            )
        ]
    )

story = Story(
    '''
    A computer hacker learns from mysterious rebels about the true nature of his reality and his role in the war against its controllers.
    '''
)

title = write_title(story)
story.title = title
story


[1;35mStory[0m[1m([0m
    [33mid[0m=[32m'd344c2d1-07ac-4d52-858c-22ea4e55a274'[0m,
    [33mtype[0m=[32m'story'[0m,
    [33mlogline[0m=[32m'\n    A computer hacker learns from mysterious rebels about the true nature of his reality and his role in the war against its controllers.\n    '[0m,
    [33mtitle[0m=[32m'The Matrix Unveiled'[0m,
    [33moutline[0m=[3;35mNone[0m,
    [33mcharacters[0m=[3;35mNone[0m,
    [33mlocations[0m=[3;35mNone[0m
[1m)[0m

Next we create a list of character objects based on the story logline and title. We then store them as embeddings so we can query them in later steps.

In [5]:
import json

def create_characters(story: Story, n=5) -> List[Character]:
    return prompt(
        f'Create {n} characters for a story.',
        input=story.logline,
        output=[Character],
        examples=[
            Example(
                star_wars.logline,
                star_wars.characters,
            ),
        ],
    ).objects

characters = create_characters(story)
story.characters = characters
story


[1;35mStory[0m[1m([0m
    [33mid[0m=[32m'd344c2d1-07ac-4d52-858c-22ea4e55a274'[0m,
    [33mtype[0m=[32m'story'[0m,
    [33mlogline[0m=[32m'\n    A computer hacker learns from mysterious rebels about the true nature of his reality and his role in the war against its controllers.\n    '[0m,
    [33mtitle[0m=[32m'The Matrix Unveiled'[0m,
    [33moutline[0m=[3;35mNone[0m,
    [33mcharacters[0m=[1m[[0m
        [1;35mCharacter[0m[1m([0m
            [33mid[0m=[32m'16799a95-da93-42be-a4a8-2da3438ee6c3'[0m,
            [33mtype[0m=[32m'character'[0m,
            [33mname[0m=[32m'Neo'[0m,
            [33mdescription[0m=[32m'Neo is a computer hacker who learns about the true nature of reality and his role in the war against the controllers.'[0m
        [1m)[0m,
        [1;35mCharacter[0m[1m([0m
            [33mid[0m=[32m'3ff2a7e2-60c4-4403-bf22-150b8818c6c7'[0m,
            [33mtype[0m=[32m'character'[0m,
            [33mname[0m=[3

Now we can generate a plot outline from the story logline, title, and characters.

In [6]:
def write_beats(story: Story, n=10) -> List[SceneBeat]:
    return prompt(
        f'''
        Write a sequence of {n} scene beats for a story a hero's journey structure.
        ''',
        input=dict(logline=story.logline, characters=story.characters),
        output=[SceneBeat],
        examples=[
            Example(
                dict(
                    logline=star_wars.logline,
                    characters=star_wars.characters,
                ),
                star_wars.outline,
            ),
        ],
    ).objects

beats = write_beats(story)
story.outline = beats
story


[1;35mStory[0m[1m([0m
    [33mid[0m=[32m'd344c2d1-07ac-4d52-858c-22ea4e55a274'[0m,
    [33mtype[0m=[32m'story'[0m,
    [33mlogline[0m=[32m'\n    A computer hacker learns from mysterious rebels about the true nature of his reality and his role in the war against its controllers.\n    '[0m,
    [33mtitle[0m=[32m'The Matrix Unveiled'[0m,
    [33moutline[0m=[1m[[0m
        [1;35mScenebeat[0m[1m([0m
            [33mid[0m=[32m'2d82c211-68ff-475d-b639-8a964bae72c2'[0m,
            [33mtype[0m=[32m'scenebeat'[0m,
            [33mlocation[0m=[32m"Neo[0m[32m's apartment"[0m,
            [33mplot_element[0m=[32m'The Ordinary World'[0m,
            [33mdescription[0m=[32m'Neo is living a mundane life as a computer hacker, unaware of the true nature of his reality.'[0m
        [1m)[0m,
        [1;35mScenebeat[0m[1m([0m
            [33mid[0m=[32m'8cad2f22-c07d-4449-aac9-e35840a7646b'[0m,
            [33mtype[0m=[32m'scenebeat'[0m,
    

Next we extract the scene 'beats' from the plot outline generated in the previous step. These are the main events that happen in the story.

Each scene beat has a location name so we can use this to extract location objects. These are stored like characters. For some reason the examples in the paper don't use locations from Star Wars. They also use the story logline and just the name of the location from the scene beat instead of the description of the scene. It's unclear why these decisions were made, but for consistency we'll do the same.

In [7]:
from typing import Optional


def extract_locations(story: Story) -> List[Location]:
    locations = []
    for beat in story.outline:
        response = prompt(
            '''
            Generate a location based on the story logline and location name. 
            If the location is already known, return None
            ''',
            input=dict(logline=story.logline, name=beat.location, known_locations=[l.name for l in locations]),
            output=Location,
            examples=[
                Example(
                    dict(
                        logline="Morgan adopts a new cat, Misterio, who sets a curse on anyone that pets them.",
                        name="The Adoption Center",
                        known_locations=["Harukiya"],
                    ),
                    Location(
                        name="The Adoption Center",
                        description='''
                        The Adoption Center is a sad place, especially for an unadopted 
                        pet. It is full of walls and walls of cages and cages. Inside of 
                        each is an abandoned animal, longing for a home. The lighting is 
                        dim, gray, buzzing fluorescent.
                        ''',
                    )
                ),
                Example(
                    dict(
                        logline="Morgan adopts a new cat, Misterio, who sets a curse on anyone that pets them.",
                        name="The Adoption Center",
                        known_locations=["Harukiya", "The Adoption Center"],
                    ),
                    None,
                ),
                Example(
                    dict(
                        logline="James finds a well in his backyard that is haunted by the ghost of Sam.",
                        name="The Well",
                    ),
                    Location(
                        name="The Well",
                        description='''
                        The well is buried under grass and hedges. It is at least 
                        twenty feet deep, if not more and it is masoned with stones. 
                        It is 150 years old at least. It stinks of stale, standing 
                        water, and has vines growing up the sides. It is narrow enough 
                        to not be able to fit down if you are a grown adult human.
                        ''',
                    )
                ),
                Example(
                    dict(
                        logline="Mr. Dorbenson finds a book at a garage sale that tells the story of his own life. And it ends in a murder! ",
                        name="The Garage Sale",
                    ),
                    Location(
                        name="The Garage Sale",
                        description='''
                        It is a garage packed with dusty household goods and antiques. 
                        There is a box at the back that says FREE and is full of paper 
                        back books.
                        ''',
                    )
                ),
            ],
        )
        if response is not None:
            locations.append(response)
    return locations
    

locations = extract_locations(story)
story.locations = locations
story


[1;35mStory[0m[1m([0m
    [33mid[0m=[32m'd344c2d1-07ac-4d52-858c-22ea4e55a274'[0m,
    [33mtype[0m=[32m'story'[0m,
    [33mlogline[0m=[32m'\n    A computer hacker learns from mysterious rebels about the true nature of his reality and his role in the war against its controllers.\n    '[0m,
    [33mtitle[0m=[32m'The Matrix Unveiled'[0m,
    [33moutline[0m=[1m[[0m
        [1;35mScenebeat[0m[1m([0m
            [33mid[0m=[32m'2d82c211-68ff-475d-b639-8a964bae72c2'[0m,
            [33mtype[0m=[32m'scenebeat'[0m,
            [33mlocation[0m=[32m"Neo[0m[32m's apartment"[0m,
            [33mplot_element[0m=[32m'The Ordinary World'[0m,
            [33mdescription[0m=[32m'Neo is living a mundane life as a computer hacker, unaware of the true nature of his reality.'[0m
        [1m)[0m,
        [1;35mScenebeat[0m[1m([0m
            [33mid[0m=[32m'8cad2f22-c07d-4449-aac9-e35840a7646b'[0m,
            [33mtype[0m=[32m'scenebeat'[0m,
    

Finally, we can write a script for each scene using the generated scene beats and the characters and locations we generated earlier. Characters and locations are queried using the scene beat data so that the most relevant entries are found using the stored embeddings. Again, the example follows the paper.

Now we have defined all the functions, lets put it together in a single cell and generate a new story.

In [8]:
def write_scene(story: Story, beat: SceneBeat) -> str:
    try:
        location = next(filter(lambda x: x.name.lower() == beat.location.lower(), story.locations))
    except StopIteration:
        print(f'No location found for {beat.location}')
        return None
    
    return prompt(
        f'''
        Write a scene for a story based on the scene beat and location.
        ''',
        input=dict(
            plot_element=beat.plot_element,
            beat_description=beat.description,
            location=location.name,
            characters=story.characters,
            logline=story.logline,
            title=story.title,
        ),
    )


def write_script(story: Story) -> list[str]:
    return [write_scene(story, beat) for beat in story.outline]

In [9]:
scene = write_scene(story, story.outline[0])
scene

[32m"INT. NEO'S APARTMENT - LIVING ROOM - DAY\n\nNeo sits hunched over his computer, surrounded by stacks of cables, monitors, and glowing computer screens. The room is cluttered and dimly lit, with the only source of light coming from the monitors.\n\nNeo, a young man in his late 20s, has unkempt hair and wears a worn-out leather jacket. He is engrossed in his work, typing furiously on the keyboard, as lines of code flicker on the screen.\n\nThe sound of a knock on the door breaks the silence. Neo looks up, momentarily distracted from his work. He walks over and opens the door to reveal Morpheus, a tall and imposing figure dressed in a long black coat.\n\nMORPHEUS: [0m[32m([0m[32msmirks[0m[32m)[0m[32m Mr. Anderson, or should I say, Neo.\n\nNeo looks at Morpheus, surprised to hear his real name.\n\nNEO: Who are you? How do you know my name?\n\nMorpheus steps into the apartment, scanning the room with his piercing gaze.\n\nMORPHEUS: I know a lot about you, Neo. More than you ca

In [10]:
script = write_script(story)
script


[1m[[0m
    [32m"INT. NEO'S APARTMENT - LIVING ROOM - NIGHT\n\nNeo, a computer hacker in his late 20s, sits alone in his dimly lit apartment. The flickering light from his computer screen casts an eerie glow on his face. He anxiously types away, searching for something that he can't quite grasp.\n\nHis small, cluttered apartment feels strangely detached from the rest of the world. Posters of rebellion and revolution cover the walls, hinting at Neo's discontent with his mundane existence.\n\nSuddenly, his computer screen goes black. Neo frowns and tries to restart the system, but it remains unresponsive. Frustration builds within him.\n\nNEO [0m[32m([0m[32mwhispering to himself[0m[32m)[0m[32m\nCome on, come on...\n\nJust as Neo starts to lose hope, his screen flickers back to life. He is greeted by a message that sends a shiver down his spine.\n\nMESSAGE ON SCREEN\nWake up, Neo.\n\nNeo's eyes widen in disbelief. He reads the cryptic message repeatedly, his mind racing to mak

In [11]:
def create_story(logline: str) -> Story:
    story = Story(logline=logline)
    story.title = write_title(story)
    story.characters = create_characters(story)
    story.outline = write_beats(story)
    story.locations = extract_locations(story)
    return story

In [14]:
uncut_gems = create_story(
    '''
    A charismatic jeweler makes a high-stakes bet that could lead to the windfall of a lifetime.
    '''
)

uncut_gems


[1;35mStory[0m[1m([0m
    [33mid[0m=[32m'22965d8e-81fd-438d-93f9-7e0c93700c33'[0m,
    [33mtype[0m=[32m'story'[0m,
    [33mlogline[0m=[32m'\n    A charismatic jeweler makes a high-stakes bet that could lead to the windfall of a lifetime.\n    '[0m,
    [33mtitle[0m=[32m'Uncut Gemstones'[0m,
    [33moutline[0m=[1m[[0m
        [1;35mScenebeat[0m[1m([0m
            [33mid[0m=[32m'e57e02b0-7d03-4c76-84a4-9a656f5e527d'[0m,
            [33mtype[0m=[32m'scenebeat'[0m,
            [33mlocation[0m=[32m"Charismatic[0m[32m Jeweler's shop"[0m,
            [33mplot_element[0m=[32m'The Ordinary World'[0m,
            [33mdescription[0m=[32m'\n            The charismatic jeweler is living a successful and comfortable life \n            in his shop. He is known for his exquisite craftsmanship and has a \n            loyal clientele.\n            '[0m
        [1m)[0m,
        [1;35mScenebeat[0m[1m([0m
            [33mid[0m=[32m'23a4610f-6b9e-4