In [1]:
# install dependencies
%pip install openai jinja2 ipywidgets d20

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Note: you may need to restart the kernel to use updated packages.


In [2]:
# function to load the config file
import tomllib
import json
from jinja2 import BaseLoader, Environment
from typing import Literal

class Message:
    role: Literal["system", "user", "assistant"]
    content: str
    def __init__(self, role, content):
        self.role = role
        self.content = content

def load_config(file) -> tuple[dict | None, list[Message]]:
    global config
    with open(file, "rb") as f:
        data = tomllib.load(f)
    meta = data.get("meta", None)
    initial_messages: list[Message] = []
    for msg in data.get("messages", []):
        role = msg.get("role", None)
        content = msg.get("content", None)
        if content is not None and role is not None:
            template = Environment(loader=BaseLoader()).from_string(content)
            initial_messages.append(Message(role, template.render(meta)))
    openai_config = data.get("openai", None)
    return openai_config, initial_messages


In [3]:
import re
import functools

import ipywidgets as widgets
from IPython.display import display, Markdown, clear_output

from openai import OpenAI

import d20

output = widgets.VBox()

openai_config = None
messages = []

def chat():
    global openai_config, messages, output
    client = OpenAI(api_key=openai_config.get("api_key"))
    input_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
   
    stream = client.chat.completions.create(
        model=openai_config.get("model_name"),
        messages=input_messages,
        stream=True,
    )

    loading = widgets.Label(value="Loading...")
    out = widgets.Output()

    output.children += (loading, out)
    with out:
        generated_content = ""
        for chunk in stream:
            clear_output(wait=True)
            if chunk.choices[0].delta.content is not None:
                generated_content += chunk.choices[0].delta.content
                display(Markdown(generated_content))
                
    messages.append(Message("assistant", generated_content))
    parse_messages()

def on_custom_action_submit(b, input):
    global messages, output
    messages.append(Message("user", input.value))
    parse_messages()
    chat()

def on_action_click(b, idx):
    global output, messages 
    if idx is not None:
        user_action = f"I choose action **{idx}**."
        messages.append(Message("user", user_action))
        parse_messages()
        chat()
    else:
        input = widgets.Text(description="Enter your actions here:", 
                             style={"description_width":"initial"},
                             layout=widgets.Layout(width="100%"))
        input_button = widgets.Button(description="Submit", 
                                      layout=widgets.Layout(width="fit-content"), 
                                      style={"description_width":"initial"})
        input_button.on_click(functools.partial(on_custom_action_submit, input=input))
        output.children += (widgets.HBox([input, input_button],layout=widgets.Layout(width="100%")),)

def on_skill_check(b, skill:str, difficulty:str):
    global output, messages
    d = difficulty.lower()
    dc = 50
    if d == "easy":
        dc = 80
    elif d == "hard":
        dc = 20
    roll = d20.roll("1d100").total
    result = "Success" if roll <= dc else "Failure"
    if roll == 1:
        result = "Critical Success"
    elif roll == 100:
        result = "Critical Failure"
    messages.append(Message("user", f"I attempt a **{skill}** check with a difficulty of **{difficulty}**, and I rolled a **{roll}**. The result is a **{result}**."))
    
    parse_messages()
    chat()

def parse_message(index:int , msg: Message, is_last=False):
    global output
    
    
    action_list = []
    content = widgets.Output()
    pattern = None

    # parse actions part
    if "## ACTIONS:" in msg.content:
        pattern = re.compile(r"(\d+).\s(.+)")
        for match in pattern.finditer(msg.content):
            idx, action = match.group(1), match.group(2)
            action_button = widgets.Button(description=f"{idx}. {action}", layout=widgets.Layout(width="fit-content"))
            action_list.append(action_button)
            if is_last:
                action_button.on_click(functools.partial(on_action_click, idx=idx))
        if is_last:
            custom_action_btn = widgets.Button(description="Custom Action", layout=widgets.Layout(width="fit-content"))
            custom_action_btn.on_click(functools.partial(on_action_click, idx=None))
            action_list.append(custom_action_btn)
    
    # parse skill check part
    if "## SKILL CHECK:" in msg.content and is_last:
        pattern = re.compile(r"skill\: (.+)\W+difficulty\: (.+)")
        match = pattern.search(msg.content)
        skill, difficulty = match.group(1), match.group(2)
        skill_check_button = widgets.Button(description=f"{skill} - [{difficulty}]", layout=widgets.Layout(width="fit-content"))
        skill_check_button.on_click(functools.partial(on_skill_check, skill=skill, difficulty=difficulty))
        action_list.append(widgets.VBox([skill_check_button], layout=widgets.Layout(width="100%")))

    with content:
        display(Markdown(f"`{msg.role.capitalize()}` (TURN {index})"))
        if pattern is not None:
            display(Markdown(re.sub(pattern, "", msg.content)))
        else:
            display(Markdown(msg.content))
    actions = widgets.VBox(action_list)
    output.children += (content, actions)

def parse_messages(): 
    global output, messages 
    output.children = []
    for idx, msg in enumerate(messages):
        if msg.role == "assistant" or msg.role == "user":
            parse_message(idx, msg, idx == len(messages)-1)

def on_start_click(b):
    global openai_config, messages
    openai_config, messages = load_config("config.toml")
    parse_messages()

start_btn = widgets.Button(description="Start")
start_btn.on_click(on_start_click)

controller = widgets.VBox([start_btn, output])
display(controller)

VBox(children=(Button(description='Start', style=ButtonStyle()), VBox()))