/
model.py
369 lines (319 loc) · 13.2 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
import json
import logging
import math
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional
import torch
from torch import nn
from simple_parsing.helpers import Serializable
from mistral.rope import precompute_freqs_cis, apply_rotary_emb
from mistral.cache import CacheView, RotatingBufferCache
from mistral.moe import MoeArgs, MoeLayer
from xformers.ops.fmha import memory_efficient_attention
@dataclass
class ModelArgs(Serializable):
dim: int
n_layers: int
head_dim: int
hidden_dim: int
n_heads: int
n_kv_heads: int
norm_eps: float
vocab_size: int
max_batch_size: int = 0
# For rotary embeddings. If not set, will be infered from sliding window.
rope_theta: Optional[float] = None
# If this is set, use sliding window attention rotating cache.
sliding_window: Optional[int] = None
# If this is set, we will use MoE layers instead of dense layers.
moe: Optional[MoeArgs] = None
@dataclass
class SimpleInputMetadata:
# rope absolute positions
positions: torch.Tensor
@staticmethod
def from_seqlens(seqlens: List[int], device: torch.device) -> "SimpleInputMetadata":
return SimpleInputMetadata(
positions=torch.cat([torch.arange(0, seqlen) for seqlen in seqlens]).to(
device=device, dtype=torch.long
)
)
def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int):
keys = torch.repeat_interleave(keys, repeats=repeats, dim=dim)
values = torch.repeat_interleave(values, repeats=repeats, dim=dim)
return keys, values
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.n_heads: int = args.n_heads
self.head_dim: int = args.head_dim
self.n_kv_heads: int = args.n_kv_heads
self.repeats = self.n_heads // self.n_kv_heads
self.scale = self.args.head_dim**-0.5
self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False)
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
def forward(
self,
x: torch.Tensor,
freqs_cis: torch.Tensor,
cache: Optional[CacheView],
) -> torch.Tensor:
seqlen_sum, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(seqlen_sum, self.n_heads, self.head_dim)
xk = xk.view(seqlen_sum, self.n_kv_heads, self.head_dim)
xv = xv.view(seqlen_sum, self.n_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
if cache is None:
key, val = xk, xv
elif cache.prefill:
key, val = cache.interleave_kv(xk, xv)
cache.update(xk, xv)
else:
cache.update(xk, xv)
key, val = cache.key, cache.value
key = key.view(
seqlen_sum * cache.sliding_window, self.n_kv_heads, self.head_dim
)
val = val.view(
seqlen_sum * cache.sliding_window, self.n_kv_heads, self.head_dim
)
# Repeat keys and values to match number of query heads
key, val = repeat_kv(key, val, self.repeats, dim=1)
# xformers requires (B=1, S, H, D)
xq, key, val = xq[None, ...], key[None, ...], val[None, ...]
output = memory_efficient_attention(
xq, key, val, None if cache is None else cache.mask
)
return self.wo(output.view(seqlen_sum, self.n_heads * self.head_dim))
class FeedForward(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)
def forward(self, x) -> torch.Tensor:
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.attention = Attention(args)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.args = args
self.feed_forward: nn.Module
if args.moe is not None:
self.feed_forward = MoeLayer(
experts=[FeedForward(args=args) for _ in range(args.moe.num_experts)],
gate=nn.Linear(args.dim, args.moe.num_experts, bias=False),
moe_args=args.moe,
)
else:
self.feed_forward = FeedForward(args=args)
def forward(
self, x: torch.Tensor, freqs_cis: torch.Tensor, cache: Optional[CacheView]
) -> torch.Tensor:
r = self.attention.forward(self.attention_norm(x), freqs_cis, cache)
h = x + r
r = self.feed_forward.forward(self.ffn_norm(h))
out = h + r
return out
class Transformer(nn.Module):
def __init__(
self,
args: ModelArgs,
pipeline_rank: int = 0,
num_pipeline_ranks: int = 1,
):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.n_layers = args.n_layers
self._precomputed_freqs_cis: Optional[torch.Tensor] = None
assert self.vocab_size > 0
assert pipeline_rank < num_pipeline_ranks, (pipeline_rank, num_pipeline_ranks)
self.pipeline_rank = pipeline_rank
self.num_pipeline_ranks = num_pipeline_ranks
# Modules specific to some ranks:
self.tok_embeddings: Optional[nn.Embedding] = None
self.norm: Optional[RMSNorm] = None
self.output: Optional[nn.Linear] = None
if pipeline_rank == 0:
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
if pipeline_rank == num_pipeline_ranks - 1:
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
# Initialize all layers but slice off those not of this rank.
layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
num_layers_per_rank = math.ceil(self.n_layers / self.num_pipeline_ranks)
offset = self.pipeline_rank * num_layers_per_rank
end = min(self.n_layers, offset + num_layers_per_rank)
self.layers = nn.ModuleDict({str(i): layers[i] for i in range(offset, end)})
self.n_local_layers = len(self.layers)
@property
def dtype(self) -> torch.dtype:
return next(self.parameters()).dtype
@property
def device(self) -> torch.device:
return next(self.parameters()).device
@property
def freqs_cis(self) -> torch.Tensor:
# We cache freqs_cis but need to take care that it is on the right device
# and has the right dtype (complex64). The fact that the dtype is different
# from the module's dtype means we cannot register it as a buffer
if self._precomputed_freqs_cis is None:
# If no sliding window, assume a larger seqlen
theta = self.args.rope_theta
if theta is None:
theta = 1000000.0 if self.args.sliding_window is None else 10000.0
# theta = 10000.
self._precomputed_freqs_cis = precompute_freqs_cis(
self.args.head_dim, 128_000, theta
)
if self._precomputed_freqs_cis.device != self.device:
self._precomputed_freqs_cis = self._precomputed_freqs_cis.to(
device=self.device
)
return self._precomputed_freqs_cis
def forward_partial(
self,
input_ids: torch.Tensor,
seqlens: List[int],
cache: Optional[RotatingBufferCache] = None,
) -> torch.Tensor:
"""Local forward pass.
If doing pipeline parallelism, this will return the activations of the last layer of this stage.
For the last stage, this will return the normalized final embeddings.
"""
assert (
len(seqlens) <= self.args.max_batch_size
), f"Max batch size is {self.args.max_batch_size}, got batch size of {len(seqlens)}"
(num_toks,) = input_ids.shape
assert sum(seqlens) == num_toks, (sum(seqlens), num_toks)
if cache is not None:
input_metadata = cache.get_input_metadata(seqlens)
else:
input_metadata = SimpleInputMetadata.from_seqlens(seqlens, self.device)
if self.pipeline_rank == 0:
assert self.tok_embeddings is not None
h = self.tok_embeddings(input_ids)
else:
h = torch.empty(
num_toks, self.args.dim, device=self.device, dtype=self.dtype
)
torch.distributed.recv(h, src=self.pipeline_rank - 1)
freqs_cis = self.freqs_cis[input_metadata.positions]
for local_layer_id, layer in enumerate(self.layers.values()):
if cache is not None:
assert input_metadata is not None
cache_view = cache.get_view(local_layer_id, input_metadata)
else:
cache_view = None
h = layer(h, freqs_cis, cache_view)
if cache is not None:
cache.update_seqlens(seqlens)
if self.pipeline_rank < self.num_pipeline_ranks - 1:
torch.distributed.send(h, dst=self.pipeline_rank + 1)
return h
else:
# Last rank has a final normalization step.
assert self.norm is not None
return self.norm(h)
def forward(
self,
input_ids: torch.Tensor,
seqlens: List[int],
cache: Optional[RotatingBufferCache] = None,
) -> torch.Tensor:
h = self.forward_partial(input_ids, seqlens, cache=cache)
if self.pipeline_rank < self.num_pipeline_ranks - 1:
# ignore the intermediate activations as we'll get the final output from
# the last stage
outs = torch.empty(
h.shape[0], self.vocab_size, device=h.device, dtype=h.dtype
)
else:
assert self.output is not None
outs = self.output(h)
if self.num_pipeline_ranks > 1:
torch.distributed.broadcast(outs, src=self.num_pipeline_ranks - 1)
return outs.float()
def load_state_dict(self, state_dict, *args, **kwargs):
state_to_load = {}
skipped = set([])
for k, v in state_dict.items():
if k.startswith("tok_embeddings"):
if self.pipeline_rank == 0:
state_to_load[k] = v
else:
logging.debug(
"Skipping parameter %s at pipeline rank %d",
k,
self.pipeline_rank,
)
skipped.add(k)
elif k.startswith("norm") or k.startswith("output"):
if self.pipeline_rank == self.num_pipeline_ranks - 1:
state_to_load[k] = v
else:
logging.debug(
"Skipping parameter %s at pipeline rank %d",
k,
self.pipeline_rank,
)
skipped.add(k)
elif k.startswith("layers"):
layer_id = k.split(".")[1]
if layer_id in self.layers:
state_to_load[k] = v
else:
logging.debug(
"Skipping parameter %s at pipeline rank %d",
k,
self.pipeline_rank,
)
skipped.add(k)
else:
raise ValueError(f"Unexpected key {k}")
assert set(state_dict.keys()) == skipped.union(set(state_to_load.keys()))
super().load_state_dict(state_to_load, *args, **kwargs)
@staticmethod
def from_folder(
folder: Path,
max_batch_size: int = 1,
num_pipeline_ranks: int = 1,
device="cuda",
dtype=torch.float16,
) -> "Transformer":
with open(folder / "params.json", "r") as f:
model_args = ModelArgs.from_dict(json.load(f))
model_args.max_batch_size = max_batch_size
if num_pipeline_ranks > 1:
pipeline_rank = torch.distributed.get_rank()
else:
pipeline_rank = 0
with torch.device("meta"):
model = Transformer(
model_args,
pipeline_rank=pipeline_rank,
num_pipeline_ranks=num_pipeline_ranks,
)
loaded = torch.load(str(folder / "consolidated.00.pth"), mmap=True)
model.load_state_dict(loaded, assign=True)
return model.to(device=device, dtype=dtype)