-
Notifications
You must be signed in to change notification settings - Fork 889
/
6B_roto_256.json
50 lines (42 loc) · 905 Bytes
/
6B_roto_256.json
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
{
"layers": 28,
"d_model": 4096,
"n_heads": 16,
"n_vocab": 50400,
"norm": "layernorm",
"pe": "rotary",
"pe_rotary_dims": 64,
"seq": 2048,
"cores_per_replica": 8,
"per_replica_batch": 1,
"gradient_accumulation_steps": 16,
"warmup_steps": 3000,
"anneal_steps": 300000,
"lr": 1.2e-4,
"end_lr": 1.2e-5,
"weight_decay": 0.1,
"total_steps": 350000,
"tpu_size": 256,
"bucket": "neo-models",
"model_dir": "mesh_jax_pile_6B_rotary",
"train_set": "pile.train.index",
"val_set": {
"pile": "pile.val.index",
"owt": "openwebtext2_new_inputs.val.index"
},
"eval_harness_tasks": [
"lambada",
"piqa",
"hellaswag",
"winogrande",
"mathqa",
"pubmedqa"
],
"val_batches": 100,
"val_every": 500,
"ckpt_every": 500,
"keep_every": 10000,
"name": "GPT3_6B_pile_rotary",
"wandb_project": "mesh-transformer-jax",
"comment": ""
}