/
reformer_imagenet64.gin
113 lines (103 loc) · 3.99 KB
/
reformer_imagenet64.gin
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Copyright 2024 The Trax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import trax.data
import trax.layers
import trax.models
import trax.optimizers
import trax.supervised.trainer_lib
# Parameters that will vary between experiments:
# ==============================================================================
train.model = @trax.models.ReformerLM # ShortenLM
attn_type = @LSHSelfAttention # @CausalFavor
attn_kv = 64
n_layers = 6
# Parameters for batcher:
# ==============================================================================
batcher.data_streams = @data.data_streams
batcher.batch_size_per_device = 1
batcher.eval_batch_size = 8
batcher.max_eval_length = 12288 # 64 * 64 * 3
# Parameters for data_streams:
# ==============================================================================
data_streams.data_dir = None
data_streams.dataset_name = 't2t_image_imagenet64_gen_flat_rev'
data_streams.input_name = 'targets'
# Parameters for multifactor:
# ==============================================================================
# 0.03125 ~= 1024^-0.5 = d_model^-0.5
multifactor.constant = 0.03125
multifactor.factors = 'constant * linear_warmup * rsqrt_decay'
multifactor.warmup_steps = 8000
# Parameters for Adam:
# ==============================================================================
Adam.weight_decay_rate=0.0
Adam.b1 = 0.9
Adam.b2 = 0.98
Adam.eps = 1e-9
# Parameters for train:
# ==============================================================================
train.eval_frequency = 500
train.eval_steps = 64
# train.model: see top
train.optimizer = @trax.optimizers.Adam
train.steps = 500000
train.save_graphs = False
train.checkpoints_at = \
[1000, 5000, 10000, 20000, 40000, 60000, 80000,
100000, 200000, 300000, 400000, 500000]
# Parameters for LSHSelfAttention:
# ==============================================================================
LSHSelfAttention.attention_dropout = 0.2
LSHSelfAttention.chunk_len = 128
LSHSelfAttention.n_buckets = 192
LSHSelfAttention.n_chunks_after = 0
LSHSelfAttention.n_chunks_before = 1
LSHSelfAttention.n_hashes = 2
LSHSelfAttention.n_parallel_heads = 1
LSHSelfAttention.predict_drop_len = 256
LSHSelfAttention.predict_mem_len = 12288
# Parameters for ReformerLM:
# ==============================================================================
ReformerLM.attention_type = %attn_type
ReformerLM.d_attention_key = %attn_kv
ReformerLM.d_attention_value = %attn_kv
ReformerLM.d_model = 1024
ReformerLM.d_ff = 4096
ReformerLM.dropout = 0.0
ReformerLM.ff_activation = @trax.layers.FastGelu
ReformerLM.max_len = 12288 # 64 * 64 * 3
ReformerLM.mode = 'train'
ReformerLM.n_heads = 8
ReformerLM.n_layers = %n_layers
ReformerLM.vocab_size = 256
ReformerLM.pos_axial_shape = (64, 64, 3)
ReformerLM.pos_d_axial_embs= (384, 384, 256)
# Parameters for ReformerShortenLM:
# ==============================================================================
ReformerShortenLM.attention_type = %attn_type
ReformerShortenLM.d_attention_key = %attn_kv
ReformerShortenLM.d_attention_value = %attn_kv
ReformerShortenLM.shorten_factor = 3
ReformerShortenLM.d_embedding = 256
ReformerShortenLM.d_model = 1024
ReformerShortenLM.d_ff = 4096
ReformerShortenLM.dropout = 0.0
ReformerShortenLM.ff_activation = @trax.layers.FastGelu
ReformerShortenLM.max_len = 12288 # 64 * 64 * 3
ReformerShortenLM.mode = 'train'
ReformerShortenLM.n_heads = 8
ReformerShortenLM.n_layers = %n_layers
ReformerShortenLM.vocab_size = 256
ReformerShortenLM.pos_axial_shape = (64, 64, 3)
ReformerShortenLM.pos_d_axial_embs= (96, 96, 64)