Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
6d791d2
Create filter_v2.py
Muennighoff Mar 30, 2023
560409b
Update filter_v2.py
Muennighoff Mar 30, 2023
692361e
Update filter_v2.py
Muennighoff Mar 30, 2023
7fab270
Update filter_v2.py
Muennighoff Mar 30, 2023
3cf3295
Update filter_v2.py
Muennighoff Mar 31, 2023
d2a47e6
Add
Muennighoff Apr 2, 2023
45bfeca
Remove dup
Muennighoff Apr 2, 2023
36f1e40
Merge branch 'main' into filterv2
Muennighoff Apr 2, 2023
4771171
specify mods
Muennighoff Apr 2, 2023
5b453b0
Diff filter v1
Muennighoff Apr 3, 2023
18e8c98
Length filtering
Muennighoff Apr 3, 2023
a03bc01
Rmv todo
Muennighoff Apr 3, 2023
53b07b4
Add bloomz
Muennighoff Apr 6, 2023
777c6ba
Fixes
Muennighoff Apr 7, 2023
1ad5557
Add
Muennighoff Apr 9, 2023
47f9cf8
Fixes
Muennighoff Apr 14, 2023
25b4478
More fixes
Muennighoff Apr 14, 2023
2fcdd57
finetuning scripts
Muennighoff Apr 17, 2023
6335848
Updates
Muennighoff Apr 18, 2023
f4bd5b5
Cleanups
Muennighoff Apr 18, 2023
873a475
Deprec. fracs
Muennighoff Apr 18, 2023
d0a1e05
Fix
Muennighoff Apr 18, 2023
b702805
Add pp, shard & co
Muennighoff Apr 21, 2023
a73af46
Updates
Muennighoff Apr 22, 2023
4ce53ac
Add
Muennighoff Apr 23, 2023
0539e45
Add
Muennighoff May 3, 2023
a066d17
Fin script
Muennighoff May 6, 2023
b86299b
fix ckpt
Muennighoff May 6, 2023
25c1796
Update
Muennighoff May 6, 2023
52bf373
Add
Muennighoff May 9, 2023
121cff5
Add samples
Muennighoff Jun 6, 2023
d01372a
Updates
Muennighoff Jun 8, 2023
19a654b
Add
Muennighoff Jun 12, 2023
8c96bc8
Rename
Muennighoff Jun 12, 2023
7648c52
Fix
Muennighoff Jun 12, 2023
f594b10
Various updates
Muennighoff Jun 14, 2023
d6c2d86
Add inst
Muennighoff Jun 14, 2023
c3d1fbb
Clarify
Muennighoff Jun 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,53 @@

### WIP
### Fine-tuning


