**Description**: first attempt at a computational comparison using the COPA data and (binary) classification task.

**Estimated run time**: ~1 min.

**Environment**: dev

**Other**: run [this notebook in Google Colab](https://drive.google.com/file/d/1ehO5YfLDGawtajEe2E3QnWPQjXe7cOZs/view?usp=sharing) on a GPU.

**TODO**: figure out what exactly I want to do here

[Motivation](#motivation)

[Setup](#setup)

[Load data and model](#load-data-and-model)

[Write prompt](#write-prompt)

[Load model](#load-model)

# Motivation

What's the computational performance of this method?

For a refresher on this method, see the [Example section here](https://stats.stackexchange.com/q/601159/337906).

Prompt-completion classification requires as many `model()` calls as there are classes. But we can exploit the auto-regressive nature of GPT-x to remove repeated computation for the prompts. I'll add more on how this is done at a high-level. The low-level implementation is in `fast.py`.

Sampling is auto-regressive, which is sequential. There's no way to parallelize that.

There are three factors which determine runtime (all else equal):
  1. The number of classes in the classification problem
  2. The number of tokens in the prompt
  3. The number of tokens in each completion/class.

We *could* exhaustively collect runtimes for many combinations of (1), (2), and (3). But that's expensive and doesn't directly answer the question we care about: across all *real-world* text classification tasks and datasets, what's the computational performance of this method vs CVS? I think it's safe to assume that the real-world distributions of (1), (2), and (3) are somewhat concentrated, so we don't need to uniformly compute things. To efficiently and directly (but approximately) answer the question we care about, we should sample a bunch of real-world text datasets and collect runtimes.

# Setup

Since GPT-x models are pretty much exclusively run on GPUs, this notebook should be run on a GPU machine. So run [this notebook in Google Colab](https://drive.google.com/file/d/1ehO5YfLDGawtajEe2E3QnWPQjXe7cOZs/view?usp=sharing) on a GPU. To get a (spot) GPU machine in Colab: Runtime -> Change runtime type -> Hardware accelerator: GPU.

If you're in Google Colab, you should uncomment and run this cell:

In [1]:
# !python -m pip install datasets transformers --quiet

In [1]:
from __future__ import annotations
from typing import Literal

import datasets as nlp_datasets
import pandas as pd
from tqdm.auto import tqdm

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2LMHeadModel
from transformers import pipeline

from callm import classify

import fast
import slow

In [2]:
# assert torch.cuda.is_available(), 'This experiment needs to run on a GPU'

In [3]:
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Load data and model

In [4]:
def load_super_glue(task_id: str, split: str):
    return pd.DataFrame(nlp_datasets
                        .load_dataset('super_glue', task_id, split=split))


## takes about 12 seconds, sorry
df = (pd.concat((load_super_glue('copa', 'train'),
                 load_super_glue('copa', 'validation')))
      .reset_index(drop=True).head(50)) ## the idx column is only unique w/in splits! fuhgetaboutit

Found cached dataset super_glue (C:/Users/kushd/.cache/huggingface/datasets/super_glue/copa/1.0.3/bb9675f958ebfee0d5d6dc5476fafe38c79123727a7258d515c450873dbdbbed)
Found cached dataset super_glue (C:/Users/kushd/.cache/huggingface/datasets/super_glue/copa/1.0.3/bb9675f958ebfee0d5d6dc5476fafe38c79123727a7258d515c450873dbdbbed)


In [13]:
gpt2_name = 'gpt2' ## smallest in the GPT-2 line
gpt2 = AutoModelForCausalLM.from_pretrained(gpt2_name).to(DEVICE)

tokenizer = AutoTokenizer.from_pretrained(gpt2_name)
tokenizer.pad_token_id = tokenizer.eos_token_id ## allow padding -> allow batching

# Write prompt

In [6]:
def _conjunction(question: Literal['cause', 'effect']):
    if question == 'cause':
        return ' because'
    elif question == 'effect':
        return ', so'
    else:
        raise ValueError( "question must be 'cause' or 'effect'. Got "
                         f'{question}.')


def prompt(premise: str, question: Literal['cause', 'effect']):
    conjunction = _conjunction(question)
    return f'{premise.strip(". ")}{conjunction}'

In [7]:
def prompt_mc(premise: str, question: Literal['cause', 'effect'],
              choice1: str, choice2: str):
    return (f'{prompt(premise, question)}\n'
            f'A. {choice1}\n'
            f'B. {choice2}\n'
             'Answer A or B.')

In [8]:
df['prompt'] = [prompt(premise, question)
                for premise, question
                in zip(df['premise'], df['question'])]


df['prompt_mc'] = [prompt_mc(record['premise'], record['question'],
                             record['choice1'], record['choice2'])
                   for record in df.to_dict('records')]

In [9]:
examples = [classify.Example(prompt=record['prompt'],
                             completions=(record['choice1'].lower(),
                                          record['choice2'].lower()),
                             prior=None)
            for record in df.to_dict('records')]

In [10]:
df['choices'] = list(zip(df['choice1'], df['choice2']))

texts = [prompt + ' ' + choice.lower()
         for prompt, choices in zip(df['prompt'], df['choices'])
         for choice in choices]

# Run model

In [16]:
token_logits, encodings = slow._logits_completions_given_prompts_examples(
                              gpt2, tokenizer, examples, batch_size=10
                          )

logits (slow):   0%|          | 0/100 [00:00<?, ?it/s]

In [17]:
token_logits, encodings = fast._logits_completions_given_prompts_examples(
                              gpt2, tokenizer, examples, batch_size=10
                          )

logits (fast):   0%|          | 0/50 [00:00<?, ?it/s]

In [19]:
class TextsDataset(Dataset):
    def __init__(self, texts: list[str]):
        self.texts = texts

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, index: int):
        return self.texts[index]

In [20]:
def gpt_log_probs(texts_batch: list[str]):
    inputs = tokenizer(texts_batch, padding=True,
                       return_tensors='pt').to(DEVICE)
    with torch.no_grad():
        out = gpt2(**inputs)
        return F.log_softmax(out.logits, dim=2)

In [21]:
num_obs_to_classify = 500 ## <= 500
num_classes = 2 ## constant for COPA

texts_sample = texts[:(num_obs_to_classify*num_classes)]
texts_sample_mc = df['prompt_mc'].head(num_obs_to_classify).tolist()

In [22]:
dataloader_kwargs = dict(batch_size=32,
                         shuffle=False)

texts_dataset = TextsDataset(texts_sample)
texts_dataloader = DataLoader(texts_dataset, **dataloader_kwargs)

texts_mc_dataset = TextsDataset(texts_sample_mc)
texts_mc_dataloader = DataLoader(texts_mc_dataset, **dataloader_kwargs)

In [23]:
## quick check
_ = gpt_log_probs(['test test'])

In [24]:
%%timeit
with tqdm(total=len(texts_dataset)) as progress_bar:
    for texts_batch in texts_dataloader:
        _ = gpt_log_probs(texts_batch)
        progress_bar.update(len(texts_batch))

  0%|          | 0/60 [00:00<?, ?it/s]

In [None]:
max_new_tokens = len(tokenizer('\n\nAnswer A').input_ids)
generator = pipeline('text-generation', model=gpt2_name,
                     max_new_tokens=max_new_tokens,
                     device=DEVICE)
generator.tokenizer.padding_side = 'left'
generator.tokenizer.pad_token_id = generator.model.config.eos_token_id

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [None]:
%%timeit
for samples_batch in tqdm(generator(texts_mc_dataset,
                                    ## suppress "Setting pad_token_id..." output
                                    pad_token_id=generator.tokenizer.eos_token_id,
                                    batch_size=dataloader_kwargs['batch_size']),
                          total=len(texts_mc_dataset)):
    pass

  0%|          | 0/500 [00:00<?, ?it/s]



  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

2.77 s ± 498 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
