**Description**: demonstrates that the zero-shot text classification method [described
here](https://stats.stackexchange.com/q/601159/337906) works well on the [Winograd
Schema Challenge (WSC)](https://cs.nyu.edu/~davise/papers/WinogradSchemas/WS.html). It's
one of the [SuperGLUE tasks](https://super.gluebenchmark.com/tasks) in which labels have
multiple tokens, in some sense. Surprisingly `text-curie-001` is 84% accurate.

☣️ **Contamination notice** ☣️: Some of the WSC examples were included in GPT-3's
training data! So `gpt-3.5-turbo-instruct` and `text-curie-001` were very likely trained
on WSC. The authors [studied this
contamination](https://arxiv.org/pdf/2005.14165.pdf#page=31&zoom=100,96,89) and "found a
2.6% decrease in performance on the clean subset".

**Estimated run time**: ~30 sec.

**Environment**: See the [Setup section in the
README](https://github.com/kddubey/cappr/#installation).

**Other**: You have to have an OpenAI API key stored in the environment variable
`OPENAI_API_KEY`. [Sign up here](https://openai.com/api/). This notebook will manually
ask you to give the go-ahead before incurring any costs. It'll cost ya about
<span>$</span>0.04.

[Load data](#load-data)

[Write prompt](#write-prompt)

[Run model](#run-model)

[Score](#score)

In [1]:
from __future__ import annotations
import logging
import os
import sys

import datasets as nlp_datasets
import numpy as np
import pandas as pd

from cappr import Example
from cappr import openai
from cappr.utils import _batch

sys.path.insert(1, os.path.join(sys.path[0], "..", ".."))
from utils import display_df, remove_suffix

In [2]:
# When hitting the OpenAI endpoints, we'll log any server errors
logging.basicConfig(
    level=logging.INFO,
    handlers=[logging.StreamHandler(stream=sys.stdout)],
    format="%(asctime)s :: %(name)s :: %(levelname)s :: " "%(message)s",
)
logger = logging.getLogger(__name__)

# Load data

Given a passage with a (marked) ambiguous pronoun, the classification problem is to pick 1 of 2 alternatives which the pronoun refers to. 

See the [example on the website](https://cs.nyu.edu/~davise/papers/WinogradSchemas/WS.html). It's pretty cool.

The test set labels are hidden, so I'll score this zero-shot classifier on the 273 examples in the `wsc273` subset of the challenge.

In [3]:
df: pd.DataFrame = (nlp_datasets
                    .load_dataset('winograd_wsc', 'wsc273') ## TODO: idk what the subsets are
                    ['test'] ## only available split
                    .data.to_pandas())

In [4]:
len(df)

273

In [5]:
df.head()

Unnamed: 0,text,pronoun,pronoun_loc,quote,quote_loc,options,label,source
0,The city councilmen refused the demonstrators ...,they,63,they feared violence,63,"[The city councilmen, The demonstrators]",0,(Winograd 1972)
1,The city councilmen refused the demonstrators ...,they,63,they advocated violence,63,"[The city councilmen, The demonstrators]",1,(Winograd 1972)
2,The trophy doesn't fit into the brown suitcase...,it,55,it is too large,55,"[the trophy, the suitcase]",0,Hector Levesque
3,The trophy doesn't fit into the brown suitcase...,it,55,it is too small,55,"[the trophy, the suitcase]",1,Hector Levesque
4,Joan made sure to thank Susan for all the help...,she,47,she had received,47,"[Joan, Susan]",0,Hector Levesque


# Write prompt

The method we'll use is described in [this paper](https://arxiv.org/abs/1806.02847)<sup>1</sup>. See Table 1. We'll do the "partial" method b/c the authors demonstrate that it performs better. The motivation there is identical to the motivation behind this package. In fact, section 3.4 of the [GPT-3 paper](https://arxiv.org/abs/2005.14165)<sup>2</sup> doesn't actually use sampling for WSC! It uses the same partial method. I guess my algorithm isn't so novel after all, heh.

1. Trinh, Trieu H., and Quoc V. Le. "A simple method for commonsense reasoning." arXiv preprint arXiv:1806.02847 (2018).

2. Brown, Tom, et al. "Language models are few-shot learners." Advances in neural information processing systems 33 (2020): 1877-1901.

To create the partial prompts and their completions, I'll just take some of [the code from here](https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/wsc273.py).

In [6]:
_upper_pronouns = [
    "A",
    "An",
    "The",
    "She",
    "He",
    "It",
    "They",
    "My",
    "His",
    "Her",
    "Their",
]


def _normalize_option(doc, option):
    # Append `'s` to possessive determiner based options.
    if doc["pronoun"].lower() in ["my", "his", "her", "our", "their"]:
        option += "'s"
    # Appropriately lowercase the pronoun in the option.
    pronoun = option.split()[0]
    start_of_sentence = doc["text"][doc["pronoun_loc"] - 2] == "."
    if not start_of_sentence and pronoun in _upper_pronouns:
        return option.replace(pronoun, pronoun.lower())
    return option


def _process_doc(doc):
    # The HF implementation of `wsc273` is not `partial evaluation` friendly.
    doc["text"] = doc["text"].replace("  ", " ")
    doc["options"][0] = _normalize_option(doc, doc["options"][0])
    doc["options"][1] = _normalize_option(doc, doc["options"][1])
    return doc


def partial_context(doc, option):
    # Substitute the pronoun in the original text with the specified
    # option and ignore everything after.
    return doc["text"][: doc["pronoun_loc"]] + option


def partial_target(doc):
    # The target is everything after the document specified pronoun.
    start_index = doc["pronoun_loc"] + len(doc["pronoun"])
    return " " + doc["text"][start_index:].strip()

In [7]:
df_exploded = (
    pd.DataFrame([_process_doc(doc) for doc in df.to_dict("records")])
    .explode(column="options")
    .rename(columns={"options": "option"})
)

In [8]:
df_exploded["prompt"] = [
    partial_context(doc, option)
    for doc, option in zip(df_exploded.to_dict("records"), df_exploded["option"])
]
df_exploded["completion"] = [
    partial_target(doc) for doc in df_exploded.to_dict("records")
]
## just in case
df_exploded["prompt"] = df_exploded["prompt"].str.strip()
df_exploded["completion"] = df_exploded["completion"].str.strip()

Let's look at the first 4 examples (8 records in the exploded df)

In [9]:
display_df(df_exploded, columns=["prompt", "completion", "label"], num_rows=8)

Unnamed: 0,prompt,completion,label
0,The city councilmen refused the demonstrators a permit because the city councilmen,feared violence.,0
0,The city councilmen refused the demonstrators a permit because the demonstrators,feared violence.,0
1,The city councilmen refused the demonstrators a permit because the city councilmen,advocated violence.,1
1,The city councilmen refused the demonstrators a permit because the demonstrators,advocated violence.,1
2,The trophy doesn't fit into the brown suitcase because the trophy,is too large.,0
2,The trophy doesn't fit into the brown suitcase because the suitcase,is too large.,0
3,The trophy doesn't fit into the brown suitcase because the trophy,is too small.,1
3,The trophy doesn't fit into the brown suitcase because the suitcase,is too small.,1


I was dubious about how well the code worked, so I scanned more examples. There's a potential problem with the 54th example:

In [10]:
display_df(df_exploded.loc[54], columns=["option", "prompt", "completion"])

Unnamed: 0,option,prompt,completion
54,the gap,There is a gap in the wall. You can see the garden through the gap,.
54,the wall,There is a gap in the wall. You can see the garden through the wall,.


Let's see how many examples have this problem.

In [11]:
_mask_corrupt = df_exploded["completion"] == "."
sum(_mask_corrupt) / 2  ## in the expoded df, there are 2 records per example

18.0

Seems like a systematic issue we need to correct. The problem is that computing Pr('.' | prompt) for these wouldn't discriminate at all. The `option` does discriminate. So let's just take the `option` out of the `prompt` and move it to the `completion`.

In [12]:
_mask_corrupt = df_exploded["completion"] == "."
_wsc_corrupt = df_exploded.copy()[_mask_corrupt]
assert all(
    prompt.endswith(option)
    for prompt, option in zip(_wsc_corrupt["prompt"], _wsc_corrupt["option"])
)

In [13]:
_prompts_fixed = [
    remove_suffix(prompt, option)
    for prompt, option in zip(_wsc_corrupt["prompt"], _wsc_corrupt["option"])
]

_completions_fixed = df_exploded.loc[_mask_corrupt, "option"]

In [14]:
df_exploded.loc[_mask_corrupt, "prompt"] = _prompts_fixed
df_exploded.loc[_mask_corrupt, "completion"] = _completions_fixed

## just in case
df_exploded["prompt"] = df_exploded["prompt"].str.strip()
df_exploded["completion"] = df_exploded["completion"].str.strip()

In [15]:
display_df(df_exploded.loc[54], columns=["option", "prompt", "completion", "label"])

Unnamed: 0,option,prompt,completion,label
54,the gap,There is a gap in the wall. You can see the garden through,the gap,0
54,the wall,There is a gap in the wall. You can see the garden through,the wall,0


There, all better.

# Run model

For WSC, the probability distribution over classes is uniform. So we'll use `prior=None`.

In [16]:
examples = [
    Example(
        prompt=record["prompt"],
        completions=(record["completion"],),
        prior=None,
        normalize=False,
    )
    for record in df_exploded.to_dict("records")
]

In [17]:
len(examples)

546

In [18]:
# $0.02
pred_probs = openai.classify.predict_proba_examples(
    examples, model="gpt-3.5-turbo-instruct", ask_if_ok=True
)

log-probs:   0%|          | 0/546 [00:00<?, ?it/s]

# Score

We flattened/exploded the examples so that there's one record for each (example, option) pair. To go back to the original format, we just need to batch `pred_probs`.

In [19]:
def process_probs(probs: np.ndarray, batch_sizes) -> np.ndarray:
    if len(probs.shape) > 1:
        raise ValueError("Expected probs to have shape (n,).")
    pred_probs_unnorm = list(_batch.variable(probs, batch_sizes))
    pred_probs_unnorm: np.ndarray = np.array(pred_probs_unnorm)
    return pred_probs_unnorm / pred_probs_unnorm.sum(axis=1, keepdims=True)

For WSC, the scoring metric is accuracy.

In [20]:
batch_sizes = df["options"].apply(len)  # ik they're all 2
pred_probs_norm = process_probs(pred_probs[:, 0], batch_sizes)
(pred_probs_norm.argmax(axis=1) == df["label"]).mean()

0.8937728937728938

This roughly matches the performance in the GPT-3 paper (section 3.4). I guess there wasn't much to learn from this b/c we're both basically using the same method. Nice to see that the code kinda works I guess.

While we're were, let's see how `text-curie-001` performs.

In [21]:
# $0.02
pred_probs_curie = openai.classify.predict_proba_examples(
    examples, model="text-curie-001", ask_if_ok=True
)

log-probs:   0%|          | 0/546 [00:00<?, ?it/s]

In [22]:
pred_probs_norm_curie = process_probs(pred_probs_curie[:, 0], batch_sizes)
(pred_probs_norm_curie.argmax(axis=1) == df["label"]).mean()

0.8424908424908425

Not too shabby. But again, WSC is contaminated.