Skip to content

Commit

Permalink
Adding a fine-tuning guide (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
CurtisASmith authored Jul 13, 2021
1 parent 831f405 commit 006e1bd
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 1 deletion.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
/data/
*.pyc
*.pprof
/ckpt/
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ shards in the case of GPT-J-6B) down to a smaller number, such as for when runni

### Fine-tuning

**Added July 12 2021:** Please read the new guide in the repo, `howto_finetune.md`, for thorough fine-tuning instructions. Below are the original instructions.

To fine-tune the model, run `device_train.py` on a TPU VM. If you use a TPU v8-3, you can fine-tune at a rate of ~5000 tokens/second, which should be sufficient for small-to-medium-size datasets.

For usage information, run `python3 device_train.py --help`.
Expand Down
40 changes: 40 additions & 0 deletions configs/example_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"layers": 28,
"d_model": 4096,
"n_heads": 16,
"n_vocab": 50400,
"norm": "layernorm",
"pe": "rotary",
"pe_rotary_dims": 64,

"seq": 2048,
"cores_per_replica": 8,
"per_replica_batch": 1,
"gradient_accumulation_steps": 16,

"warmup_steps": 7,
"anneal_steps": 65,
"lr": 5e-5,
"end_lr": 1e-5,
"weight_decay": 0.1,
"total_steps": 72,

"tpu_size": 8,

"bucket": "your-bucket",
"model_dir": "finetune_dir",

"train_set": "example.train.index",
"val_set": {},

"eval_harness_tasks": [
],

"val_batches": 0,
"val_every": 80,
"ckpt_every": 72,
"keep_every": 72,

"name": "example_model",
"comment": ""
}
1 change: 1 addition & 0 deletions data/example.train.index
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
gs://your-bucket/datasets/your.tfrecords
61 changes: 61 additions & 0 deletions howto_finetune.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# How to Fine-Tune GPT-J - The Basics

