Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CodeParrot 馃 codebase #14536

Merged
merged 26 commits into from Dec 2, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
66 changes: 66 additions & 0 deletions examples/research_projects/codeparrot/README.md
@@ -0,0 +1,66 @@
# CodeParrot 馃
<p align="center">
<img src="images/code-highlighting-streamlit.png" alt="drawing" width="350"/>
</p>

## What is this about?
lvwerra marked this conversation as resolved.
Show resolved Hide resolved
This is an open-source effort to train and evaluate code generation models. CodeParrot 馃 is a GPT-2 model trained from scratch on Python code. The highlights of this repo are:
lvwerra marked this conversation as resolved.
Show resolved Hide resolved
- initialize and train a GPT-2 language model from scratch for code generation
- clean and deduplicate a large (>100GB) dataset with `datasets`
- train with `accelerate` on multiple GPUs using data parallelism and mixed precision
- continuously push checkpoints to the hub with `huggingface_hub`
- stream the dataset with `datasets` during training to avoid disk bottlenecks
- apply `code_eval` metric in `datasets` to evaluate on OpenAI's HumanEval benchmark
lvwerra marked this conversation as resolved.
Show resolved Hide resolved

## Installation
To install the dependencies simply run the following command:
```bash
pip install -r requirements.txt
```

## Dataset
The source of the dataset is the GitHub dump available on Google's [BigQuery](https://cloud.google.com/blog/topics/public-datasets/github-on-bigquery-analyze-all-the-open-source-code). The database was queried for all Python files resulting in a 180GB dataset with over 20M files. The dataset is available on the Hugging Face Hub [here](https://huggingface.co/datasets/transformersbook/codeparrot).
lvwerra marked this conversation as resolved.
Show resolved Hide resolved

