-
Notifications
You must be signed in to change notification settings - Fork 88
/
llama.ex
440 lines (367 loc) 路 13.3 KB
/
llama.ex
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
defmodule Bumblebee.Text.Llama do
alias Bumblebee.Shared
options =
[
vocab_size: [
default: 32000,
doc: """
the vocabulary size of the token embedding. This corresponds to the number of distinct
tokens that can be represented in model input and output
"""
],
max_positions: [
default: 1024,
doc: """
the vocabulary size of the position embedding. This corresponds to the maximum sequence
length that this model can process. Typically this is set to a large value just in case,
such as 512, 1024 or 2048
"""
],
hidden_size: [
default: 4096,
doc: "the dimensionality of hidden layers"
],
intermediate_size: [
default: 11008,
doc: "the dimensionality of intermediate layers"
],
num_blocks: [
default: 32,
doc: "the number of Transformer blocks in the model"
],
num_attention_heads: [
default: 32,
doc: "the number of attention heads for each attention layer in the model"
],
num_key_value_heads: [
default: nil,
doc: "the number of key value heads for each attention layer in the model"
],
activation: [
default: :silu,
doc: "the activation function"
],
rotary_embedding_base: [
default: 10_000,
doc: "base for computing rotary embedding frequency"
],
rotary_embedding_scaling_strategy: [
default: nil,
doc: """
scaling configuration for rotary embedding. Currently the supported values are:
* `%{type: :linear, factor: number()}`
* `%{type: :dynamic, factor: number()}`
For more details see https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases
"""
],
layer_norm_epsilon: [
default: 1.0e-12,
doc: "the epsilon used by RMS normalization layers"
],
initializer_scale: [
default: 0.02,
doc:
"the standard deviation of the normal initializer used for initializing kernel parameters"
]
] ++
Shared.common_options([:num_labels, :id_to_label]) ++ Shared.token_options(pad_token_id: 0)
@moduledoc """
LLaMA model family.
## Architectures
* `:base` - plain LLaMA without any head on top
* `:for_causal_language_modeling` - LLaMA with a language modeling
head. The head returns logits for each token in the original
sequence
* `:for_sequence_classification` - LLaMA with a sequence
classification head. The head returns logits corresponding to
possible classes
## Inputs
* `"input_ids"` - `{batch_size, sequence_length}`
Indices of input sequence tokens in the vocabulary.
* `"attention_mask"` - `{batch_size, sequence_length}`
Mask indicating which tokens to attend to. This is used to ignore
padding tokens, which are added when processing a batch of sequences
with different length.
* `"position_ids"` - `{batch_size, sequence_length}`
Indices of positions of each input sequence tokens in the position
embeddings.
* `"attention_head_mask"` - `{encoder_num_blocks, encoder_num_attention_heads}`
Mask to nullify selected heads of the self-attention blocks in
the encoder.
* `"input_embeddings"` - `{batch_size, sequence_length, hidden_size}`
Embedded representation of `"input_ids"`, which can be specified
for more control over how `"input_ids"` are embedded than the
model's internal embedding lookup. If `"input_embeddings"` are present,
then `"input_ids"` will be ignored.
* `"cache"`
A container with cached layer results used to speed up sequential
decoding (autoregression). With cache, certain hidden states are
taken from the cache, rather than recomputed on every decoding
pass. The cache should be treated as opaque and initialized with
`Bumblebee.Text.Generation.init_cache/4`.
## Global layer options
#{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])}
## Configuration
#{Shared.options_doc(options)}
"""
defstruct [architecture: :base] ++ Shared.option_defaults(options)
@behaviour Bumblebee.ModelSpec
@behaviour Bumblebee.Configurable
@behaviour Bumblebee.Text.Generation
import Bumblebee.Utils.Model, only: [join: 2]
alias Bumblebee.Layers
@impl true
def architectures(),
do: [
:base,
:for_causal_language_modeling,
:for_sequence_classification
]
@impl true
def config(spec, opts) do
spec
|> Shared.put_config_attrs(opts)
|> Shared.validate_label_options()
end
@impl true
def input_template(_spec) do
%{
"input_ids" => Nx.template({1, 1}, :s64)
}
end
@impl true
def init_cache(spec, batch_size, max_length, _inputs) do
Layers.Decoder.init_cache(batch_size, max_length,
hidden_size: spec.hidden_size,
decoder_num_attention_heads: spec.num_attention_heads,
decoder_num_blocks: spec.num_blocks
)
end
@impl true
def traverse_cache(_spec, cache, fun) do
Layers.Decoder.traverse_cache(cache, fun)
end
@impl true
def model(%__MODULE__{architecture: :base} = spec) do
inputs = inputs(spec)
inputs
|> core(spec)
|> Layers.output()
end
def model(%__MODULE__{architecture: :for_causal_language_modeling} = spec) do
inputs = inputs(spec)
outputs = core(inputs, spec)
logits = language_modeling_head(outputs.hidden_state, spec, name: "language_modeling_head")
Layers.output(%{
logits: logits,
hidden_states: outputs.hidden_states,
attentions: outputs.attentions,
cache: outputs.cache
})
end
def model(%__MODULE__{architecture: :for_sequence_classification} = spec) do
inputs = inputs(spec)
outputs = core(inputs, spec)
logits =
Axon.dense(outputs.hidden_state, spec.num_labels,
kernel_initializer: kernel_initializer(spec),
name: "sequence_classification_head.output",
use_bias: false
)
pooled_logits =
Layers.if_present inputs["input_ids"] do
Axon.layer(
fn logits, input_ids, _opts ->
indices =
input_ids
|> Nx.not_equal(spec.pad_token_id)
|> Nx.sum(axes: [-1])
|> Nx.subtract(1)
|> Nx.as_type({:s, 64})
Bumblebee.Utils.Nx.batched_take(logits, indices)
end,
[logits, inputs["input_ids"]]
)
else
Layers.take_token(logits, axis: 1, index: -1)
end
Layers.output(%{
logits: pooled_logits,
hidden_states: outputs.hidden_states,
attentions: outputs.attentions,
cache: outputs.cache
})
end
defp inputs(spec) do
shape = {nil, nil}
hidden_shape = {nil, nil, spec.hidden_size}
attention_head_mask_shape = {spec.num_blocks, spec.num_attention_heads}
Bumblebee.Utils.Model.inputs_to_map([
Axon.input("input_ids", optional: true, shape: shape),
Axon.input("attention_mask", optional: true, shape: shape),
Axon.input("position_ids", optional: true, shape: shape),
Axon.input("attention_head_mask", optional: true, shape: attention_head_mask_shape),
Axon.input("input_embeddings", optional: true, shape: hidden_shape),
Axon.input("cache", optional: true)
])
end
defp core(inputs, spec) do
embeddings =
embedder(
inputs["input_ids"],
inputs["input_embeddings"],
spec,
name: "embedder"
)
position_ids =
Layers.default inputs["position_ids"] do
Layers.default_position_ids(embeddings)
end
decoder_outputs =
decoder(
embeddings,
position_ids,
inputs["attention_mask"],
inputs["attention_head_mask"],
inputs["cache"],
spec,
name: "decoder"
)
hidden_state =
Layers.rms_norm(decoder_outputs.hidden_state,
name: "output_norm",
epsilon: spec.layer_norm_epsilon
)
%{
hidden_state: hidden_state,
hidden_states: Layers.append(decoder_outputs.hidden_states, hidden_state),
attentions: decoder_outputs.attentions,
cache: decoder_outputs.cache
}
end
defp embedder(input_ids, input_embeddings, spec, opts) do
name = opts[:name]
Layers.default input_embeddings do
Axon.embedding(input_ids, spec.vocab_size, spec.hidden_size,
kernel_initializer: kernel_initializer(spec),
name: join(name, "token_embedding")
)
end
end
defp decoder(
hidden_state,
position_ids,
attention_mask,
attention_head_mask,
cache,
spec,
opts
) do
name = opts[:name]
Layers.Transformer.blocks(hidden_state,
attention_mask: attention_mask,
attention_head_mask: attention_head_mask,
cache: cache,
num_blocks: spec.num_blocks,
num_attention_heads: spec.num_attention_heads,
num_key_value_heads: spec.num_key_value_heads,
hidden_size: spec.hidden_size,
kernel_initializer: kernel_initializer(spec),
layer_norm: &Layers.rms_norm(&1, name: &2, epsilon: spec.layer_norm_epsilon),
ffn:
&gated_ffn(&1, spec.intermediate_size, spec.hidden_size,
name: &2,
activation: spec.activation
),
block_type: :norm_first,
causal: true,
rotary_embedding: [
position_ids: position_ids,
max_positions: spec.max_positions,
base: spec.rotary_embedding_base,
scaling_strategy: spec.rotary_embedding_scaling_strategy
],
query_use_bias: false,
key_use_bias: false,
value_use_bias: false,
output_use_bias: false,
name: join(name, "blocks")
)
end
defp gated_ffn(hidden_state, intermediate_size, output_size, opts) do
name = opts[:name]
activation = opts[:activation]
intermediate =
Axon.dense(hidden_state, intermediate_size,
name: join(name, "intermediate"),
use_bias: false
)
gate = Axon.dense(hidden_state, intermediate_size, name: join(name, "gate"), use_bias: false)
hidden_state = Axon.multiply(intermediate, Axon.activation(gate, activation))
Axon.dense(hidden_state, output_size, name: join(name, "output"), use_bias: false)
end
defp language_modeling_head(hidden_state, spec, opts) do
name = opts[:name]
# TODO: Tie lm-head to word embedding as a spec option
Layers.dense_transposed(hidden_state, spec.vocab_size,
kernel_initializer: kernel_initializer(spec),
name: join(name, "output")
)
end
defp kernel_initializer(spec) do
Axon.Initializers.normal(scale: spec.initializer_scale)
end
defimpl Bumblebee.HuggingFace.Transformers.Config do
def load(spec, data) do
import Shared.Converters
scaling_strategy_converter = fn name, value ->
case value do
%{"type" => "linear", "factor" => factor} when is_number(factor) ->
{:ok, %{type: :linear, factor: factor}}
%{"type" => "dynamic", "factor" => factor} when is_number(factor) ->
{:ok, %{type: :dynamic, factor: factor}}
_other ->
{:error, "invalid format for #{inspect(name)}, got: #{inspect(value)}"}
end
end
opts =
convert!(data,
vocab_size: {"vocab_size", number()},
max_positions: {"max_position_embeddings", number()},
hidden_size: {"hidden_size", number()},
num_blocks: {"num_hidden_layers", number()},
num_attention_heads: {"num_attention_heads", number()},
num_key_value_heads: {"num_key_value_heads", number()},
intermediate_size: {"intermediate_size", number()},
activation: {"hidden_act", activation()},
rotary_embedding_base: {"rope_theta", number()},
rotary_embedding_scaling_strategy:
{"rope_scaling", optional(scaling_strategy_converter)},
initializer_scale: {"initializer_range", number()},
layer_norm_epsilon: {"rms_norm_eps", number()}
) ++ Shared.common_options_from_transformers(data, spec)
@for.config(spec, opts)
end
end
defimpl Bumblebee.HuggingFace.Transformers.Model do
def params_mapping(_spec) do
%{
"embedder.token_embedding" => "model.embed_tokens",
"decoder.blocks.{n}.self_attention.query" => "model.layers.{n}.self_attn.q_proj",
"decoder.blocks.{n}.self_attention.key" => "model.layers.{n}.self_attn.k_proj",
"decoder.blocks.{n}.self_attention.value" => "model.layers.{n}.self_attn.v_proj",
"decoder.blocks.{n}.self_attention.output" => "model.layers.{n}.self_attn.o_proj",
"decoder.blocks.{n}.self_attention_norm" => "model.layers.{n}.input_layernorm",
"decoder.blocks.{n}.self_attention.rotary_embedding" =>
"model.layers.{n}.self_attn.rotary_emb",
"decoder.blocks.{n}.ffn.gate" => "model.layers.{n}.mlp.gate_proj",
"decoder.blocks.{n}.ffn.intermediate" => "model.layers.{n}.mlp.up_proj",
"decoder.blocks.{n}.ffn.output" => "model.layers.{n}.mlp.down_proj",
"decoder.blocks.{n}.output_norm" => "model.layers.{n}.post_attention_layernorm",
"output_norm" => "model.norm",
"language_modeling_head.output" => "lm_head",
"sequence_classification_head.output" => "score"
}
end
end
end