# Orchestrating LLMs to Write Diverse Stories with Quality Diversity through AI Feedback

_This tutorial is part of the series of pyribs tutorials! See [here](https://docs.pyribs.org/en/latest/tutorials.html) for the list of all tutorials and the order in which they should be read._

Given a creative writing task, such as "write a story about a spy and a politician," there are many possible outcomes. For instance, we could have a story where

(TODO)

or a story where

(TODO -- refer to outputs).

Alternative, we could even...

TODO

In short, a wide range of stories exist, each with their own interesting capabilities.

To explore the diverse possibilities available in creative writing, [Quality Diversity through AI Feedback (QDAIF; Bradley 2024)](https://qdaif.github.io/) proposes to leverage recent advances in large language models (LLMs). Specifically, LLMs in QDAIF serve two purposes.

First, they can generate new stories...

Second, they can evaluate...

In this tutorial, we will show how

## Setup

### Ollama

To run the LLMs for this tutorial, we will use [Ollama](https://ollama.com). If you are running this tutorial locally, please follow the [installation instructions](https://ollama.com/download) for Ollama and skip this cell. On Google Colab, we can install Ollama following the instructions in this [notebook](https://colab.research.google.com/github/5aharsh/collama/blob/main/Ollama_Setup.ipynb) by Saharsh Anand. First, we download and run the Ollama installation script.

In [None]:
!sudo apt update
!sudo apt install -y pciutils
!curl -fsSL https://ollama.com/install.sh | sh

Next, we start the Ollama server in the background.

In [None]:
import subprocess
import threading
import time


def run_ollama_serve():
    subprocess.Popen(["ollama", "serve"])


thread = threading.Thread(target=run_ollama_serve)
thread.start()
time.sleep(5)  # Wait for the server to start.

We can now pull the Llama 3.1 model from [Ollama's library](https://ollama.com/library/llama3.1:8b-instruct-q4_K_M). We select the 8B parameter model that has been finetuned for instruction following. We choose the `q4_K_M` [quantization](https://github.com/ggml-org/llama.cpp/blob/master/tools/quantize/README.md) as it is a recommended size that balances between speed/memory usage and accuracy.

In [6]:
model_name = "llama3.1:8b-instruct-q4_K_M"  # @param {"type":"string"}
!ollama pull {model_name}

llama3.1:8b-instruct-q4_K_M
[?2026h[?25l[1Gpulling manifest ⠋ [K[?25h[?2026l[?2026h[?25l[1Gpulling manifest ⠙ [K[?25h[?2026l[?2026h[?25l[1Gpulling manifest ⠹ [K[?25h[?2026l[?2026h[?25l[1Gpulling manifest [K
pulling 667b0c1932bc: 100% ▕██████████████████▏ 4.9 GB                         [K
pulling 948af2743fc7: 100% ▕██████████████████▏ 1.5 KB                         [K
pulling 0ba8f0e314b4: 100% ▕██████████████████▏  12 KB                         [K
pulling 56bb8bd477a5: 100% ▕██████████████████▏   96 B                         [K
pulling 455f34728c9b: 100% ▕██████████████████▏  487 B                         [K
verifying sha256 digest [K
writing manifest [K
success [K[?25h[?2026l


### Python Dependencies

In addition to pyribs, this tutorial depends on [LangChain](https://python.langchain.com/docs/introduction/), a framework for developing LLM applications. Below we install pyribs, LangChain, and the package that integrates LangChain with Ollama.

In [None]:
%pip install ribs[visualize] langchain langchain-ollama tqdm

In [4]:
import numpy as np
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables.base import Runnable
from langchain_ollama import ChatOllama
from pydantic import BaseModel, Field

## Evaluator

In [1]:
class Evaluator:
    """Manages an LLM to compute objective and measures.

    Args:
        llm: Chat model for computing evaluations.
        n_evals: Number of times to evaluate each objective and measure.
    """

    def __init__(self, llm: BaseChatModel, n_evals: int):
        self.llm = llm
        self.n_evals = n_evals
        self.min_score = 1
        self.max_score = 10

        # Objective: Short story about spy and politician. Note that the text in the
        # schema class (including class name, field name, field description, docstrings)
        # all have some influence on the LLM output -- see here:
        # https://python.langchain.com/docs/how_to/structured_output/
        class Rating(BaseModel):
            rating: int = Field(description="The rating on a scale of 1 to 10.")

        self.objective_template = ChatPromptTemplate(
            [
                (
                    "system",
                    "Determine if the input text contains a high-quality short story containing two characters, a suspicious spy, and a rich politician. For example, a high-quality short story would have good flow, interesting plot, and not repeat similar sentences or undesired items such as titles and URLs. Output your rating on a scale of 1 to 10. Output only a single integer.",
                ),
                ("user", "{text}"),
            ]
        )
        self.objective_chain = (
            self.objective_template | self.llm.with_structured_output(Rating)
        )

        # Measure 0: Romance genre.
        class Romance(BaseModel):
            rating: int = Field(description="The rating on a scale of 1 to 10.")

        self.measure_0_template = ChatPromptTemplate(
            [
                (
                    "system",
                    "Determine if the input text is a romance story. For example, a romance story talks about two characters who fall in love with each other. Output your rating on a scale of 1 to 10. Output only a single integer.",
                ),
                ("user", "{text}"),
            ]
        )
        self.measure_0_chain = (
            self.measure_0_template | self.llm.with_structured_output(Romance)
        )

        # Measure 1: Happy ending.
        class HappyEnding(BaseModel):
            rating: int = Field(description="The rating on a scale of 1 to 10.")

        self.measure_1_template = ChatPromptTemplate(
            [
                (
                    "system",
                    "Determine if the input text is a story with a happy ending. For example, a story where the two characters make peach with each other has a happy ending. Output your rating on a scale of 1 to 10. Output only a single integer.",
                ),
                ("user", "{text}"),
            ]
        )
        self.measure_1_chain = (
            self.measure_1_template | self.llm.with_structured_output(HappyEnding)
        )

    def _compute_score(self, chain: Runnable, texts: list[str]):
        """Computes the objectives and measures for the given batch of inputs.

        Each text input is evaluated `n_evals` times.

        Two values are returned:
        - The first value is `all_scores`, which is a list where each entry contains the
          `n_evals` scores for each text. This is a list because some evals may fail,
          meaning that not all texts will have `n_evals` scores.
        - The second is `mean_scores`, which is the mean score for each piece of text.
        """
        inputs = [{"text": text} for text in texts for _ in range(self.n_evals)]
        outputs = chain.batch(inputs)

        all_scores = []
        mean_scores = []

        for i in range(0, len(outputs), self.n_evals):
            results = outputs[i : i + self.n_evals]
            scores = []
            for r in results:
                # Note: This assumes the schema for each result has a `rating` field,
                # which may not be the case if you modify the schema above.
                score = np.clip(r.rating, self.min_score, self.max_score)
                scores.append(score)

            scores = np.asarray(scores)
            all_scores.append(scores)
            mean_scores.append(scores.mean())

        return all_scores, np.asarray(mean_scores)

    def evaluate_objective(self, texts: list[str]):
        return self._compute_score(self.objective_chain, texts)

    def evaluate(self, texts: list[str]):
        objectives = self._compute_score(self.objective_chain, texts)[1]
        measure_0 = self._compute_score(self.measure_0_chain, texts)[1]
        measure_1 = self._compute_score(self.measure_1_chain, texts)[1]
        measures = np.stack((measure_0, measure_1), axis=1)
        return objectives, measures

In [8]:
evaluator = Evaluator(
    model="llama3.1:8b-instruct-q4_K_M",
    n_evals=5,
)

objectives, measures = evaluator.evaluate(
    [
        # From QDAIF paper.
        "A spy named Joanne wants to infiltrate the premises of Karl Johnson, a highly-influential figure in the city. Karl was a wealthy mayor, and would do anything in his power to suppress any opposing voices. Joanne wanted to figure out what Karl was hiding, but she took a turn for the worse, as she was highly suspicious in her presence outside his home.",
        "The wealthy entrepreneur and member of parliament, Susan, hosted a party at her mansion. She invited all of the residents, as well as an unusual looking man. The man, Dave, was wearing a tacky shirt, and star-shaped glasses, and was actually a spy. He made the whole room laugh with his jokes, and had a secret agenda - to find what Susan does in her private fun room!",
        "The rich politician, Tom’s life took a turn for the worst - he feared all of his close aides all of a sudden after sensing danger in his clique. There was a civil war going on, and he feared for his life. One day, one of his security guards, turned secret agent, decided to sneak into the classified files room, and spied on Johnny, who was in the room. He wanted to find Johnny’s weakness, and strike at the right time.",
    ]
)

print("Objectives:")
print(objectives)
print("Measures:")
print(measures)

Objectives:
[6.66666667 6.33333333 7.33333333 6.66666667]
Measures:
[[1.66666667 2.        ]
 [2.33333333 3.        ]
 [1.66666667 2.        ]
 [8.33333333 9.        ]]


## QDAIF Components in pyribs

### Archive

In [None]:
# TODO

### Emitters

In [None]:
# TODO

### Scheduler

In [None]:
from ribs.schedulers import Scheduler

scheduler = Scheduler(archive, emitters)

## Running QDAIF

In [None]:
total_itrs = 200

## Citation

If you find this tutorial useful, please cite it as:

```
@article{pyribs_qdaif,
  title   = {Orchestrating LLMs to Write Diverse Stories with Quality Diversity through AI Feedback},
  author  = {Bryon Tjanaka},
  journal = {pyribs.org},
  year    = {2025},
  url     = {https://docs.pyribs.org/en/stable/tutorials/qdaif.html}
}
```

## Credits

Thank you to [Sid Srikanth](https://sidsrikanth.com/), [Saeed Hedayatian](https://conflictednerd.github.io/), and the members of the ICAROS Lab for their invaluable feedback in developing this tutorial.