## Preprocessing
The raw dataset contains many duplications so the dataset was deduplicated and filtered using the heuristics proposed in the Codex [paper](https://arxiv.org/abs/2107.03374):

- exact deduplication
lvwerra marked this conversation as resolved.
Show resolved Hide resolved
- filtering files with max line length > 1000
- filtering files with mean line length > 100
- fraction of alphanumeric characters < 0.25
- containing the word "auto-generated" or similar in the first 5 lines

The script to process the full dataset can be found in `scripts/preprocessing.py`. Executing the script on 16 CPUs takes roughly 3h and removes 70% of the original dataset. The cleaned [train](https://huggingface.co/datasets/lvwerra/codeparrot-clean-train) and [validation](https://huggingface.co/datasets/lvwerra/codeparrot-clean-valid) splits are also available on the Hub.
lvwerra marked this conversation as resolved.
Show resolved Hide resolved
lvwerra marked this conversation as resolved.
Show resolved Hide resolved

## Training
The models are randomly initialized and trained from scratch. The initialization script can be found at `scripts/initialize.py`. The main training script is built with 馃 Accelerate to scale across a wide range of platforms and infrastructure scales.
lvwerra marked this conversation as resolved.
Show resolved Hide resolved
lvwerra marked this conversation as resolved.
Show resolved Hide resolved

We train two models with [110M](https://huggingface.co/lvwerra/codeparrot-small/) and [1.5B](https://huggingface.co/lvwerra/codeparrot/) parameters for 25-30B tokens on a 16xA100 (40GB) machine which takes 1 day and 1 week, respectively. The training script can be found in `scripts/codeparrot_training.py`.

## Evaluation
The validation loss can be calculate with the `scripts/validation_loss.py` script. In addition we evaluate the model on OpenAI's _HumanEval_ benchmark. The evaluation script can be found in `scripts/human_eval.py`. The results as well as reference values are shown in the following table:
lvwerra marked this conversation as resolved.
Show resolved Hide resolved

| Model | pass@1 | pass@10 | pass@100|
|-------|--------|---------|---------|
|CodeParrot 馃 (110M) | 3.80% | 6.57% | 12.78% |
|CodeParrot 馃 (1.5B) | 3.58% | 8.03% | 14.96% |
|||||
|Codex (25M)| 3.21% | 7.1% | 12.89%|
|Codex (85M)| 8.22% | 12.81% | 22.40% |
|Codex (300M)| 13.17%| 20.37% | 36.27% |
|Codex (12B)| 28.81%| 46.81% | 72.31% |
|||||
|GPT-neo (125M)| 0.75% | 1.88% | 2.97% |
|GPT-neo (1.5B)| 4.79% | 7.47% | 16.30% |
|GPT-neo (2.7B)| 6.41% | 11.27% | 21.37% |
|GPT-J (6B)| 11.62% | 15.74% | 27.74% |

Both CodeParrot 馃 models are still underfitted and longer training would likely improve the performance, especially for the large model.

## Demo
Give the model a shot yourself! There are two demos to interact with the model:
- [Code generation](https://huggingface.co/spaces/lvwerra/codeparrot-generation)
- [Code highlighting](https://huggingface.co/spaces/lvwerra/codeparrot-highlighting)

## Further Resources
A detailed description of the project can be found in the chapter "Training Transformers from Scratch" in the upcoming O'Reilly book [Natural Language Processing with Transformers](https://learning.oreilly.com/library/view/natural-language-processing/9781098103231/).
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 7 additions & 0 deletions examples/research_projects/codeparrot/requirements.txt
@@ -0,0 +1,7 @@
transformers==4.12.2
datasets==1.16.0
accelerate==0.5.1
wandb==0.12.0
tensorboard==2.6.0
torch==1.9.0
huggingface-hub==0.0.19
232 changes: 232 additions & 0 deletions examples/research_projects/codeparrot/scripts/codeparrot_training.py
@@ -0,0 +1,232 @@
import logging
from argparse import Namespace

import datasets
import torch
from datasets import load_dataset
from torch.utils.data import IterableDataset
from torch.utils.data.dataloader import DataLoader
from torch.utils.tensorboard import SummaryWriter

import transformers
import wandb
from accelerate import Accelerator
from huggingface_hub import Repository
from transformers import AdamW, AutoTokenizer, GPT2LMHeadModel, get_scheduler, set_seed
lvwerra marked this conversation as resolved.
Show resolved Hide resolved


class ConstantLengthDataset(IterableDataset):
lvwerra marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self, tokenizer, dataset, infinite=False, seq_length=1024, num_of_sequences=1024, chars_per_token=3.6
):
self.tokenizer = tokenizer
self.concat_token_id = tokenizer.bos_token_id
self.dataset = dataset
self.seq_length = seq_length
self.input_characters = seq_length * chars_per_token * num_of_sequences
self.epoch = 0
self.infinite = infinite

def __iter__(self):
iterator = iter(self.dataset)
more_examples = True
while more_examples:
buffer, buffer_len = [], 0
while True:
if buffer_len >= self.input_characters:
break
try:
buffer.append(next(iterator)["content"])
buffer_len += len(buffer[-1])
except StopIteration:
if self.infinite:
iterator = iter(self.dataset)
self.epoch += 1
logger.info(f"Dataset epoch: {self.epoch}")
else:
more_examples = False
break
tokenized_inputs = tokenizer(buffer, truncation=False)["input_ids"]
all_token_ids = []
for tokenized_input in tokenized_inputs:
all_token_ids.extend(tokenized_input + [self.concat_token_id])
for i in range(0, len(all_token_ids), self.seq_length):
input_ids = all_token_ids[i : i + self.seq_length]
if len(input_ids) == self.seq_length:
yield torch.tensor(input_ids)


def setup_logging(project_name):
lvwerra marked this conversation as resolved.
Show resolved Hide resolved
logger = logging.getLogger(__name__)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
handlers=[logging.FileHandler(f"log/debug_{accelerator.process_index}.log"), logging.StreamHandler()],
)
if accelerator.is_main_process: # we only want to setup logging once
wandb.init(project=project_name, config=args)
run_name = wandb.run.name
tb_writer = SummaryWriter()
tb_writer.add_hparams(vars(args), {"0": 0})
logger.setLevel(logging.INFO)
datasets.utils.logging.set_verbosity_info()
transformers.utils.logging.set_verbosity_info()
else:
tb_writer = None
run_name = ""
logger.setLevel(logging.ERROR)
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
return logger, tb_writer, run_name


def create_dataloaders(dataset_name, args):
ds_kwargs = {"streaming": True}
train_data = load_dataset(dataset_name + "-train", split="train", **ds_kwargs)
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed)
valid_data = load_dataset(dataset_name + "-valid", split="train", **ds_kwargs)
train_dataset = ConstantLengthDataset(tokenizer, train_data, infinite=True, seq_length=args.seq_length)
valid_dataset = ConstantLengthDataset(tokenizer, valid_data, infinite=False, seq_length=args.seq_length)
train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size)
eval_dataloader = DataLoader(valid_dataset, batch_size=args.valid_batch_size)
return train_dataloader, eval_dataloader


def get_grouped_params(model, args, no_decay=["bias", "LayerNorm.weight"]):
params_with_wd, params_without_wd = [], []
for n, p in model.named_parameters():
if any(nd in n for nd in no_decay):
params_without_wd.append(p)
else:
params_with_wd.append(p)
return [
{"params": params_with_wd, "weight_decay": args.weight_decay},
{"params": params_without_wd, "weight_decay": 0.0},
]


def log_metrics(step, metrics):
logger.info(f"Step {step}: {metrics}")
if accelerator.is_main_process:
wandb.log(metrics)
[tb_writer.add_scalar(k, v, step) for k, v in metrics.items()]


def evaluate(args):
model.eval()
losses = []
for step, batch in enumerate(eval_dataloader):
with torch.no_grad():
outputs = model(batch, labels=batch)
loss = outputs.loss.repeat(args.valid_batch_size)
losses.append(accelerator.gather(loss))
if args.max_eval_steps > 0 and step >= args.max_eval_steps:
break
loss = torch.mean(torch.cat(losses))
try:
perplexity = torch.exp(loss)
except OverflowError:
perplexity = float("inf")
return loss.item(), perplexity.item()


# Accelerator
accelerator = Accelerator()
acc_state = {str(k): str(v) for k, v in accelerator.state.__dict__.items()}

# Hyperparameters (codeparrot-small configs are in comments)
project_name = "lvwerra/codeparrot"
lvwerra marked this conversation as resolved.
Show resolved Hide resolved
dataset_name = "../codeparrot-clean"
config = {
lvwerra marked this conversation as resolved.
Show resolved Hide resolved
"train_batch_size": 2, # 16
"valid_batch_size": 2, # 16
"weight_decay": 0.1,
"shuffle_buffer": 1_000,
"learning_rate": 2e-4, # 5e-4
"lr_scheduler_type": "cosine",
"num_warmup_steps": 750, # 2000
"gradient_accumulation_steps": 16, # 1
"gradient_checkpointing": True, # False
"max_train_steps": 50_000, # 150_000
"max_eval_steps": -1,
"seq_length": 1024,
"seed": 1,
"save_checkpoint_steps": 50_000,
} # 15_000
lvwerra marked this conversation as resolved.
Show resolved Hide resolved
args = Namespace(**config, **acc_state)
samples_per_step = accelerator.state.num_processes * args.train_batch_size
set_seed(args.seed)

# Logging
logger, tb_writer, run_name = setup_logging(project_name.split("/")[1])
logger.info(accelerator.state)

# Load model and tokenizer
if accelerator.is_main_process:
hf_repo = Repository("./", clone_from=project_name, revision=run_name)
model = GPT2LMHeadModel.from_pretrained("./")
if args.gradient_checkpointing:
model.gradient_checkpointing_enable()
tokenizer = AutoTokenizer.from_pretrained("./")

# Load dataset and dataloader
train_dataloader, eval_dataloader = create_dataloaders(dataset_name, args)

# Prepare the optimizer and learning rate scheduler
optimizer = AdamW(get_grouped_params(model, args), lr=args.learning_rate)
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=args.num_warmup_steps,
num_training_steps=args.max_train_steps,
)