1. Get the StarCoderBase Megatron-LM checkpoint: `git clone https://huggingface.co/bigcode/starcoderbase-megatron`
2. Get Megatron-LM: `git clone -b mtf https://github.com/bigcode-project/Megatron-LM`
3. Prepare a Python environment with PyTorch. (TODO: There may be some other packages needed that you will find out about when training fails)
4. Prepare dataset: Preapre a finetuning dataset in the form of a single jsonl file with two keys: `inputs` & `outputs`. `inputs` should contain the prompt and instruction while `outputs` contains the targets. Loss will only be computed over `outputs`. See `dataset/commits_to_jsonl.py` for an example of doing this. In that example we put the instruction (commit message) in the target, but it's better to put it in the input.
5. Tokenize the fine-tuning dataset by modifying `dataset/preprocess.sh` to point to your jsonl dataset. Also modify the path of the tokenizer, in our case point to the StarCoder's `tokenizer.json` (`wget https://huggingface.co/bigcode/starcoderbase/raw/main/tokenizer.json`). Finally specify an output prefix where the tokenized data will be stored. Then run it with `bash dataset/preprocess.sh`.
6. Create two files `train_data_paths.txt.tmp` and `valid_data_paths.txt.tmp` that contain the paths to the above created tokenized dataset. For example they could look like `"train: 1.0 0:0.95 output_prefix"` and `"valid: 1.0 0.95:1.0 output_prefix`. In this case the dataset is split into 95% training and 5% validation. The first number is the weight of the dataset, the second number is the start of the dataset and the third number is the end of the dataset.
7. Rename the checkpoint downloaded to `release` i.e. `mv starcoderbase-megatron/iter* starcoderbase-megatron/release` and create a file `starcoderbase-megatron/latest_checkpointed_iteration.txt` that contains simply `release` (`echo release > starcoderbase-megatron/latest_checkpointed_iteration.txt`).
8. Modify `training/finetune_starcoderbase.sh` to adapt `CHECKPOINT_PATH` to point to the downloaded Megatron-LM checkpoint, `WEIGHTS_TRAIN` & `WEIGHTS_VALID` to point to the above created txt files, `TOKENIZER_FILE` to StarCoder's `tokenizer.json`, point to your environment and cache locations, and modify the SBATCH settings to suit your setup. Then run it with `bash training/finetune_starcoderbase.sh`. You can interrupt and resume training, however, if you resume, you need to remove `--no_load_optim` and `--no_load_rng` from the command line arguments in the script to load the optimizer and random number generator state from the newly saved checkpoint (we only do not want to load them from starcoderbase).
9. Convert the saved checkpoint using the instructions below.


#### Checkpoint conversion

1. Update the paths in `convert_large.sh` & download the marked repos & run it

#### Other

for idx in ["00001", "00002", "00003", "00004", "00005", "00006", "00007"]:
x = torch.load(f"/gpfsscratch/rech/ajs/commun/Bigcode-large-megatron_conv/base/shard2/pytorch_model-{idx}-of-00007.bin")
y = torch.load(f"/gpfsscratch/rech/ajs/commun/starcoderbase/pytorch_model-{idx}-of-00007.bin")
assert x.keys() == y.keys()
for k in x.keys():
if not((x[k] == y[k]).all()):
print(k)
print(x[k].shape)
print(y[k].shape)
break

# pip install -q transformers
```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

checkpoint = "/gpfsscratch/rech/ajs/commun/Bigcode-large-megatron_conv/base/shard"
checkpoint = "/gpfsscratch/rech/ajs/commun/Bigcode-large-megatron_conv/base3/"
device = "cuda" # for GPU usage or "cpu" for CPU usage

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(
checkpoint,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
).to(device)

inputs = tokenizer.encode("def print_hello_world():", return_tensors="pt").to(device)
outputs = model.generate(inputs, max_new_tokens=1)
print(tokenizer.decode(outputs[0]))
```
27 changes: 27 additions & 0 deletions dataset/commits_to_jsonl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import datasets
import random

NUM_PROC = 32
ds = datasets.load_dataset("commits-8192")["train"]

def prepare(example):
example["inputs"] = f"<commit_before>{example['old_contents']}<commit_msg>"
example["targets"] = f"{example['subject']}<commit_after>{example['new_contents']}<|endoftext|>"
return example

def prepare_code(example):
example["inputs"] = f"```\n{example['old_contents']}\n```\n"
example["targets"] = f"{example['subject']}\n```\n{example['new_contents']}\n```<|endoftext|>"
return example

def prepare_bigcode(example):
# With 50% probability add filename
if random.random() < 0.5:
example["inputs"] = f"<filename>{example['old_file'].split('/')[-1]}<commit_before>{example['old_contents']}<commit_msg>"
else:
example["inputs"] = f"<commit_before>{example['old_contents']}<commit_msg>"
example["targets"] = f"{example['subject']}<commit_after>{example['new_contents']}<|endoftext|>"
return example

ds = ds.map(prepare_bigcode, num_proc=NUM_PROC).select_columns(["inputs", "targets"])
ds.to_json("out.jsonl", orient="records", lines=True, force_ascii=False, num_proc=NUM_PROC)
373 changes: 373 additions & 0 deletions dataset/filter.py

Large diffs are not rendered by default.

Loading