Before anything else, you'll likely want to apply for access to the TPU Research Cloud (TRC). Combined with a Google Cloud free trial, that should allow you to do everything here for free. Once you're in TRC, you need to create a project, then with the name of the new project fill out the form that was emailed to you. Use `create_tfrecords.py` from the [GPT-NEO](https://github.com/EleutherAI/gpt-neo/blob/master/data/create_tfrecords.py) repo to prepare your data as tfrecords; I might do a separate guide on that. Another thing you might want to do is fork the mesh-transformer-jax repo to make it easier to add and modify the config files.

0. [Install the Google Cloud SDK](https://cloud.google.com/sdk/docs/install). We'll need it later.

1. If you didn't make a project and activate TPU access through TRC yet (or if you plan on paying out of pocket), [make one now](https://console.cloud.google.com/projectcreate).

2. TPUs use Google Cloud buckets for storage, go ahead and [create one now](https://console.cloud.google.com/storage/create-bucket). Make sure it's in the region the TPU VM will be; the email from TRC will tell you which region(s) you can use free TPUs in.

3. You'll need the full pretrained weights in order to fine-tune the model. [Download those here](https://the-eye.eu/public/AI/GPT-J-6B/step_383500.tar.zstd).

Now that you have a bucket on the cloud and the weights on your PC, you need to upload the weights to the bucket in two steps:

4. Decompress and extract `GPT-J-6B/step_383500.tar.zstd` so you're left with the uncompressed folder containing the sharded checkpoint.

5. Open the Google Cloud SDK and run the following command, replacing the path names as appropriate: `gsutil -m cp -R LOCAL_PATH_TO/step_383500 gs://YOUR-BUCKET`. If that works, the console will show the files being uploaded. *Note: Took about 12 hours for me, uploading to the Netherlands from California; hopefully you'll have a better geographic situation than I did! I also initially made the mistake of uploading the still-packed .tar. Don't do that, TPU VMs don't have enough local storage for you to unpack it. To avoid needing to reupload, I had to unpack it in Colab.*

You'll want to upload tfrecords of your data as well, you can do that here or through the web interface, but trust me when I say you don't want to upload the nearly 70GB weights through the web interface.

Note that steps 6 and 7, preparing the index and config files, can be done later on by editing the base repo in the VM's text editor. It's more efficient to instead make these changes to your own fork of the repo as follows:

6. In the data folder, create a new file `foo.train.index`, replace foo with whatever you want to refer to your dataset as. For each tfrecord in your bucket that you intend to train with, add the path as a line in the index. Make `foo.val.index` and do the same for your validation dataset (if you have one). See the existing files for examples.

7. Duplicate the config file `6B_roto_256.json`, rename it to something appropriate for your project. Open it up and make these edits:
- `tpu_size`: Change from `256` to `8`
- `bucket`: Change to your bucket
- `model_dir`: Change to the directory you'd like to save your checkpoints in
- `train_set` and `val_set`: Change to the index files from the last step
- `eval_harness_tasks`: Can be removed if you don't plan on using the eval harness
- `val_batches` & `val_every` & `ckpt_every` & `keep_every`: Usage should be intuitive. Don't set the `foo_every` values to 0 though or you'll get a divide by zero error. If you don't have a `val_set`, just set `val_every` to something higher than `total_steps`.
- `name`: Change to a name for your model
- `warmup_steps`, `lr`,etc.: see the *Learning Rate Notes* section at the end of the guide

8. Push the changes to your GitHub repo.

9. Follow [this guide](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm) up to and including the step **"Connect to your Cloud TPU VM"**.

At this point you should have remote access to the TPU VM!

10. In the new VM terminal, type `git clone https://github.com/kingoflolz/mesh-transformer-jax` (or, preferably, your own fork, after pushing the config and index files)

11. Move to the new directory with `cd mesh-transformer-jax` and run `pip install -r requirements.txt`. For whatever reason, the requirements file doesn't *seem* to install the correct version of Jax... that is, it does, but something must override it later and I haven't figured out what That's okay, just run `pip install jax==0.2.12` and you'll be all set.

12. Finally, run `python3 device_train.py --config=YOUR_CONFIG.json --tune-model-path=gs://YOUR-BUCKET/step_383500/`. If everything is set up correctly this will begin the fine-tuning process. First the model has to be loaded into memory; when `loading network` displayed on the console it took about 10-15 minutes before the next step, setting up WandB for logging. Option 3 allows you to skip that if you aren't using WandB. A step 1 checkpoint will save, and the real training will start. If you have a small dataset, this will go by quickly; TPU VMs can train at a rate of ~5000 tokens/second.

13. You did it! Now don't forget any clean up steps you need to take like shutting down your TPU VM or removing unneeded data in buckets, so that you don't have any unexpected charges from Google later.

### Now what?

This guide is labeled "The Basics", anything we haven't covered so far is out of scope, but go check out the rest of the repository! Try `python3 device_sample.py --config=YOUR_CONFIG.json` for a basic sampling interface. Use `slim_model.py` to prepare an easier-to-deploy slim version of your new weights for inference. Experiment!

## Learning Rate Notes

**Thanks to nostalgebraist for talking about this!** They're the one who explained this part on Discord, I'm just paraphrasing really:

The first thing you want to determine is how long a training epoch will be. `gradient_accumulation_steps` is your batch size, it defaults to `16`, nostalgebraist recommends `32`. Your .tfrecord files should have a number in the file name indicating how many sequences are in the dataset. Divide that number by the batch size and the result is how many steps are in an epoch. Now we can write the schedule.

`lr` is recommended to be between `1e-5` and `5e-5`, with `end_lr` set to 1/5 or 1/10 of `lr`. `weight_decay` can remain `0.1`. `total_steps` should be at least one epoch, longer if you have a validation set to determine your training loss with. `warmup_steps` should be 5-10% of total, and finally `anneal_steps` should be `total_steps - warmup_steps`.

To illustrate: I have a small dataset that tokenized into 1147 sequences as a .tfrecord. Dividing by `gradient_accumulation_steps` set to `16`, rounding up to ensure I use all the data, equals 72 steps per epoch. I'll set `lr` to `5e-5`, `end_lr` to a fifth of that, `1e-5`; that may be too much, it's on the high end of the recommended range. I'll set `total_steps` to `72` for one epoch, since I don't have a validation set. Then I'll set `anneal_steps` to `65` and `warmup_steps` to `7`. Simple as that, but you may need to fiddle with the specifics on your own.

0 comments on commit 006e1bd

Please sign in to comment.