def get_lr():
return optimizer.param_groups[0]["lr"]


# Prepare everything with our `accelerator`.
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader
)

# Train model
model.train()
completed_steps = 0
for step, batch in enumerate(train_dataloader, start=1):
loss = model(batch, labels=batch, use_cache=False).loss
log_metrics(
step, {"lr": get_lr(), "samples": step * samples_per_step, "steps": completed_steps, "loss/train": loss.item()}
)
loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss)
if step % args.gradient_accumulation_steps == 0:
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
completed_steps += 1
if step % args.save_checkpoint_steps == 0:
logger.info("Evaluating and saving model checkpoint")
eval_loss, perplexity = evaluate(args)
log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity})
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained("./", save_function=accelerator.save)
if accelerator.is_main_process:
hf_repo.push_to_hub(commit_message=f"step {step}")
model.train()
if completed_steps >= args.max_train_steps:
break

# Evaluate and save the last checkpoint
logger.info("Evaluating and saving model after training")
eval_loss, perplexity = evaluate(args)
log_metrics(step, {"loss/eval": eval_loss, "perplexity": perplexity})
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained("./", save_function=accelerator.save)
if accelerator.is_main_process:
hf_repo.push_to_hub(commit_message="final model")
73 changes: 73 additions & 0 deletions examples/research_projects/codeparrot/scripts/human_eval.py
@@ -0,0 +1,73 @@
import json
import multiprocessing
import os
import re

