<a href="https://colab.research.google.com/github/honicky/character-extraction/blob/main/Character_Extractor_open_source_local_models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Character Extractor - open source local models

This notebook is part of an exploration of how to extract the names of characters from stories cheaply and easily. I am doing this as part of a little project to generate and co-author illustrated childrens stories. One of the challenges for this problem is to generate consistent characters, so I am using it as an excuse to learn about different approaches including using off-the-shelf models with manual and automated (DSPy) prompting, fine-tuning small models and just using god models like GTP4 or Claude Opus.

In this notebook I examine a some pre-trained open source models. I have already generated a bunch of story-character pairs using GTP4. I also used the loubnabnl/stories_oh_children and extracted the story character-names using GPT 3.5-turbo.

I use [outlines](https://github.com/outlines-dev/outlines) to guarantee that the output of the models is valid JSON. This helps prevent formatting differences from impacting the model metrics.  

## Install a development version of transformers and other packages

`phi-3` requires a dev version of `transfomers` at the time of writing, so I reinstall it, as per https://huggingface.co/microsoft/Phi-3-mini-4k-instruct

In [None]:
!pip uninstall -y transformers && pip install git+https://github.com/huggingface/transformers

In [None]:
!pip install datasets wandb
!pip install outlines flash-attn


# Load and preprocess datasets

We will evaluate on the same data set that we use for the [Character_Extractor_T5_LoRA](https://github.com/honicky/character-extraction/blob/main/Character_Extractor_T5_LoRA.ipynb) notebook (by cutting and pasting the code).

In [11]:
from datasets import load_dataset, concatenate_datasets, DatasetDict

honicky_dataset = load_dataset('honicky/short_childrens_stories_with_labeled_character_names')

In [12]:
# Split into training and test + validation first (95% train, 5% test+val)
train_test_split = honicky_dataset['train'].train_test_split(test_size=0.15, seed=42)

# Split the test+validation set into test and validation (50% test, 50% validation)
test_val_split = train_test_split['test'].train_test_split(test_size=0.5, seed=42)

# Now assemble the final splits
honicky_splits = DatasetDict({
    'train': train_test_split['train'],
    'test': test_val_split['test'],
    'validation': test_val_split['train']  # Since we split test into two halves
})

In [13]:
honicky_splits

DatasetDict({
    train: Dataset({
        features: ['story', 'characters'],
        num_rows: 2199
    })
    test: Dataset({
        features: ['story', 'characters'],
        num_rows: 195
    })
    validation: Dataset({
        features: ['story', 'characters'],
        num_rows: 194
    })
})

## Prompt and Schema

We will use a simple prompt to make sure that the model outputs JSON, and then provide the schema to `outlines`.

In [16]:
character_prompt_template = """Please analyze the following story and identify the main characters.
Output the result in JSON format with a "characters" array containing the names of the main characters

<story>
{story}
</story>
"""


schema = """
{
  "type": "object",
  "properties": {
    "characters": {
      "type": "array",
      "items": {
        "type": "string",
        "description": "The name of the character."
      }
    }
  },
  "required": ["characters"]
}
"""


In [17]:
def extract_characters_using_outlines(story, generator):
  characters = generator(
    character_prompt_template.format(
      story=story
    )
  )

  return characters["characters"]


## Load the `mistral-7B` model

We need an A100 GPU to run mistral-7B, at least easily and with no quantization

In [6]:
from enum import Enum
from pydantic import BaseModel, constr

import outlines
import torch

from google.colab import userdata
import json

from transformers import AutoModelForCausalLM, AutoTokenizer


mistral_model = AutoModelForCausalLM.from_pretrained(
  "mistralai/Mistral-7B-Instruct-v0.2",
  output_attentions=True,
  token=userdata.get('HF_TOKEN'),
).to("cuda")

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", token=userdata.get('HF_TOKEN'))
mistral_model_outlines = outlines.models.Transformers(mistral_model, tokenizer)


mistral_generator = outlines.generate.json(mistral_model_outlines, schema)


# Character(name='Anderson', age=28, armor=<Armor.chainmail: 'chainmail'>, weapon=<Weapon.sword: 'sword'>, strength=8)


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [9]:
characters = extract_characters_using_outlines(honicky_splits['validation']['story'][0], mistral_generator)
print(characters)

['Timmy', 'Sara', 'Max', 'Mr. Thompson', 'Principal']


# Evaluations utils

I copied and pasted these from the other notebooks in this repo.  If I end up needing these again, I will extract them into a library.

In [18]:
import string
# Define a set of characters to strip: all punctuation and whitespace characters
strip_chars = set(string.punctuation + string.whitespace)

def strip_punctuation_whitespace(text):

  # Strip from the beginning
  start = 0
  while start < len(text) and text[start] in strip_chars:
    start += 1

  # Strip from the end
  end = len(text)
  while end > 0 and text[end-1] in strip_chars:
    end -= 1

  # Return the stripped string
  return text[start:end]

def metrics_from_strings(true_labels: list[str], predicted_labels: list[str]):
    # Calculate the intersection of true and predicted labels for correctly predicted labels
    correct_predictions = set(true_labels).intersection(predicted_labels)

    # Precision: correctly predicted positive / all predicted positive
    if len(predicted_labels) == 0:
        precision = 0
    else:
        precision = len(correct_predictions) / len(predicted_labels)

    # Recall: correctly predicted positive / all actual positive
    if len(true_labels) == 0:
        recall = 0
    else:
        recall = len(correct_predictions) / len(true_labels)

    # F1 Score: 2 * (precision * recall) / (precision + recall)
    if precision + recall == 0:
        f1 = 0
    else:
        f1 = 2 * (precision * recall) / (precision + recall)

    return precision, recall, f1

# Parse the strings to remove whitespace and split by commas

# true_labels = [strip_punctuation_whitespace(label) for label in true_labels_str.split(',')]
# predicted_labels = [strip_punctuation_whitespace(label) for label in predicted_labels_str.split(',')]



In [13]:
for story, characters in zip(honicky_splits['validation']['story'][:5], honicky_splits['validation']['characters'][:5]):
  extracted_characters = extract_characters_using_outlines(story, mistral_generator)
  characters = [strip_punctuation_whitespace(character) for character in characters.split(",")]
  print(f"extracted_characters: {extracted_characters} --- characters: {characters} --- metrics: {metrics_from_strings(characters, extracted_characters)}")

extracted_characters: ['Timmy', 'Sara', 'Max', 'Mr.Thompson'] --- characters: ['Timmy', 'Sara', 'Max', 'Mr. Thompson'] --- metrics: (0.75, 0.75, 0.75)
extracted_characters: ['One', 'Zero'] --- characters: ['One', 'Zero', 'Queen Binary'] --- metrics: (1.0, 0.6666666666666666, 0.8)
extracted_characters: ['Mia', 'Ben'] --- characters: ['Mia', 'Ben'] --- metrics: (1.0, 1.0, 1.0)
extracted_characters: ['Qantas', 'Jetstar'] --- characters: ['Qantas', 'Jetstar'] --- metrics: (1.0, 1.0, 1.0)
extracted_characters: ['Timmy', 'Junior', 'Mr. Laemmle'] --- characters: ['Timmy', 'Junior', 'Mr. Laemmle'] --- metrics: (1.0, 1.0, 1.0)


https://stackoverflow.com/questions/33987060/python-context-manager-that-measures-time

In [19]:
from time import perf_counter

class catchtime:

    def __init__(self, name):
      if name is not None:
        self.name = f" {name}"
      else:
        self.name = ""

    def __enter__(self):
      self.start = perf_counter()
      return self

    def __exit__(self, type, value, traceback):
      self.time = perf_counter() - self.start
      self.readout = f'Time{self.name}: {self.time:.3f} seconds'
      print(self.readout)

## Evaluate mistral-7B with `outlines`

In [16]:
with catchtime("mistral") as timer:
  extracted_characters = [
    extract_characters_using_outlines(story, mistral_generator)
    for story in honicky_splits['validation']['story']
  ]

true_characters = [
  [strip_punctuation_whitespace(character) for character in characters.split(",")]
  for characters in honicky_splits['validation']['characters']
]


Time mistral: 269.805 seconds


In [None]:
print(timer.readout)

Time gtp3.5-turbo: 151.268 seconds


In [17]:
mistral_precisions, mistral_recalls, mistral_f1s = zip(*[
  metrics_from_strings(true_characters[i], extracted_characters[i])
  for i in range(len(extracted_characters))
])

In [18]:
import numpy as np

mistral_metrics = {
    "precision": np.mean(mistral_precisions),
    "recall": np.mean(mistral_recalls),
    "f1": np.mean(mistral_f1s),
    "time": timer.time,
    "time_per_story": timer.time / len(extracted_characters),
}



In [19]:
mistral_metrics

{'precision': 0.8469378988708884,
 'recall': 0.8605670103092783,
 'f1': 0.8454758569373488,
 'time': 269.8045599230004,
 'time_per_story': 1.3907451542422702}

In [24]:
del mistral_model
del mistral_model_outlines
del tokenizer
del mistral_generator
torch.cuda.empty_cache()


## Load and evaluate Phi-3-mini (3.9B)

[microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) has a reputation for punching above its weight, so we will give it a try too


In [20]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from google.colab import userdata
import outlines

phi_model = AutoModelForCausalLM.from_pretrained(
  "microsoft/Phi-3-mini-4k-instruct",
  output_attentions=True,
  token=userdata.get('HF_TOKEN'),
).to("cuda")

tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct", token=userdata.get('HF_TOKEN'))
phi_model_outlines = outlines.models.Transformers(phi_model, tokenizer)

phi_generator = outlines.generate.json(phi_model_outlines, schema)



config.json:   0%|          | 0.00/904 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/16.3k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/2.67G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/172 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/3.17k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/293 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/568 [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Encountered the use of a type that is scheduled for deprecation: type 'reflected set' found for argument 'fsm_finals' of function '_walk_fsm'.

For more information visit https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-reflection-for-list-and-set-types

File "../usr/local/lib/python3.10/dist-packages/outlines/fsm/regex.py", line 415:
@numba.njit(nogil=True, cache=True)
def _walk_fsm(
^

  state_seq = _walk_fsm(
Encountered the use of a type that is scheduled for deprecation: type 'reflected set' found for argument 'fsm_finals' of function 'state_scan_tokens'.

For more information visit https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-reflection-for-list-and-set-types

File "../usr/local/lib/python3.10/dist-packages/outlines/fsm/regex.py", line 651:
@numba.njit(cache=True, nogil=True)
def state_scan_tokens

In [21]:
with catchtime("phi-3") as timer:
  extracted_characters = [
    extract_characters_using_outlines(story, phi_generator)
    for story in honicky_splits['validation']['story']
  ]

You are not running the flash-attention implementation, expect numerical differences.


Time phi-3: 239.398 seconds


Uh, oh.  I couldn't get flash-attention to work with phi-3, so I suspect our performance is way below where it should be.  It seems like the "`You are not running the flash-attention implementation, expect numerical differences.`" is a common error message, and I am using a dev version of the `transformers` library, so I'm not going to dull my sword on this problem.  Ultimately, it will mean that for now, the cost-point of inference on A100s is higher than it should be.


In [25]:
true_characters = [
  [strip_punctuation_whitespace(character) for character in characters.split(",")]
  for characters in honicky_splits['validation']['characters']
]

In [26]:
phi_precisions, phi_recalls, phi_f1s = zip(*[
  metrics_from_strings(true_characters[i], extracted_characters[i])
  for i in range(len(extracted_characters))
])

In [28]:
import numpy as np

phi_metrics = {
    "precision": np.mean(phi_precisions),
    "recall": np.mean(phi_recalls),
    "f1": np.mean(phi_f1s),
    "time": timer.time,
    "time_per_story": timer.time / len(extracted_characters),
}

In [29]:
phi_metrics

{'precision': 0.8006095565373916,
 'recall': 0.8372565864833906,
 'f1': 0.8084318316277079,
 'time': 239.39821746899997,
 'time_per_story': 1.234011430252577}

## Inference cost

For the `phi` and and `mistral` models, we need A100s.  Poking around a bit, I found that lambda labs offers 40GB A100s at $1.29/hour which means, at 50% occupancy, we could do inference at about

