<a href="https://colab.research.google.com/github/honicky/character-extraction/blob/main/Character_Extractor_proprietary_models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install datasets wandb openai tiktoken anthropic


Collecting datasets
  Downloading datasets-2.19.1-py3-none-any.whl (542 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m542.0/542.0 kB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting wandb
  Downloading wandb-0.17.0-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.7/6.7 MB[0m [31m25.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting openai
  Downloading openai-1.30.1-py3-none-any.whl (320 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m320.6/320.6 kB[0m [31m28.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting tiktoken
  Downloading tiktoken-0.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m35.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting anthropic
  Downloading anthropic-0.26.0-py3-none-any.whl (877 kB)
[

# Load and preprocess datasets

We will evaluate on the same data set that we use for the [Character_Extractor_T5_LoRA](https://github.com/honicky/character-extraction/blob/main/Character_Extractor_T5_LoRA.ipynb) notebook (by cutting and pasting the code), except for when evaluating GPT3.5 Turbo since we used that model for generating labels on the [loubnabnl/stories_oh_children](https://huggingface.co/datasets/loubnabnl/stories_oh_children) dataset. In that case, we will only evaluate on the [honicky/short_childrens_stories_with_labeled_character_names](https://huggingface.co/datasets/honicky/short_childrens_stories_with_labeled_character_names) dataset and hope that the distributions are similar enough that we get comparable statistics.

In [3]:
from datasets import load_dataset, concatenate_datasets, DatasetDict

honicky_dataset = load_dataset('honicky/short_childrens_stories_with_labeled_character_names')
loubnabnl_dataset = load_dataset('honicky/stories_oh_children_with_character_names')

Downloading readme:   0%|          | 0.00/815 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/3.07M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2588 [00:00<?, ? examples/s]

Downloading readme:   0%|          | 0.00/408 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/8.81M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/5000 [00:00<?, ? examples/s]

In [4]:
loubnabnl_dataset = loubnabnl_dataset.map(lambda example: {'story': example['train']['completion']})

# Flatten the nested structure
loubnabnl_dataset = loubnabnl_dataset.remove_columns('train').map(lambda example: {'story': example['story'], 'characters': example['characters']})

Map:   0%|          | 0/5000 [00:00<?, ? examples/s]

Map:   0%|          | 0/5000 [00:00<?, ? examples/s]

In [5]:
combined_dataset = concatenate_datasets([honicky_dataset["train"], loubnabnl_dataset["train"]])

In [6]:
# Split into training and test + validation first (95% train, 5% test+val)
train_test_split = honicky_dataset['train'].train_test_split(test_size=0.15, seed=42)

# Split the test+validation set into test and validation (50% test, 50% validation)
test_val_split = train_test_split['test'].train_test_split(test_size=0.5, seed=42)

# Now assemble the final splits
honicky_splits = DatasetDict({
    'train': train_test_split['train'],
    'test': test_val_split['test'],
    'validation': test_val_split['train']  # Since we split test into two halves
})

In [7]:
honicky_splits

DatasetDict({
    train: Dataset({
        features: ['story', 'characters'],
        num_rows: 2199
    })
    test: Dataset({
        features: ['story', 'characters'],
        num_rows: 195
    })
    validation: Dataset({
        features: ['story', 'characters'],
        num_rows: 194
    })
})

In [8]:
# Split into training and test + validation first (95% train, 5% test+val)
train_test_split = combined_dataset.train_test_split(test_size=0.05, seed=42)

# Split the test+validation set into test and validation (50% test, 50% validation)
test_val_split = train_test_split['test'].train_test_split(test_size=0.5, seed=42)

# Now assemble the final splits
final_splits = DatasetDict({
    'train': train_test_split['train'],
    'test': test_val_split['test'],
    'validation': test_val_split['train']  # Since we split test into two halves
})

In [9]:
final_splits

DatasetDict({
    train: Dataset({
        features: ['story', 'characters'],
        num_rows: 7208
    })
    test: Dataset({
        features: ['story', 'characters'],
        num_rows: 190
    })
    validation: Dataset({
        features: ['story', 'characters'],
        num_rows: 190
    })
})

* T5 finetune - 6s / sample on CPU
* DistilBeRT NER - 100s / min on T4

OTS models
* GTP3.5 Turbo
* Claude Sonnet, Haiku
* Mistral 7B w/ outlines
* Phi 3.9B w/ outlines

In [174]:
from openai import OpenAI
from google.colab import userdata
import json

client = OpenAI(api_key=userdata.get('OPENAI_API_KEY'))

character_prompt_template = """Please analyze the following story and identify the main characters.
Output the result in JSON format with a "characters" array containing the names of the main characters

<story>
{story}
</story>
"""

def extract_characters_using_openai(story, model="gpt-3.5-turbo"):
  response = client.chat.completions.create(
    model=model,
    response_format={ "type": "json_object" },
    messages=[
      {"role": "system", "content": "You are a helpful assistant designed to output JSON."},
      {"role": "user", "content": character_prompt_template.format(story=story) }
    ],
    max_tokens=200,
    temperature=0,
  )
  return {
    "characters": json.loads(response.choices[0].message.content).get('characters', ['ERROR']),
    "input_tokens": response.usage.prompt_tokens,
    "output_tokens": response.usage.completion_tokens
  }

# Evaluations utils


In [166]:
import string
# Define a set of characters to strip: all punctuation and whitespace characters
strip_chars = set(string.punctuation + string.whitespace)

def strip_punctuation_whitespace(text):

  # Strip from the beginning
  start = 0
  while start < len(text) and text[start] in strip_chars:
    start += 1

  # Strip from the end
  end = len(text)
  while end > 0 and text[end-1] in strip_chars:
    end -= 1

  # Return the stripped string
  return text[start:end]

def metrics_from_strings(true_labels: list[str], predicted_labels: list[str]):

    # print(f"true_labels: {true_labels}")
    # print(f"predicted_labels: {predicted_labels}")

    # Calculate the intersection of true and predicted labels for correctly predicted labels
    correct_predictions = set(true_labels).intersection(predicted_labels)

    # Precision: correctly predicted positive / all predicted positive
    if len(predicted_labels) == 0:
        precision = 0
    else:
        precision = len(correct_predictions) / len(predicted_labels)

    # Recall: correctly predicted positive / all actual positive
    if len(true_labels) == 0:
        recall = 0
    else:
        recall = len(correct_predictions) / len(true_labels)

    # F1 Score: 2 * (precision * recall) / (precision + recall)
    if precision + recall == 0:
        f1 = 0
    else:
        f1 = 2 * (precision * recall) / (precision + recall)

    return precision, recall, f1

# Parse the strings to remove whitespace and split by commas

# true_labels = [strip_punctuation_whitespace(label) for label in true_labels_str.split(',')]
# predicted_labels = [strip_punctuation_whitespace(label) for label in predicted_labels_str.split(',')]



In [180]:
model_costs_per_token = {
    "claude-3-haiku-20240307": {
        "input": 0.25/1_000_000,
        "output": 1.25/1_000_000,
    },
    "claude-3-sonnet-20240229": {
        "input": 3.00/1_000_000,
        "output": 15.00/1_000_000,
    },
    "claude-3-opus-20240229": {
        "input": 15.00/1_000_000,
        "output": 75.00/1_000_000,
    },
    "gpt-3.5-turbo": {
        "input": 0.50/1_000_000,
        "output": 1.50/1_000_000,
    },
}

In [176]:
for story, characters in zip(honicky_splits['validation']['story'][:5], honicky_splits['validation']['characters'][:5]):
  extracted_characters = extract_characters_using_openai(story)["characters"]
  characters = [strip_punctuation_whitespace(character) for character in characters.split(",")]
  print(f"extracted_characters: {extracted_characters} --- characters: {characters} --- metrics: {metrics_from_strings(characters, extracted_characters)}")

extracted_characters: ['Timmy', 'Sara', 'Max', 'Mr. Thompson', 'Principal'] --- characters: ['Timmy', 'Sara', 'Max', 'Mr. Thompson'] --- metrics: (0.8, 1.0, 0.888888888888889)
extracted_characters: ['One', 'Zero'] --- characters: ['One', 'Zero', 'Queen Binary'] --- metrics: (1.0, 0.6666666666666666, 0.8)
extracted_characters: ['Mia', 'Ben'] --- characters: ['Mia', 'Ben'] --- metrics: (1.0, 1.0, 1.0)
extracted_characters: ['Qantas', 'Jetstar'] --- characters: ['Qantas', 'Jetstar'] --- metrics: (1.0, 1.0, 1.0)
extracted_characters: ['Timmy', 'Junior', 'Mr. Laemmle'] --- characters: ['Timmy', 'Junior', 'Mr. Laemmle'] --- metrics: (1.0, 1.0, 1.0)


https://stackoverflow.com/questions/33987060/python-context-manager-that-measures-time

In [11]:
from time import perf_counter

class catchtime:

    def __init__(self, name):
      if name is not None:
        self.name = f" {name}"
      else:
        self.name = ""

    def __enter__(self):
      self.start = perf_counter()
      return self

    def __exit__(self, type, value, traceback):
      self.time = perf_counter() - self.start
      self.readout = f'Time{self.name}: {self.time:.3f} seconds'
      print(self.readout)

In [179]:
with catchtime("gtp3.5-turbo") as timer:
  extracted_characters = [
    extract_characters_using_openai(story)
    for story in honicky_splits['validation']['story']
  ]

true_characters = [
  [strip_punctuation_whitespace(character) for character in characters.split(",")]
  for characters in honicky_splits['validation']['characters']
]


Time gtp3.5-turbo: 142.368 seconds


In [184]:
gtp35_precisions, gtp35_recalls, gtp35_f1s = zip(*[
  metrics_from_strings(true_characters[i], extracted_characters[i]["characters"])
  for i in range(len(extracted_characters))
])
gtp35_input_tokens = sum(response["input_tokens"] for response in extracted_characters)
gtp35_output_tokens = sum(response["output_tokens"] for response in extracted_characters)

In [207]:
import numpy as np

gtp35_metrics = {
    "precision": np.mean(gtp35_precisions),
    "recall": np.mean(gtp35_recalls),
    "f1": np.mean(gtp35_f1s),
    "time": timer.time,
    "time_per_story": (timer.time ) / len(extracted_characters),
    "input_tokens": gtp35_input_tokens,
    "output_tokens": gtp35_output_tokens,
    "total_cost":
      gtp35_input_tokens * model_costs_per_token["gpt-3.5-turbo"]["input"]
      + gtp35_output_tokens * model_costs_per_token["gpt-3.5-turbo"]["output"],
}
gtp35_metrics

{'precision': 0.8856406480117819,
 'recall': 0.9066723940435281,
 'f1': 0.8904982440913309,
 'time': 142.368,
 'time_per_story': 0.7338556701030927,
 'input_tokens': 84580,
 'output_tokens': 4703,
 'total_cost': 0.0493445}

In [None]:

gtp35_metrics = {
    "precision": np.mean(gtp35_precisions),
    "recall": np.mean(gtp35_recalls),
    "f1": np.mean(gtp35_f1s),
    "time": timer.time,
    "time_per_story": timer.time / len(extracted_characters),
}



In [None]:
gtp35_metrics

{'precision': 0.8837584237068773,
 'recall': 0.9045246277205039,
 'f1': 0.8885167204603227,
 'time': 151.2680061609999,
 'time_per_story': 0.779731990520618}

In [132]:
gtp35_metrics = {'precision': 0.8837584237068773,
 'recall': 0.9045246277205039,
 'f1': 0.8885167204603227,
 'time': 151.2680061609999,
 'time_per_story': 0.779731990520618}

In [173]:
import tiktoken
enc = tiktoken.get_encoding("o200k_base")
input_count = sum(
    len(enc.encode(story))
    for story in honicky_splits['validation']['story']
)
output_count = sum(
    len(enc.encode(story))
    for story in honicky_splits['validation']['story']
)

# To get the tokeniser corresponding to a specific model in the OpenAI API:
enc = tiktoken.encoding_for_model("gpt-4o")

# Anthropic

We can try using generic prompting as well as the `tool_use` API to see the difference in token counts and cost. If `tool_use` is about the same cost, we should use that to simplify things.

In [21]:
import anthropic
from google.colab import userdata

client = anthropic.Anthropic(
    # defaults to os.environ.get("ANTHROPIC_API_KEY")
    api_key=userdata.get('ANTHROPIC_API_KEY'),
)

## Basic prompting


In [158]:
import time
import json

character_prompt_template = anthropic.HUMAN_PROMPT + """Please analyze the following story and identify the main characters.
Output the result in JSON format with a "characters" array containing the names of the main characters

<story>
{story}
</story>
"""
def extract_characters_using_anthropic(story, model="claude-3-haiku-20240307", retry_count=2):

  while retry_count > 0:
    try:
      message = client.messages.create(
        model=model,
        max_tokens=512,
        messages=[
          {"role": "user", "content": character_prompt_template.format(story=story)}
        ],
        temperature=0,
      )
      break
    except anthropic.RateLimitError:
      print("Rate limit reached. Retrying in 60 seconds...", end="", flush=True)
      time.sleep(60)
      print(" resuming.")
      retry_count -= 1

  try:
    return {
      "characters": json.loads(message.content[0].text).get('characters', ['ERROR']),
      "input_tokens": message.usage.input_tokens,
      "output_tokens": message.usage.output_tokens
    }
  except json.decoder.JSONDecodeError:
    return {
      "characters": ['ERROR'],
      "input_tokens": message.usage.input_tokens,
      "output_tokens": message.usage.output_tokens
    }


# m = extract_characters_using_anthropic(final_splits['validation']["story"][1])

## Prompt Anthropic Haiku model

Keep track of the time per call and tokens used

In [60]:
with catchtime("claude-3-haiku") as timer:
  extracted_characters = [
    extract_characters_using_anthropic(story, model="claude-3-haiku-20240307")
    for story in honicky_splits['validation']['story']
  ]

true_characters = [
  [strip_punctuation_whitespace(character) for character in characters.split(",")]
  for characters in honicky_splits['validation']['characters']
]


Rate limit reached. Retrying in 60 seconds... resuming.
Rate limit reached. Retrying in 60 seconds... resuming.
Rate limit reached. Retrying in 60 seconds... resuming.
Rate limit reached. Retrying in 60 seconds... resuming.
Time claude-3-haiku: 379.648 seconds


### Cacluate metrics for Haiku model

To cacluate the time per story, I am doing something very hacky, which is to subtract the seconds of delay printed out above from the `timer.time` value.  I could be smarter but...

In [61]:
haiku_precisions, haiku_recalls, haiku_f1s = zip(*[
  metrics_from_strings(true_characters[i], extracted_characters[i]["characters"])
  for i in range(len(extracted_characters))
])
haiku_input_tokens = sum(response["input_tokens"] for response in extracted_characters)
haiku_output_tokens = sum(response["output_tokens"] for response in extracted_characters)


In [205]:
import numpy as np

haiku_metrics = {
    "precision": np.mean(haiku_precisions),
    "recall": np.mean(haiku_recalls),
    "f1": np.mean(haiku_f1s),
    "time": timer.time - 240,
    "time_per_story": (timer.time - 240) / len(extracted_characters),
    "input_tokens": haiku_input_tokens,
    "output_tokens": haiku_output_tokens,
    "total_cost":
      haiku_input_tokens * model_costs_per_token["claude-3-haiku-20240307"]["input"]
      + haiku_output_tokens * model_costs_per_token["claude-3-haiku-20240307"]["output"],

}

haiku_metrics

{'precision': 0.9028145966290296,
 'recall': 0.9369988545246277,
 'f1': 0.9122840268686235,
 'time': 139.64800000000002,
 'time_per_story': 0.7198350515463918,
 'input_tokens': 89882,
 'output_tokens': 6872,
 'total_cost': 0.031060499999999998}

## Anthropic Tool Use API

In [112]:
import anthropic
import time

def extract_characters_using_anthropic_tool_use(story, model="claude-3-haiku-20240307"):
  print_characters_tool_name = "print_characters"
  tools = [
      {
          "name": print_characters_tool_name,
          "description": "Prints out the characters from a story or story fragment",
          "input_schema": {
              "type": "object",
              "properties": {
                  "characters": {
                      "type": "array",
                      "items": {
                          "type": "string",
                          "description": "The name of the character."
                      }
                  }
              },
              "required": ["characters"]
          }
      }
  ]

  query = f"""
  <story>
  {story}
  </story>

  Use the {print_characters_tool_name} tool.
  """

  # print(f"story: {story[:20]}...")
  retry_count = 2
  while retry_count > 0:
    try:
      # print(f"Calling api...")
      response = client.beta.tools.messages.create(
          model=model,
          max_tokens=4096,
          tools=tools,
          tool_choice = {"type": "tool", "name": print_characters_tool_name},
          messages=[{"role": "user", "content": query}],
          temperature=0,
      )
      # print(response)

      json_entities = None
      for content in response.content:
          if content.type == "tool_use" and content.name == print_characters_tool_name:
              json_entities = content.input
              # print(f"json_entities: {json_entities}")
              break
          else:
            print(f"no tool: {content}")

      if json_entities is None:
          print("No JSON entities found in the response.")
          print(response)
          return ['ERROR']

      characters = json_entities.get('characters', ['ERROR'])
      print(f"characters: {characters}")
      break
    except anthropic.RateLimitError:
      print("Rate limit reached. Retrying in 60 seconds...", end="", flush=True)
      time.sleep(60)
      print(" resuming.")
      retry_count -= 1

  return {
    "characters": characters,
    "input_tokens": response.usage.input_tokens,
    "output_tokens": response.usage.output_tokens
  }

## Extract characters using the `tool_use` API

In [113]:
with catchtime("claude-3-haiku") as timer:
  haiku_extracted_characters = [
    extract_characters_using_anthropic_tool_use(story, model="claude-3-haiku-20240307")
    for story in honicky_splits['validation']['story']
  ]

true_characters = [
  [strip_punctuation_whitespace(character) for character in characters.split(",")]
  for characters in honicky_splits['validation']['characters']
]

characters: ['Timmy', 'Sara', 'Max', 'Mr. Thompson', 'ISIS', 'the masked man']
characters: ['X', 'V', 'l', 'w', 'T', '379']
characters: ['Mia', 'Ben']
characters: ['Qantas', 'Jetstar']
characters: ['Timmy', 'Junior', 'Mr. Laemmle']
characters: ['Casey', 'Mom', 'classmate']
Time claude-3-haiku: 4.614 seconds


KeyboardInterrupt: 

In [91]:
with catchtime("claude-3-haiku") as timer:
  extracted_characters = [
    extract_characters_using_anthropic_tool_use(story, model="claude-3-haiku-20240307")
    for story in honicky_splits['validation']['story']
  ]

true_characters = [
  [strip_punctuation_whitespace(character) for character in characters.split(",")]
  for characters in honicky_splits['validation']['characters']
]

Rate limit reached. Retrying in 60 seconds... resuming.
Rate limit reached. Retrying in 60 seconds... resuming.
Rate limit reached. Retrying in 60 seconds... resuming.
Time claude-3-haiku: 329.462 seconds


### Cacluate metrics for Haiku `tool_use`


In [92]:
haiku_tool_precisions, haiku_tool_recalls, haiku_tool_f1s = zip(*[
  metrics_from_strings(true_characters[i], extracted_characters[i]["characters"])
  for i in range(len(extracted_characters))
])
haiku_tool_input_tokens = sum(response["input_tokens"] for response in extracted_characters)
haiku_tool_output_tokens = sum(response["output_tokens"] for response in extracted_characters)


In [211]:
import numpy as np

haiku_tool_use_metrics = {
    "precision": np.mean(haiku_tool_precisions),
    "recall": np.mean(haiku_tool_recalls),
    "f1": np.mean(haiku_tool_f1s),
    "time": timer.time - 180,
    "time_per_story": (timer.time - 180) / len(extracted_characters),
    "input_tokens": haiku_tool_input_tokens,
    "output_tokens": haiku_tool_output_tokens,
    "total_cost":
      haiku_tool_input_tokens * model_costs_per_token["claude-3-haiku-20240307"]["input"]
      + haiku_tool_output_tokens * model_costs_per_token["claude-3-haiku-20240307"]["output"],

}
haiku_tool_use_metrics

{'precision': 0.8756850537778373,
 'recall': 0.9382875143184422,
 'f1': 0.8984140396359923,
 'time': 149.462,
 'time_per_story': 0.770422680412371,
 'input_tokens': 170974,
 'output_tokens': 9107,
 'total_cost': 0.05412725}

### `sonnet` doesn't support tool use very well???

When I call `sonnet` using the exact same API, sonnet doesn't give back the correct schema. It returns a comma separated string for the list of characters, rather than a list of strings



In [130]:
extract_characters_using_anthropic_tool_use(honicky_splits['validation']['story'][2], model="claude-3-sonnet-20240229")

characters: Mia, Ben


{'characters': 'Mia, Ben', 'input_tokens': 731, 'output_tokens': 37}

In [129]:
extract_characters_using_anthropic_tool_use(honicky_splits['validation']['story'][2], model="claude-3-haiku-20240307")

characters: ['Mia', 'Ben']


{'characters': ['Mia', 'Ben'], 'input_tokens': 836, 'output_tokens': 40}

### We'll just use the regular prompt them

In [131]:
with catchtime("claude-3-sonnet") as timer:
  extracted_characters = [
    extract_characters_using_anthropic(story, model="claude-3-sonnet-20240229")
    for story in honicky_splits['validation']['story']
  ]

true_characters = [
  [strip_punctuation_whitespace(character) for character in characters.split(",")]
  for characters in honicky_splits['validation']['characters']
]

Time claude-3-sonnet: 292.577 seconds


In [163]:
# extracted_characters = [ chars if type(chars) == dict else chars[0] for chars in extracted_characters ]

In [169]:
sonnet_precisions, sonnet_recalls, sonnet_f1s = zip(*[
  metrics_from_strings(true_characters[i], extracted_characters[i]["characters"])
  for i in range(len(extracted_characters))
])
sonnet_input_tokens = sum(response["input_tokens"] for response in extracted_characters)
sonnet_output_tokens = sum(response["output_tokens"] for response in extracted_characters)

In [208]:
sonnet_metrics = {
    "precision": np.mean(sonnet_precisions),
    "recall": np.mean(sonnet_recalls),
    "f1": np.mean(sonnet_f1s),
    "time": timer.time,
    "time_per_story": (timer.time ) / len(extracted_characters),
    "input_tokens": sonnet_input_tokens,
    "output_tokens": sonnet_output_tokens,
    "total_cost":
      sonnet_input_tokens * model_costs_per_token["claude-3-sonnet-20240229"]["input"]
      + sonnet_output_tokens * model_costs_per_token["claude-3-sonnet-20240229"]["output"],
}
sonnet_metrics

{'precision': 0.8009818360333825,
 'recall': 0.897479954180985,
 'f1': 0.8368819089407326,
 'time': 292.577,
 'time_per_story': 1.5081288659793814,
 'input_tokens': 84580,
 'output_tokens': 4703,
 'total_cost': 0.32428500000000005}

In [213]:
import pandas as pd

metrics_pdf = pd.DataFrame([
  dict(name="gpt-3.5", **gtp35_metrics),
  dict(name="haiku", **haiku_metrics),
  dict(name="haiku-tool", **haiku_tool_use_metrics),
  dict(name="sonnet", **sonnet_metrics),
])

In [248]:
metrics_pdf.to_json()

'{"name":{"0":"gpt-3.5","1":"haiku","2":"haiku-tool","3":"sonnet"},"precision":{"0":0.885640648,"1":0.9028145966,"2":0.8756850538,"3":0.800981836},"recall":{"0":0.906672394,"1":0.9369988545,"2":0.9382875143,"3":0.8974799542},"f1":{"0":0.8904982441,"1":0.9122840269,"2":0.8984140396,"3":0.8368819089},"time":{"0":142.368,"1":139.648,"2":149.462,"3":292.577},"time_per_story":{"0":0.7338556701,"1":0.7198350515,"2":0.7704226804,"3":1.508128866},"input_tokens":{"0":84580,"1":89882,"2":170974,"3":84580},"output_tokens":{"0":4703,"1":6872,"2":9107,"3":4703},"total_cost":{"0":0.0493445,"1":0.0310605,"2":0.05412725,"3":0.324285}}'

In [247]:
import plotly.graph_objects as go
import pandas as pd
from plotly.subplots import make_subplots

# Assuming your DataFrame is named 'metrics_pdf'
columns_to_group = ['precision', 'recall', 'f1', 'time', 'time_per_story', 'total_cost']

# Melt the DataFrame
melted_df = pd.melt(metrics_pdf, id_vars=['name'], value_vars=columns_to_group, var_name='metric', value_name='value')

# Create a color map for each unique name
color_map = {name: f'rgba({i*50%255}, {(i*70)%255}, {(i*90)%255}, 1)' for i, name in enumerate(metrics_pdf['name'].unique())}

# Create subplots
fig = make_subplots(rows=3, cols=1, specs=[[{}], [{}], [{"secondary_y": True}]])

# Create traces for each subplot
for i, name_value in enumerate(metrics_pdf['name'].unique()):
  color = color_map[name_value]

  # Scores subplot
  fig.add_trace(go.Bar(
    x=melted_df[(melted_df['name'] == name_value) & (melted_df['metric'].isin(['precision', 'recall', 'f1']))]['metric'],
    y=melted_df[(melted_df['name'] == name_value) & (melted_df['metric'].isin(['precision', 'recall', 'f1']))]['value'],
    name=name_value,
    marker_color=color,
    showlegend=True
  ), row=1, col=1)

  # Total Cost subplot
  fig.add_trace(go.Bar(
    x=melted_df[(melted_df['name'] == name_value) & (melted_df['metric'] == 'total_cost')]['metric'],
    y=melted_df[(melted_df['name'] == name_value) & (melted_df['metric'] == 'total_cost')]['value'],
    name=name_value,
    marker_color=color,
    showlegend=False
  ), row=2, col=1)

  # Time and Time per Story subplot (Time)
  fig.add_trace(go.Bar(
    x=melted_df[(melted_df['name'] == name_value) & (melted_df['metric'] == 'time')]['metric'],
    y=melted_df[(melted_df['name'] == name_value) & (melted_df['metric'] == 'time')]['value'],
    name=name_value,
    marker_color=color,
    showlegend=False
  ), row=3, col=1, secondary_y=False)

  # Time and Time per Story subplot (Time per Story)
  fig.add_trace(go.Bar(
    x=melted_df[(melted_df['name'] == name_value) & (melted_df['metric'] == 'time_per_story')]['metric'],
    y=melted_df[(melted_df['name'] == name_value) & (melted_df['metric'] == 'time_per_story')]['value'],
    name=f'{name_value} (Time per Story)',
    marker_color=color,
    showlegend=False,
  ), row=3, col=1, secondary_y=True)

# Update layout
fig.update_layout(
  barmode='group',
  xaxis1=dict(),
  yaxis1=dict(title='Score'),
  xaxis2=dict(),
  yaxis2=dict(title='Total Cost'),
  xaxis3=dict(),
)

# Update y-axes titles for combined subplot
fig.update_yaxes(title_text="Time", row=3, col=1, secondary_y=False)
fig.update_yaxes(title_text="Time per Story", row=3, col=1, secondary_y=True)

# Show the plot
fig.show()
