81M parameter GPT model trained on WikiText-103 using torch.distributed
for DDP and torch.cuda.amp
for mixed precision. Used youtokentome for tokenizing the corpus.
d_model | 768 |
n_head | 16 |
n_layer | 8 |
ctx_len | 512 |
vocab_size | 16384 |
max_lr | 3e-4 |
min_lr | 5e-5 |
weight_decay | 1e-2 |
train_steps | 15000 |
lr_warmup_steps | 1500 |
lr_schedule | inverse square root |
batch_size | 128 |
gpu | RTX 3090 |
n_gpu | 4 |
Few changes: RMS norm instead of layer norm, RoPE instead of learned position encoding, SwiGLU instead of GELU, and Lion optimizer instead of Adam(W).
The difference between AdamW and Lion is as follows:
Lion saves memory by only keeping track of the first moment (the EMA of gradients) whereas Adam and variants use the second moment (the EMA of gradients squared) as well.
A few days ago 4chan anon kaiokendev made a two line edit in his RoPE implementation which appeared to double the model's context length at test time. The equivalent edit in the implementation here would be:
There's no fine-tuning going on, all we have to do is scale torch.arange(l)
by scale
. What's happening? Here, x
will be the query q
or key k
batch on which we apply RoPE. x
will be of shape (batch_size, seq_len, d_model)
- we only need to care about the last two dimensions (assume we're working with a single batch, which is what we do at test time). We first define a vector theta
of d
angle values like so
[θ_1, θ_1, θ_2, θ_2, ..., θ_{d/2}]
where d
is assumed to be even. Then we consider torch.arange(l)
which is simply
[1, 2, ..., l]
where l
is seq_len
. torch.einsum('i, j -> ij', torch.arange(l), theta)
is the outer product between the two arguments, so you get ltheta
which is a matrix that looks like
[[1 * θ_1, 1 * θ_1, ..., 1 * θ_{d/2}, 1 * θ_{d/2}],
[2 * θ_1, 2 * θ_1, ..., 2 * θ_{d/2}, 2 * θ_{d/2}],
...
[l * θ_1, l * θ_1, ..., l * θ_{d/2}, l * θ_{d/2}]]
So ltheta
is same shape as x
- (seq_len, d_model)
. x_rot
is just x
with alternating columns and even columns made negative. We apply sin
and cos
along each row of ltheta
, and elementwise multiply the outputs with x_rot
and x
respectively. To double the context length, we halve the frequencies
[[1/2 * θ_1, 1/2 * θ_1, ..., 1/2 * θ_{d/2}, 1/2 * θ_{d/2}],
[2/2 * θ_1, 2/2 * θ_1, ..., 2/2 * θ_{d/2}, 2/2 * θ_{d/2}],
...
[2l/2 * θ_1, 2l/2 * θ_1, ..., 2l/2 * θ_{d/2}, 2l/2 * θ_{d/2}]]
Intuitively, the model has learned to map sections of a context to specific sections of the sin
and cos
curves. Imagine dividing an arbitrary context into uniform chunks. Then no matter the context length, we want each chunk to be mapped to the same section of the trig curves. We do this by scaling down the frequencies. A paper laying out the exact trick came out just recently and it has this helpful visualization of what's going on
If we have a trained model with context length l
, we can make it work with longer context length k * l
by doing
k = 2
model = GPT(...)
model.load_state_dict(torch.load('trained_model.pt'))
# update RoPE
for layer in model.layers:
layer.pe = RotaryEmbedding(k * model.l, model.d, scale=k)
model.l *= k
torch.save(model.state_dict(), 'extended_model.pt')
(Update again: Jianlin Su's explanation)
Attention scores visualized for 64 tokens in the middle of a completion: