Skip to content

bnativi/Parallel-Context-Windows

 
 

Repository files navigation

Parallel Context Windows (PCW)

This repo contains the code for reproducing the classification experiments on GPT2 models from AI21 Labs' paper Parallel Parallel Context Windows for Large Language Models .
The code was tested with python 3.10, with CPU and single GPU.

Setup

To install the required libraries in our repo, run:

pip install -r requirements.txt

To have a Pytorch version specific to your CUDA, install your version before running the above command.

Evaluation

Due to the fact that the paper's results were based on an earlier implementation of PCW and not HuggingFace Transformers, the results produced using this code may differ slightly from those shown in the paper. To reproduce similiar results shown in the appendix for GPT2-XL for a specific dataset (for example SST2), simply run:

python run_evaluation.py \
--dataset sst2 \
--model gpt2-xl \
--n-windows 1 \
--n-windows 3 \
--subsample-test-set 250 \
--n-runs 30 \
--output-dir $OUTPUT_DIR

In this run, PCW's performance is evaluated on a subsample (250 samples) of the full test set. The experiment is repeated 30 times (with different random samples of training examples) for each number of windows (in this case - one and three). As a default, the script uses as many examples per window as possible. Note that using a single window is equivalent to regular ICL settings. Thus, this run should give similar results to those shown in Table 5 for SST2 with GPT2-XL.

The evaluation output is a numpy file (shaped [2,30]) found in $OUTPUT_DIR with the mean accuracy for each repetition and number of windows. You could read the file directly with np.load, or use utils.py function to load and plot the results. See --help for further instructions.

PCW Usage examples

In the evaluation code, only classification tasks are performed. The code snippet below shows how PCW can be used both for classification and generation:

from transformers import AutoConfig
from modeling_gpt2_with_pcw import RestrictiveTokensLogitsProcessor, GPT2LMHeadWithPCWModel
import numpy as np

model_name = 'gpt2-large'
num_windows = 2

config = AutoConfig.from_pretrained(model_name)
config.n_positions = config.n_positions * num_windows
model = GPT2LMHeadWithPCWModel.from_pretrained(model_name, config=config, ignore_mismatched_sizes=True)

# use PCW with few shot for classification example:
labels = ['positive', 'negative']
labels_input_ids = np.array(
    [model.tokenizer.encode(f' {label.lstrip()}', add_special_tokens=False) for label in labels])
# using RestrictiveTokensLogitsProcessor forces the output to be one of the labels:
logit_processor = RestrictiveTokensLogitsProcessor(labels_input_ids, eos_token_id=model.tokenizer.eos_token_id)
output = model.pcw_generate(contexts=["Review: Great movie! Sentiment: positive\n",
                                      "Review: Horrible film Sentiment: negative\n"],
                            task_text="Review: I liked it Sentiment:",
                            restrictive_logit_preprocessor=logit_processor,
                            temperature=0,
                            max_new_tokens=1)
print(output.strip())

# use PCW for generation:
output = model.pcw_generate(contexts=["Review: Great movie!\n", "Review: Horrible film\n"],
                            task_text="Review:",
                            temperature=1,
                            do_sample=True,
                            max_new_tokens=16)
print(output)

Citation

If you find our paper or code helpful, please consider citing our paper:

@misc{ratner2023parallel,
      title={Parallel Context Windows for Large Language Models}, 
      author={Nir Ratner and Yoav Levine and Yonatan Belinkov and Ori Ram and Inbal Magar and Omri Abend and Ehud Karpas and Amnon Shashua and Kevin Leyton-Brown and Yoav Shoham},
      year={2023},
      eprint={2212.10947},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%