# IN PROGRESS

**Description**: first attempt at a computational comparison using the COPA data and (binary) classification task.

**Estimated run time**: ~1 min.

**Environment**: dev

**Other**: run [this notebook in Google Colab](https://colab.research.google.com/drive/1Xw95EUt7SxlfN3Hjq0GtLZIy0O0J27nX?usp=sharing) on a GPU.

**TODO**: figure out what exactly I want to do here.

[Motivation](#motivation)

[Setup](#setup)

[Load data and model](#load-data-and-model)

[Write prompt](#write-prompt)

[Load model](#load-model)

# Motivation

What's the computational performance of CALLM?

For a refresher on this method, see the [Example section here](https://stats.stackexchange.com/q/601159/337906).

Prompt-completion classification requires as many `model()` calls as there are classes. But we can exploit the auto-regressive nature of GPT to remove repeated computation for the prompts. I'll add more on how this is done at a high-level. The low-level implementation is in [`callm.huggingface.classify`](https://github.com/kddubey/callm/blob/main/callm/huggingface/classify.py).

Greedy sampling can't be parallelized.

There are three factors which determine runtime:
  1. The number of classes in the classification problem
  2. The number of tokens in the prompt
  3. The number of tokens in each completion/class.

We *could* exhaustively collect runtimes for many combinations of (1), (2), and (3). But that's expensive and doesn't directly answer the question we care about: across all *real-world* text classification tasks and datasets, what's the computational performance of this method vs CVS? I think it's safe to assume that the real-world distributions of (1), (2), and (3) are somewhat concentrated, so we don't need to uniformly compute things. To efficiently and directly (but approximately) answer the question we care about, we should sample a bunch of real-world text datasets and collect runtimes.

# Setup

Since GPT-x models are pretty much exclusively run on GPUs, this notebook should be run on a GPU machine. So run [this notebook in Google Colab](https://drive.google.com/file/d/1ehO5YfLDGawtajEe2E3QnWPQjXe7cOZs/view?usp=sharing) on a GPU. To get a (spot) GPU machine in Colab: Runtime -> Change runtime type -> Hardware accelerator: GPU.

If you're in Google Colab, you should uncomment and run this cell:

In [None]:
!python -m pip install "callm[demos] @ git+https://github.com/kddubey/callm.git"

In [2]:
from __future__ import annotations
import string
from typing import Literal, Sequence

import datasets as nlp_datasets
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import pipeline

from callm.example import Example
from callm.huggingface import classify as fast
from callm.huggingface import classify_slow as slow
from callm.utils import batch

In [3]:
assert torch.cuda.is_available(), 'This experiment should be run on a GPU'

In [4]:
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Load data and model

In [5]:
def load_super_glue(task_id: str, split: str):
    return pd.DataFrame(nlp_datasets
                        .load_dataset('super_glue', task_id, split=split))


## takes about 12 seconds, sorry
df = (pd.concat((load_super_glue('copa', 'train'),
                 load_super_glue('copa', 'validation')))
      .reset_index(drop=True)) ## the idx column is only unique w/in splits! fuhgetaboutit



In [6]:
gpt2_name = 'gpt2' ## smallest in the GPT-2 line
gpt2 = AutoModelForCausalLM.from_pretrained(gpt2_name).to(DEVICE)

tokenizer = AutoTokenizer.from_pretrained(gpt2_name)
tokenizer.pad_token_id = tokenizer.eos_token_id ## allow padding -> allow batching

# Write prompt

In [7]:
def _conjunction(question: Literal['cause', 'effect']):
    if question == 'cause':
        return ' because'
    elif question == 'effect':
        return ', so'
    else:
        raise ValueError( "question must be 'cause' or 'effect'. Got "
                         f'{question}.')


def prompt(premise: str, question: Literal['cause', 'effect']):
    conjunction = _conjunction(question)
    return f'{premise.strip(". ")}{conjunction}'

In [8]:
english_alphabet = string.ascii_uppercase

## NOT meant for actual use! I just need something basic to test computation
def mc(*choices) -> str:
    if len(choices) > len(english_alphabet):
        raise ValueError('Nope!')
    choices_str = '\n'.join([f'{letter}. {choice}'
                             for choice, letter
                             in zip(choices, english_alphabet)])
    return choices_str + '\n' + 'Pick one from above.'

In [9]:
print(mc('Green eggs', 'Ham', 'Scooby Dooby'))

A. Green eggs
B. Ham
C. Scooby Dooby
Pick one from above.


In [10]:
def prompt_mc(premise: str, question: Literal['cause', 'effect'], *choices):
    return (f'{prompt(premise, question)}\n'
            f'{mc(*choices)}')

In [11]:
df['prompt'] = [prompt(premise, question)
                for premise, question
                in zip(df['premise'], df['question'])]

df['choices_2'] = list(zip(df['choice1'].str.lower().str.lstrip(),
                           df['choice2'].str.lower().str.lstrip()))
df['choices_10'] = list(zip(*[df['choice1'].str.lower().str.lstrip(),
                              df['choice2'].str.lower().str.lstrip()]*5))


df['prompt_mc_2'] = [prompt_mc(record['premise'], record['question'],
                               *record['choices_2'])
                     for record in df.to_dict('records')]

df['prompt_mc_10'] = [prompt_mc(record['premise'], record['question'],
                                *record['choices_10'])
                      for record in df.to_dict('records')]

In [12]:
examples_2 = [Example(prompt=record['prompt'],
                      completions=record['choices_2'],
                      prior=None)
              for record in df.to_dict('records')]

In [13]:
examples_10 = [Example(prompt=record['prompt'],
                       completions=record['choices_10'],
                       prior=None)
               for record in df.to_dict('records')]

# Run CALLM

In [14]:
dataloader_kwargs = dict(batch_size=32,
                         shuffle=False)

In [15]:
%%timeit
pred_probs_slow = slow.predict_proba_examples(examples_10,
                                              model_and_tokenizer=(gpt2, tokenizer),
                                              batch_size=dataloader_kwargs['batch_size'])

log-probs (fast):   0%|          | 0/500 [00:00<?, ?it/s]

log-probs (fast):   0%|          | 0/500 [00:00<?, ?it/s]

log-probs (fast):   0%|          | 0/500 [00:00<?, ?it/s]

log-probs (fast):   0%|          | 0/500 [00:00<?, ?it/s]

log-probs (fast):   0%|          | 0/500 [00:00<?, ?it/s]

log-probs (fast):   0%|          | 0/500 [00:00<?, ?it/s]

log-probs (fast):   0%|          | 0/500 [00:00<?, ?it/s]

log-probs (fast):   0%|          | 0/500 [00:00<?, ?it/s]

9.17 s ± 181 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [16]:
%%timeit
pred_probs_fast = fast.predict_proba_examples(examples_10,
                                              model_and_tokenizer=(gpt2, tokenizer),
                                              batch_size=dataloader_kwargs['batch_size'])

log-probs (fast):   0%|          | 0/500 [00:00<?, ?it/s]

log-probs (fast):   0%|          | 0/500 [00:00<?, ?it/s]

log-probs (fast):   0%|          | 0/500 [00:00<?, ?it/s]

log-probs (fast):   0%|          | 0/500 [00:00<?, ?it/s]

log-probs (fast):   0%|          | 0/500 [00:00<?, ?it/s]

log-probs (fast):   0%|          | 0/500 [00:00<?, ?it/s]

log-probs (fast):   0%|          | 0/500 [00:00<?, ?it/s]

log-probs (fast):   0%|          | 0/500 [00:00<?, ?it/s]

5.85 s ± 84.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# CVS

In [17]:
class TextsDataset(Dataset):
    def __init__(self, texts: list[str]):
        self.texts = texts

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, index: int):
        return self.texts[index]

texts_mc_dataset_2 = TextsDataset(df['prompt_mc_2'].tolist())
texts_mc_dataloader_2 = DataLoader(texts_mc_dataset_2, **dataloader_kwargs)

texts_mc_dataset_10 = TextsDataset(df['prompt_mc_10'].tolist())
texts_mc_dataloader_10 = DataLoader(texts_mc_dataset_10, **dataloader_kwargs)

In [18]:
max_new_tokens = len(tokenizer('\n\nAnswer A').input_ids)
generator = pipeline('text-generation', model=gpt2_name,
                     max_new_tokens=max_new_tokens,
                     device=DEVICE)
generator.tokenizer.padding_side = 'left' ## for sampling
generator.tokenizer.pad_token_id = generator.model.config.eos_token_id

In [19]:
%%timeit
for samples_batch in tqdm(generator(texts_mc_dataset_10,
                                    ## suppress "Setting pad_token_id..." output
                                    pad_token_id=generator.tokenizer.eos_token_id,
                                    batch_size=dataloader_kwargs['batch_size']),
                          total=len(texts_mc_dataset_10)):
    pass

  0%|          | 0/500 [00:00<?, ?it/s]



  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

7.43 s ± 175 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
