**Description**: demonstrates that the zero-shot text classification method [described here](https://stats.stackexchange.com/q/601159/337906) works well on the [Winograd Schema Challenge (WSC)](https://cs.nyu.edu/~davise/papers/WinogradSchemas/WS.html). It's one of the [SuperGLUE tasks](https://super.gluebenchmark.com/tasks) in which labels have multiple tokens, in some sense.

**Estimated run time**: ~1 min.

**Environment**: See the [Setup section in the README](https://github.com/kddubey/lm-classification/#setup).

**Other**: You have to have an OpenAI API key stored in the environment variable `OPENAI_API_KEY`. [Sign up here](https://openai.com/api/). This notebook will warn you about cost before incurring any. It'll cost ya about <span>$</span>0.30.

[Load data](#load-data)

[Write prompt](#write-prompt)

[Run model](#run-model)

[Score](#score)

In [1]:
import logging
import sys

import datasets as nlp_datasets
from IPython.display import display
import numpy as np
import pandas as pd

from lm_classification import classify
from lm_classification.utils import gpt2_tokenizer, batch_variable

None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
## When hitting the OpenAI endpoints, we'll log any server errors
logging.basicConfig(level=logging.INFO,
                    handlers=[logging.StreamHandler(stream=sys.stdout)],
                    format='%(asctime)s :: %(name)s :: %(levelname)s :: '
                           '%(message)s')
logger = logging.getLogger(__name__)

# Load data

Given a passage with a (marked) ambiguous pronoun, the classification problem is to pick 1 of 2 alternatives which the pronoun refers to. 

See the [example on the website](https://cs.nyu.edu/~davise/papers/WinogradSchemas/WS.html). It's pretty cool.

The test set labels are hidden, so I'll score this zero-shot classifier on the 273 examples in the `wsc273` subset of the challenge.

In [3]:
wsc_df: pd.DataFrame = (nlp_datasets
                        .load_dataset('winograd_wsc', 'wsc273') ## TODO: idk what the subsets are
                        ['test'] ## only available split
                        .data.to_pandas())



100%|██████████| 1/1 [00:00<00:00, 167.11it/s]


In [4]:
len(wsc_df)

273

In [5]:
wsc_df.head()

Unnamed: 0,text,pronoun,pronoun_loc,quote,quote_loc,options,label,source
0,The city councilmen refused the demonstrators ...,they,63,they feared violence,63,"[The city councilmen, The demonstrators]",0,(Winograd 1972)
1,The city councilmen refused the demonstrators ...,they,63,they advocated violence,63,"[The city councilmen, The demonstrators]",1,(Winograd 1972)
2,The trophy doesn't fit into the brown suitcase...,it,55,it is too large,55,"[the trophy, the suitcase]",0,Hector Levesque
3,The trophy doesn't fit into the brown suitcase...,it,55,it is too small,55,"[the trophy, the suitcase]",1,Hector Levesque
4,Joan made sure to thank Susan for all the help...,she,47,she had received,47,"[Joan, Susan]",0,Hector Levesque


# Write prompt

The method we'll use is described in [this paper](https://arxiv.org/abs/1806.02847)<sup>1</sup>. See Table 1. We'll do the "partial" method b/c the authors demonstrate that it performs better. The idea there is super similar (if not identical) to the motivation behind this package. In fact, section 3.4 of the [GPT-3 paper](https://arxiv.org/abs/2005.14165)<sup>2</sup> doesn't actually use sampling for WSC! It uses the same partial method. I guess my algorithm isn't so novel after all, heh.

1. Trinh, Trieu H., and Quoc V. Le. "A simple method for commonsense reasoning." arXiv preprint arXiv:1806.02847 (2018).

2. Brown, Tom, et al. "Language models are few-shot learners." Advances in neural information processing systems 33 (2020): 1877-1901.

To create the partial prompts and their completions, I'll just take some of the code from [here](https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/wsc273.py).

In [6]:
_upper_pronouns = ["A","An","The","She","He","It","They","My","His","Her","Their"]


def _normalize_option(doc, option):
    # Append `'s` to possessive determiner based options.
    if doc["pronoun"].lower() in ["my", "his", "her", "our", "their"]:
        option += "'s"
    # Appropriately lowercase the pronoun in the option.
    pronoun = option.split()[0]
    start_of_sentence = doc["text"][doc["pronoun_loc"] - 2] == "."
    if not start_of_sentence and pronoun in _upper_pronouns:
        return option.replace(pronoun, pronoun.lower())
    return option


def _process_doc(doc):
    # The HF implementation of `wsc273` is not `partial evaluation` friendly.
    doc["text"] = doc["text"].replace("  ", " ")
    doc["options"][0] = _normalize_option(doc, doc["options"][0])
    doc["options"][1] = _normalize_option(doc, doc["options"][1])
    return doc


def partial_context(doc, option):
    # Substitute the pronoun in the original text with the specified
    # option and ignore everything after.
    return doc["text"][: doc["pronoun_loc"]] + option


def partial_target(doc):
    # The target is everything after the document specified pronoun.
    start_index = doc["pronoun_loc"] + len(doc["pronoun"])
    return " " + doc["text"][start_index:].strip()

In [7]:
wsc_exploded_df = (pd.DataFrame([_process_doc(doc)
                                 for doc in wsc_df.to_dict('records')])
                   .explode(column='options')
                   .rename(columns={'options': 'option'}))

In [8]:
wsc_exploded_df['prompt'] = [partial_context(doc, option)
                             for doc, option
                             in zip(wsc_exploded_df.to_dict('records'),
                                    wsc_exploded_df['option'])]
wsc_exploded_df['completion'] = [partial_target(doc)
                                 for doc in wsc_exploded_df.to_dict('records')]
## just in case
wsc_exploded_df['prompt'] = wsc_exploded_df['prompt'].str.strip()
wsc_exploded_df['completion'] = wsc_exploded_df['completion'].str.strip()

Let's look at the first 4 examples (8 records in the exploded df)

In [9]:
_num_examples_displayed = 2 * 4
with pd.option_context('max_colwidth', -1):
    display(wsc_exploded_df
            [['prompt', 'completion', 'label']]
            .head(_num_examples_displayed))

Unnamed: 0,prompt,completion,label
0,The city councilmen refused the demonstrators a permit because the city councilmen,feared violence.,0
0,The city councilmen refused the demonstrators a permit because the demonstrators,feared violence.,0
1,The city councilmen refused the demonstrators a permit because the city councilmen,advocated violence.,1
1,The city councilmen refused the demonstrators a permit because the demonstrators,advocated violence.,1
2,The trophy doesn't fit into the brown suitcase because the trophy,is too large.,0
2,The trophy doesn't fit into the brown suitcase because the suitcase,is too large.,0
3,The trophy doesn't fit into the brown suitcase because the trophy,is too small.,1
3,The trophy doesn't fit into the brown suitcase because the suitcase,is too small.,1


I was dubious about how well the code worked, so I scanned more examples. There's a potential problem with the 54th example:

In [10]:
with pd.option_context('max_colwidth', -1):
    display(wsc_exploded_df
            [['option', 'prompt', 'completion', 'label']]
            .loc[54])

Unnamed: 0,option,prompt,completion,label
54,the gap,There is a gap in the wall. You can see the garden through the gap,.,0
54,the wall,There is a gap in the wall. You can see the garden through the wall,.,0


Let's see how many examples have this problem.

In [11]:
_mask_corrupt = wsc_exploded_df['completion'] == '.'
sum(_mask_corrupt) / 2 ## in the expoded df, there are 2 records per example

18.0

It seems like a systematic issue we need to correct. The problem is that computing Pr('.' | prompt) for these wouldn't discriminate at all. The `option` does discriminate. So let's just take the `option` out of the `prompt` and move it to the `completion`.

In [12]:
_mask_corrupt = wsc_exploded_df['completion'] == '.'
_wsc_corrupt = wsc_exploded_df.copy()[_mask_corrupt]
assert all(prompt.endswith(option)
           for prompt, option
           in zip(_wsc_corrupt['prompt'], _wsc_corrupt['option']))

In [13]:
_prompts_fixed = [prompt.removesuffix(option)
                  for prompt, option
                  in zip(_wsc_corrupt['prompt'], _wsc_corrupt['option'])]
_completions_fixed = wsc_exploded_df.loc[_mask_corrupt, 'option']

In [14]:
wsc_exploded_df.loc[_mask_corrupt, 'prompt']     = _prompts_fixed
wsc_exploded_df.loc[_mask_corrupt, 'completion'] = _completions_fixed

## just in case
wsc_exploded_df['prompt']     = wsc_exploded_df['prompt'].str.strip()
wsc_exploded_df['completion'] = wsc_exploded_df['completion'].str.strip()

In [15]:
with pd.option_context('max_colwidth', -1):
    display(wsc_exploded_df
            [['option', 'prompt', 'completion', 'label']]
            .loc[54])

Unnamed: 0,option,prompt,completion,label
54,the gap,There is a gap in the wall. You can see the garden through,the gap,0
54,the wall,There is a gap in the wall. You can see the garden through,the wall,0


There, all better.

# Run model

For WSC, the probability distribution over classes (alternative 1, 2 for COPA) is uniform. So we'll use `prior=None`.

In [16]:
wsc_examples = [classify.Example(prompt=record['prompt'],
                                 completions=(record['completion'],),
                                 prior=None)
                for record in wsc_exploded_df.to_dict('records')]

In [17]:
len(wsc_examples)

546

This next cell warns you about price, and asks if you're ready to pay.

(small TODO: I forgot what `input` does in Jupyter notebooks. I use VS Code notebooks now, and it just asks you to press the Enter key to proceed.)

In [18]:
all_tokens = gpt2_tokenizer([example.prompt + ' ' + completion
                             for example in wsc_examples
                             for completion in example.completions])
num_tokens = sum(len(tokens) for tokens in all_tokens['input_ids'])
cost_per_1k_tokens = 0.02 ## https://openai.com/api/pricing/
cost = round(num_tokens * cost_per_1k_tokens / 1_000, 2)

output = input(f'The next cell will cost you ${cost}. Proceed?')

In [19]:
pred_probs = classify.predict_proba_examples(wsc_examples,
                                             model='text-davinci-003')

Computing probs: 100%|██████████| 546/546 [00:12<00:00, 42.92it/s]


# Score

We flattened/exploded the examples so that there's one record for each (example, option) pair. To go back to the original format, we just need to batch `pred_probs`.

In [20]:
def process_probs(probs: np.ndarray, batch_sizes):
    pred_probs_unnorm = np.array(list(batch_variable(probs[:,0], batch_sizes)))
    return pred_probs_unnorm / pred_probs_unnorm.sum(axis=1, keepdims=True)

For WSC, the scoring metric is accuracy.

In [21]:
batch_sizes = wsc_df['options'].apply(len) ## ik they're all 2
pred_probs_norm = process_probs(pred_probs, batch_sizes)
(pred_probs_norm.argmax(axis=1) == wsc_df['label']).mean()

0.8901098901098901

This roughly matches the performance in the GPT-3 paper (section 3.4). I guess there wasn't much to learn from this b/c we're both basically using the same method. Nice to see that the code works I guess :-)

For transparency, some of the WSC examples were included in GPT-3's training data. The authors [studied this contamination](https://arxiv.org/pdf/2005.14165.pdf#page=31&zoom=100,96,89) and found that isn't much of an issue for WSC in particular.

While we're were, let's see how `text-curie-001` performs.

In [23]:
pred_probs_curie = classify.predict_proba_examples(wsc_examples,
                                                   model='text-curie-001')

Computing probs: 100%|██████████| 546/546 [00:07<00:00, 75.01it/s]


In [24]:
pred_probs_norm_curie = process_probs(pred_probs_curie, batch_sizes)
(pred_probs_norm_curie.argmax(axis=1) == wsc_df['label']).mean()

0.8131868131868132