In its current form, you must run this notebook on a GPU. A T4 is sufficient. It's free
on [Google
Colab](https://stackoverflow.com/questions/62596466/how-can-i-run-notebooks-of-a-github-project-in-google-colab/67344477#67344477).
You can technically run this notebook on a CPU (with minor adjustments), but then it'll
take hours. We'll be running the model 1,500 times!

**Description**: for a [4 GB 4-bit Llama 2 chat
model](https://huggingface.co/TheBloke/Llama-2-7b-Chat-GGUF/blob/main/llama-2-7b-chat.Q4_0.gguf)
and the [AG news](https://huggingface.co/datasets/ag_news) classification task, this
notebook demonstrates that CAPPr gets you +5% absolute accuracy, though at some cost in runtime. Note that this dataset is kind of a strawman b/c it's probably trivial to solve w/ a handful of labeled examples. Still, I think it's useful as a benchmark.

**Contamination notice**: I don't know whether Llama 2 was trained on any AG news data. If
it was, but there's no interaction between the method (CAPPr vs text generation) and
training, then the difference between performances can be studied.

**Estimated run time**: ~20 min.

[Install packages](#install-packages)

[Download model](#download-model)

[Utils](#utils)

[Load data](#load-data)

[Text generation](#text-generation)

[Text generation (Multiple Choice)](#text-generation-multiple-choice)

[CAPPr](#cappr)

# Install packages

For CPU, just do

```
!pip install llama-cpp-python
```

For GPU (ty [this comment](https://github.com/ggerganov/llama.cpp/issues/128#issuecomment-1604696753)):

In [None]:
!CMAKE_ARGS="-DLLAMA_CUBLAS=on" FORCE_CMAKE=1 pip install llama-cpp-python

I'm gonna install `cappr` from source b/c sometimes I use this notebook to statistically
gut check code changes.

I'll also install the `demos` extras for NLP datasets.

In your local env, you'd just do:

```
pip install "cappr[llama-cpp]"
```

In [None]:
!pip install "cappr[demos] @ git+https://github.com/kddubey/cappr.git"

# Download model

The model is a [4 GB 4-bit Llama 2 chat
model](https://huggingface.co/TheBloke/Llama-2-7b-Chat-GGUF/blob/main/llama-2-7b-chat.Q4_0.gguf) with 7B parameters.

In [3]:
!huggingface-cli download \
TheBloke/Llama-2-7b-Chat-GGUF \
llama-2-7b-chat.Q4_0.gguf \
--local-dir . \
--local-dir-use-symlinks False

downloading https://huggingface.co/TheBloke/Llama-2-7b-Chat-GGUF/resolve/main/llama-2-7b-chat.Q4_0.gguf to /root/.cache/huggingface/hub/tmpt3w0cugx
Downloading (…)-2-7b-chat.Q4_0.gguf: 100% 3.83G/3.83G [01:52<00:00, 33.9MB/s]
Storing https://huggingface.co/TheBloke/Llama-2-7b-Chat-GGUF/resolve/main/llama-2-7b-chat.Q4_0.gguf in local_dir at ./llama-2-7b-chat.Q4_0.gguf (not cached).
./llama-2-7b-chat.Q4_0.gguf


In [1]:
model_path = "./llama-2-7b-chat.Q4_0.gguf"

In [2]:
from __future__ import annotations
from string import ascii_uppercase as alphabet
from typing import Collection, Sequence

import datasets
import pandas as pd
from tqdm.auto import tqdm

from llama_cpp import Llama

from cappr.llama_cpp import classify

In [3]:
import torch
n_gpu_layers = -1 if torch.cuda.is_available() else 0
n_gpu_layers

-1

In [4]:
model = Llama(model_path=model_path, logits_all=True, n_gpu_layers=n_gpu_layers)

AVX = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | 


# Load data

In [5]:
DATASET_NAME = "ag_news"

In [6]:
_df = pd.DataFrame(datasets.load_dataset(DATASET_NAME, split="train"))

In [7]:
_df.head()

Unnamed: 0,text,label
0,Wall St. Bears Claw Back Into the Black (Reute...,2
1,Carlyle Looks Toward Commercial Aerospace (Reu...,2
2,Oil and Economy Cloud Stocks' Outlook (Reuters...,2
3,Iraq Halts Oil Exports from Main Southern Pipe...,2
4,"Oil prices soar to all-time record, posing new...",2


Ensure your dataframe passes these checks

In [8]:
assert len(set(_df.index)) == len(_df)
assert "text" in _df.columns
assert "label" in _df.columns

In [9]:
_df["text"] = _df["text"].astype(str)
_df["label"] = _df["label"].astype(int)

In [10]:
len(_df)

120000

We don't need that much data to compare methods. We'll (stratify) sample a couple hundred of them.

In [11]:
_df["label"].value_counts(normalize=True).sort_index()

0    0.25
1    0.25
2    0.25
3    0.25
Name: label, dtype: float64

Is the default context, 512, sufficient? Let's see how many characters (not tokens)
there are in this dataset.

In [12]:
_df["text"].str.len().describe()

count    120000.000000
mean        236.477525
std          66.509741
min         100.000000
25%         196.000000
50%         232.000000
75%         266.000000
max        1012.000000
Name: text, dtype: float64

In [13]:
def stratified_sample(
    df: pd.DataFrame, sample_size: int, random_state: int = None
) -> pd.DataFrame:
    # let's not worry about not exactly returning a df w/ size sample_size for
    # now. it's nbd for this experiment
    num_labels = len(set(df["label"]))
    num_obs_per_label = int(sample_size / num_labels)

    def label_sampler(df_label: pd.DataFrame) -> pd.DataFrame:
        return df_label.sample(num_obs_per_label, random_state=random_state)

    return df.groupby("label", group_keys=False).apply(label_sampler)

In [14]:
sample_size = 500
random_state = 3459

In [15]:
df = stratified_sample(_df, sample_size=sample_size, random_state=random_state)

In [16]:
df["label"].value_counts(normalize=True).sort_index()

0    0.25
1    0.25
2    0.25
3    0.25
Name: label, dtype: float64

In [17]:
news_categories = ("world", "sports", "business", "science")

In [18]:
llama_chat_template = """
<s>[INST] <<SYS>>
{system_prompt}
<</SYS>>

{user_message} [/INST]
""".lstrip(
    "\n"
)

# Text generation

In [19]:
def prompt_chat(text: str):
    return (
        f'The following text was taken from a news article:\n"{text}"\n\n'
        "Identify the topic which the text belongs to."
    )

system_prompt_chat = (
    "Identify the news topic which the text belongs to. The news topics are: "
    "world, sports, business, and science. "
    "Respond only with the correct topic."
)

df["prompt_chat"] = [
    llama_chat_template.format(
        system_prompt=system_prompt_chat,
        user_message=prompt_chat(text)
    )
    for text in df["text"]
]

print(df["prompt_chat"].iloc[0])

<s>[INST] <<SYS>>
Identify the news topic which the text belongs to. The news topics are: world, sports, business, and science. Respond only with the correct topic.
<</SYS>>

The following text was taken from a news article:
"Strong Quake Hits Japan's Hokkaido, 11 Hurt (Reuters) Reuters - A strong earthquake with a preliminary\magnitude of 7.1 hit a wide area of Japan's northernmost main\island of Hokkaido early on Monday, the Japan Meteorological\Agency said."

Identify the topic which the text belongs to. [/INST]



Generate completions

In [None]:
model.reset()
completions = []
for _prompt in tqdm(df["prompt_chat"], total=len(df), desc="Sampling"):
    response = model(_prompt, max_tokens=20, temperature=0)
    completion = response["choices"][0]["text"]
    completions.append(completion)

Inspect completions

In [21]:
pd.Series(completions).sample(10)

37       The text belongs to the topic of "World".
146                                         Sports
192                                       Business
56     The text belongs to the "world" news topic.
484         The text belongs to the "world" topic.
32                                           World
80     The text belongs to the "world" news topic.
243     The text belongs to the topic of "Sports".
425       The text belongs to the "science" topic.
428                                        Science
dtype: object

When you're doing text generation, you often have to write this sort of data-dependent and model-dependent function. For this prompt and the AG news task, it's pretty trivial.

In [22]:
def process_completion(
    completion: str,
    class_names: Sequence[str],
    default=-1,
) -> int:
    for i, name in enumerate(class_names):
        if name in completion.lower():
            return i
    return default

In [23]:
pred_classes_text_gen = [
    process_completion(completion, news_categories)
    for completion in completions
]

How many of the completions could be mapped to a label?

In [24]:
(pd.Series(pred_classes_text_gen) != -1).mean()

0.982

What do invalid completions look like?

In [25]:
pd.Series(completions)[(pd.Series(pred_classes_text_gen) == -1)]

29     The text belongs to the "Politics" or "Governm...
46     The text belongs to the "Politics" or "Governm...
86       The text belongs to the "law" or "legal" topic.
92     The text belongs to the "Politics" or "Governm...
107    The text belongs to the "Politics" or "Governm...
270    The text belongs to the "Politics" or "Governm...
396    The text belongs to the "health" or "medical" ...
467    The text belongs to the topic of "Technology" ...
472              The text belongs to the "Health" topic.
dtype: object

Hmm methinks that it's not fair to figure out how to map these to an existing news category, because the mapping function is dependent on the *observed* mistakes.

How accurate are the predictions?

In [26]:
(pred_classes_text_gen == df['label']).mean()

0.704

# Text generation (Multiple Choice)

A strawman is to use multiple choice. This prompt is the best I could do.

In [27]:
def multiple_choice(*choices) -> str:
    if len(choices) > len(alphabet):
        raise ValueError("There are more choices than letters.")
    letters_and_choices = [
        f"{letter}. {choice}" for letter, choice in zip(alphabet, choices)
    ]
    return "\n".join(letters_and_choices)


def prompt_mc(text: str):
    mc = multiple_choice(*news_categories)
    return (
        f'The following text was taken from a news article:\n"{text}"\n\n'
        "Identify the news topic which the text belongs to:\n"
        f"{mc}\n\n"
        "Answer A, B, C, or D."
    )

system_prompt_mc = (
    "Identify the news topic which the text belongs to. The news categories are: "
    "world, sports, business, and science. Each topic is identified by a letter: "
    "A, B, C, or D, respectively.\n"
    "Respond only with the letter corresponding to the correct news topic."
)

df["prompt_chat_mc"] = [
    llama_chat_template.format(
        system_prompt=system_prompt_mc,
        user_message=prompt_mc(text),
    )
    for text in df["text"]
]

print(df["prompt_chat_mc"].iloc[0])

<s>[INST] <<SYS>>
Identify the news topic which the text belongs to. The news categories are: world, sports, business, and science. Each topic is identified by a letter: A, B, C, or D, respectively.
Respond only with the letter corresponding to the correct news topic.
<</SYS>>

The following text was taken from a news article:
"Strong Quake Hits Japan's Hokkaido, 11 Hurt (Reuters) Reuters - A strong earthquake with a preliminary\magnitude of 7.1 hit a wide area of Japan's northernmost main\island of Hokkaido early on Monday, the Japan Meteorological\Agency said."

Identify the news topic which the text belongs to:
A. world
B. sports
C. business
D. science

Answer A, B, C, or D. [/INST]



In [None]:
model.reset()
completions_mc = []
for _prompt in tqdm(df["prompt_chat_mc"], total=len(df), desc="Sampling"):
    response = model(_prompt, max_tokens=15, temperature=0)
    completion_mc = response["choices"][0]["text"]
    completions_mc.append(completion_mc)

In [29]:
pd.Series(completions_mc).sample(10)

462     The text belongs to the news topic "B" - sports.
96     The news topic that the text belongs to is:\nA...
154    The text belongs to the news topic of "sports"...
318    The text belongs to the news topic "business" ...
417     The text belongs to the news topic "B" - sports.
211    The text belongs to the news topic of "sports"...
66     The news topic that the text belongs to is:\nD...
228    The text belongs to the news topic of "sports"...
489     The text belongs to the news topic "D. science".
74     The text belongs to the news topic "world", so...
dtype: object

In [30]:
def process_completion_mc(
    completion: str,
    class_chars: Sequence[str],
    class_names: Sequence[str],
    default=-1,
) -> int:
    for i, name in enumerate(class_names):
        if name in completion.lower():
            return i
    for i, char in enumerate(class_chars):
        if char in completion:  # need to retain uppercase
            return i
    return default

In [31]:
pred_classes_text_gen_mc = [
    process_completion_mc(
        completion_mc,
        class_chars=alphabet[:len(news_categories)],
        class_names=news_categories
    )
    for completion_mc in completions_mc
]

How many of the sampled completions could be mapped to a label?

In [32]:
(pd.Series(pred_classes_text_gen_mc) != -1).mean()

0.998

Accuracy:

In [33]:
(pred_classes_text_gen_mc == df['label']).mean()

0.622

In [34]:
model.n_tokens

205

We'll need to reset the model before CAPPr.

# CAPPr

In [35]:
prompt_prefix = (
    "Every news article can be categorized as either world, sports, business, or "
    "science.\n"
    "The following text was taken from a news article:"
)

In [36]:
def prompt(text: str):
    return (
        f'\n"{text}"\n\n'
        "The topic which the text belongs to is"
    )

In [37]:
df["prompt"] = [prompt(text) for text in df["text"]]

In [38]:
print(prompt_prefix + df["prompt"].iloc[0])

Every news article can be categorized as either world, sports, business, or science.
The following text was taken from a news article:
"Strong Quake Hits Japan's Hokkaido, 11 Hurt (Reuters) Reuters - A strong earthquake with a preliminary\magnitude of 7.1 hit a wide area of Japan's northernmost main\island of Hokkaido early on Monday, the Japan Meteorological\Agency said."

The topic which the text belongs to is


In [39]:
with classify.cache(model, prompt_prefix):
    pred_probs = classify.predict_proba(
        prompts=df["prompt"],
        completions=news_categories,
        model=model,
        reset_model=False,
    )

conditional log-probs:   0%|          | 0/500 [00:00<?, ?it/s]

In [40]:
(pred_probs.argmax(axis=1) == df["label"]).mean()

0.754

Alternatively, you could use `os.path.commonprefix` to extract `prompt_prefix` from a complete set of prompts.

In [41]:
import os

In [42]:
prompts = [prompt_prefix + prompt for prompt in df["prompt"]]

In [43]:
%timeit os.path.commonprefix(prompts)

67.4 µs ± 18.7 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [44]:
os.path.commonprefix(prompts)

'Every news article can be categorized as either world, sports, business, or science.\nThe following text was taken from a news article:\n"'