from datasets import load_dataset, load_metric
from tqdm import tqdm

import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, set_seed


transformers.logging.set_verbosity_error()
os.environ["HF_ALLOW_CODE_EVAL"] = "1"
lvwerra marked this conversation as resolved.
Show resolved Hide resolved
os.environ["TOKENIZERS_PARALLELISM"] = "false"


def first_block(string):
"""Split off first block of code by scanning for class, def etc. on newlines."""
return re.split("\nclass|\ndef|\n#|\n@|\nprint|\nif", string)[0].rstrip()


def complete_code(pipe, prompt, num_completions=1, **gen_kwargs):
"""Complete prompt with text generation pipeline and return num_completions."""
prompt = pipe.tokenizer.eos_token + prompt
code_gens = pipe(prompt, num_return_sequences=num_completions, **gen_kwargs)
return [first_block(code_gen["generated_text"][len(prompt) :]) for code_gen in code_gens]


# Settings
gen_kwargs = {
lvwerra marked this conversation as resolved.
Show resolved Hide resolved
"do_sample": True,
"temperature": 0.2,
"max_new_tokens": 256,
"top_p": 0.95,
"top_k": 0,
}

bs = 10
samples = 200
num_workers = multiprocessing.cpu_count()
model_ckpt = "lvwerra/codeparrot"
set_seed(1)

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModelForCausalLM.from_pretrained(model_ckpt)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)

# Load evaluation dataset and metric
human_eval = load_dataset("openai_humaneval")
code_eval_metric = load_metric("code_eval")

# Generate completions for evaluation set
n_tasks = len(human_eval["test"])
generations, references = [], []
for task in tqdm(range(n_tasks)):
task_generations = []
prompt = human_eval["test"][task]["prompt"].strip()
for batch in range(samples // bs):
task_generations.extend(complete_code(pipe, prompt, num_completions=bs, **gen_kwargs))
generations.append([prompt + gen for gen in task_generations])
test_func = human_eval["test"][task]["test"]
entry_point = f"check({human_eval['test'][task]['entry_point']})"
references.append("\n" + test_func + "\n" + entry_point)

# Evaluate completions with "code_eval" metric
pass_at_k, _ = code_eval_metric.compute(references=references, predictions=generations, num_workers=num_workers)
print(f"Results: {pass_at_k}")

# Save results to json file
with open("eval_results.json", "w") as fp:
json.dump(pass_at_k, fp)