Skip to content

Commit

Permalink
Add W&B lr logging and project naming (#80)
Browse files Browse the repository at this point in the history
  • Loading branch information
morganmcg1 authored Aug 10, 2021
1 parent a678e30 commit 8aaa4f9
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 6 deletions.
1 change: 1 addition & 0 deletions configs/6B_roto_256.json
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,6 @@
"keep_every": 10000,

"name": "GPT3_6B_pile_rotary",
"wandb_project": "mesh-transformer-jax",
"comment": ""
}
1 change: 1 addition & 0 deletions configs/example_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,6 @@
"keep_every": 72,

"name": "example_model",
"wandb_project": "mesh-transformer-jax",
"comment": ""
}
13 changes: 9 additions & 4 deletions device_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def parse_args():
- set `tpu_size` to 8 (if on a v3-8)
- set `warmup_steps`, `anneal_steps`, `lr`, `end_lr` to the lr schedule for your finetuning run
- the global step will reset to 0, keep that in mind when writing your lr schedule
- set `name` to specify the name of the Weights & Biases run
- set `wandb_project` to specify the Weights & Biases project to log to
To prepare data in the expected data format:
- use the script `create_finetune_tfrecords.py` in this repo to create data in the expected format
- upload the .tfrecords files to GCS
Expand Down Expand Up @@ -165,17 +167,19 @@ def eval_step(network, data):
lr = params["lr"]
end_lr = params["end_lr"]
weight_decay = params["weight_decay"]

# alpha parameter for the exponential moving averages used to compute B_simple
noise_scale_alpha = params.get("noise_scale_alpha", 0.01)

scheduler = util.gpt3_schedule(warmup_steps, anneal_steps, lr, end_lr)

opt = optax.chain(
optax.scale(1 / gradient_accumulation_steps),
clip_by_global_norm(1),
optax.scale_by_adam(),
additive_weight_decay(weight_decay),
optax.scale(-1),
optax.scale_by_schedule(util.gpt3_schedule(warmup_steps, anneal_steps, lr, end_lr))
optax.scale_by_schedule(scheduler)
)

params["optimizer"] = opt
Expand Down Expand Up @@ -286,7 +290,8 @@ def eval_step(network, data):
val_set.reset()
print(f"Eval fn compiled in {time.time() - start:.06}s")

wandb.init(project='mesh-transformer-jax', name=params["name"], config=params)
project = params.get("wandb_project", "mesh-transformer-jax")
wandb.init(project=project, name=params["name"], config=params)

G_noise_avg = None
S_noise_avg = None
Expand Down Expand Up @@ -326,7 +331,6 @@ def eval_step(network, data):

steps_per_sec = 1 / (time.time() - start)
tokens_per_sec = tokens_per_step * steps_per_sec

sequences_processed = sequences_per_step * step
tokens_processed = tokens_per_step * step

Expand Down Expand Up @@ -386,6 +390,7 @@ def eval_step(network, data):
"train/steps_per_sec": steps_per_sec,
"train/tokens_per_sec": tokens_per_sec,
"train/grad_norm": grad_norm,
"train/learning_rate": float(scheduler(network.state["opt_state"][-1].count[0].item())),
"sequences_processed": sequences_processed,
"tokens_processed": tokens_processed,
}
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ numpy~=1.19.5
transformers~=4.8.2
tqdm~=4.45.0
setuptools~=51.3.3
wandb~=0.10.22
wandb>=0.11.2
einops~=0.3.0
requests~=2.25.1
fabric~=2.6.0
Expand Down
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def parse_args():
t.eval(val_set.get_samples())
print(f"Eval fn compiled in {time.time() - start:.06}s")

wandb.init(project='mesh-transformer-jax', entity="eleutherai", name=params["name"], config=params)
project = params.get("wandb_project", "mesh-transformer-jax")
wandb.init(project=project, entity="eleutherai", name=params["name"], config=params)

eval_task_dict = tasks.get_task_dict(eval_tasks)

Expand Down

0 comments on commit 8aaa4f9

Please sign in to comment.