-
Notifications
You must be signed in to change notification settings - Fork 25.2k
/
modeling_gptsan_japanese.py
1354 lines (1149 loc) · 65.6 KB
/
modeling_gptsan_japanese.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding=utf-8
# Copyright 2023 Toshiyuki Sakamoto(tanreinama) and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch GPTSANJapanese model."""
import copy
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from ...activations import ACT2FN
from ...modeling_outputs import MoECausalLMOutputWithPast, MoEModelOutputWithPastAndCrossAttentions
from ...modeling_utils import PreTrainedModel
from ...utils import (
DUMMY_INPUTS,
DUMMY_MASK,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_torch_fx_proxy,
logging,
)
from .configuration_gptsan_japanese import GPTSanJapaneseConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "GPTSanJapaneseConfig"
_CHECKPOINT_FOR_DOC = "Tanrei/GPTSAN-japanese"
####################################################
# This dict contains ids and associated url
# for the pretrained weights provided with the models
####################################################
GPTSAN_JAPANESE_PRETRAINED_MODEL_ARCHIVE_LIST = [
"Tanrei/GPTSAN-japanese",
# See all GPTSAN-japanese models at https://huggingface.co/models?filter=gptsan-japanese
]
# Copied from transformers.models.switch_transformers.modeling_switch_transformers.router_z_loss_func
def router_z_loss_func(router_logits: torch.Tensor) -> float:
r"""
Compute the router z-loss implemented in PyTorch.
The router z-loss was introduced in [Designing Effective Sparse Expert Models](https://arxiv.org/abs/2202.08906).
It encourages router logits to remain small in an effort to improve stability.
Args:
router_logits (`float`):
Input logits of shape [batch_size, sequence_length, num_experts]
Returns:
Scalar router z-loss.
"""
num_groups, tokens_per_group, _ = router_logits.shape
log_z = torch.logsumexp(router_logits, dim=-1)
z_loss = log_z**2
return torch.sum(z_loss) / (num_groups * tokens_per_group)
# Copied from transformers.models.switch_transformers.modeling_switch_transformers.load_balancing_loss_func
def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float:
r"""
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
experts is too unbalanced.
Args:
router_probs (`torch.Tensor`):
Probability assigned to each expert per token. Shape: [batch_size, seqeunce_length, num_experts].
expert_indices (`torch.Tensor`):
Indices tensor of shape [batch_size, seqeunce_length] identifying the selected expert for a given token.
Returns:
The auxiliary loss.
"""
num_experts = router_probs.shape[-1]
# cast the expert indices to int64, otherwise one-hot encoding will fail
if expert_indices.dtype != torch.int64:
expert_indices = expert_indices.to(torch.int64)
if len(expert_indices.shape) == 2:
expert_indices = expert_indices.unsqueeze(2)
expert_mask = torch.nn.functional.one_hot(expert_indices, num_experts)
# For a given token, determine if it was routed to a given expert.
expert_mask = torch.max(expert_mask, axis=-2).values
# cast to float32 otherwise mean will fail
expert_mask = expert_mask.to(torch.float32)
tokens_per_group_and_expert = torch.mean(expert_mask, axis=-2)
router_prob_per_group_and_expert = torch.mean(router_probs, axis=-2)
return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2)
class GPTSanJapaneseDenseActDense(nn.Module):
"""
FFN Layer for Switch Transformer and Extra layers
GPTSAN can mix Switch Transformer layers and normal Transformer layers This class is used as Expert in Switch
Transformer layers and as FFN in regular Transformer layers. RELU is used in the Switch Transformer layer, and
Swish is used in the normal Transformer layer, so there is a choice of which is used in the argument.
"""
def __init__(self, config: GPTSanJapaneseConfig, ext_layer=False):
super().__init__()
d_inter = config.d_ext if ext_layer else config.d_ff
self.wi = nn.Linear(config.d_model, d_inter, bias=ext_layer)
self.wo = nn.Linear(d_inter, config.d_model, bias=ext_layer)
self.dropout = nn.Identity() if ext_layer else nn.Dropout(config.dropout_rate)
self.act = ACT2FN["swish" if ext_layer else "relu"]
def forward(self, hidden_states):
r"""
Args:
hidden_states (`torch.Tensor`) :
[num_groups, tokens_per_group, hidden_dim] inputs to send to experts.
Returns:
torch.Tensor[num_groups, tokens_per_group, hidden_dim]
"""
hidden_states = self.wi(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.wo(hidden_states)
return hidden_states
# Copied from transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersTop1Router with SwitchTransformers->GPTSanJapanese
class GPTSanJapaneseTop1Router(nn.Module):
"""
Router using tokens choose top-1 experts assignment.
This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE
(https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then
routed to their choice of expert until the expert's expert_capacity is reached. **There is no guarantee that each
token is processed by an expert**, or that each expert receives at least one token.
"""
def __init__(self, config: GPTSanJapaneseConfig):
super().__init__()
self.num_experts = config.num_experts
self.expert_capacity = config.expert_capacity
self.classifier = nn.Linear(config.hidden_size, self.num_experts, bias=config.router_bias)
self.jitter_noise = config.router_jitter_noise
self.ignore_padding_tokens = config.router_ignore_padding_tokens
self.dtype = getattr(torch, config.router_dtype)
def _compute_router_probabilities(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Computes router probabilities from input hidden states.
Args:
hidden_states (`torch.Tensor`):
(batch_size, sequence_length, hidden_dim) from which router probabilities are computed.
Returns:
router_probabilities (`torch.Tensor`):
Tensor of shape (batch_size, sequence_length, num_experts) corresponding to the probabilities for each
token and expert. Used for routing tokens to experts.
router_logits (`torch.Tensor`):
Logits tensor of shape (batch_size, sequence_length, num_experts) corresponding to raw router logits.
This is used later for computing router z-loss.
"""
# float32 is used to ensure stability. See the discussion of "selective precision" in
# https://arxiv.org/abs/2101.03961.
# We also store the previous dtype to cast back the output to the previous dtype
self.input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(self.dtype)
if self.jitter_noise > 0:
# Get the lower and upper bound of the uniform distribution
# Adapted from: https://stackoverflow.com/questions/44328530/how-to-get-a-uniform-distribution-in-a-range-r1-r2-in-pytorch
distrib_lower_bound = 1.0 - self.jitter_noise
distrib_upper_bound = 1.0 + self.jitter_noise
uniform_distrib = torch.rand(hidden_states.shape, device=hidden_states.device, dtype=self.dtype)
uniform_distrib = uniform_distrib * (distrib_lower_bound - distrib_upper_bound)
uniform_distrib = uniform_distrib + distrib_upper_bound
# Multiply the token inputs by the uniform distribution - adding some noise
hidden_states *= uniform_distrib
# Shape: [num_groups, tokens_per_group, num_experts]
self._cast_classifier()
router_logits = self.classifier(hidden_states)
# Apply Softmax and cast back to the original `dtype`
router_probabilities = nn.functional.softmax(router_logits, dim=-1, dtype=self.dtype).to(self.input_dtype)
return router_probabilities, router_logits
def _cast_classifier(self):
r"""
`bitsandbytes` `Linear8bitLt` layers does not support manual casting Therefore we need to check if they are an
instance of the `Linear8bitLt` class by checking special attributes.
"""
if not (hasattr(self.classifier, "SCB") or hasattr(self.classifier, "CB")):
self.classifier = self.classifier.to(self.dtype)
def forward(self, hidden_states: torch.Tensor) -> Tuple:
r"""
Generic forward function for every Router class. Each Router expects to have the same input hidden states
(`hidden_states`) corresponding to the hidden states for each token, the `expert_capacity` corresponding to the
number of tokens the Router will send to each expert, some Routers can send up to few tokens to each expert.
Each Router works as the following: it expects the hidden states for each token, gets the `router_probs` and
`router_logits` from the `router_weights`. This will assign for each token, the raw probability to be assigned
to an expert. Then each Router class will have to define its own `_compute_routing_instructions`.
Args:
hidden_states (`torch.Tensor`) :
[num_groups, tokens_per_group, hidden_dim] inputs to send to experts.
Returns:
Tuple[`torch.Tensor`, `torch.Tensor`, `torch.Tensor`] Tuple containing the expert index, the router probs
and the router logits. The router probabilities and logits are required to compute the loss.
"""
router_probs, router_logits = self._compute_router_probabilities(hidden_states)
expert_index = torch.argmax(router_probs, dim=-1)
expert_index = torch.nn.functional.one_hot(expert_index, num_classes=self.num_experts)
# Mask tokens outside expert capacity. Sum over each sequence
token_priority = torch.cumsum(expert_index, dim=-2)
# mask if the token routed to to the expert will overflow
expert_capacity_mask = token_priority <= self.expert_capacity
expert_index = expert_index * expert_capacity_mask
router_probs = torch.max(router_probs, dim=-1).values.unsqueeze(-1)
return expert_index, router_probs, router_logits
# Copied from transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersSparseMLP with SwitchTransformers->GPTSanJapanese
class GPTSanJapaneseSparseMLP(nn.Module):
r"""
Implementation of the Switch Transformers Sparse MLP module.
"""
def __init__(self, config: GPTSanJapaneseConfig, expert_class: nn.Module = GPTSanJapaneseDenseActDense):
super().__init__()
# Step 1: Get the correct router according to its class
self.router = GPTSanJapaneseTop1Router(config)
# Step 2: Get the experts
self.experts = nn.ModuleDict()
for idx in range(config.num_experts):
self.experts[f"expert_{idx}"] = expert_class(config)
def forward(self, hidden_states):
r"""
Hold on, this will be slightly tricky to understand In the correct order, a MoE layer does the following:
1- Gets the `router_mask` from the router. The shape of the mask is `(batch_size, sequence_length, num_expert)`
and corresponds to the argmax of the `router_probs`. The probabilities are needed in the computation of the
hidden states : they are broadcasted to the hidden states values (can be interpreted as a scaling factor).
2- Dispatch the tokens to its associated experts. We do a classic for loop over the experts and assign for each
expert the corresponding hidden states.
"""
# Step 1: Get the router_mask from the router as wel as the probabilities
router_mask, router_probs, router_logits = self.router(hidden_states)
expert_index = torch.argmax(router_mask, dim=-1)
# The routers introduced might not always map all the tokens, to a router, which means that some hidden states
# can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the seleced ones.
next_states = hidden_states.clone()
for idx, expert in enumerate(self.experts.values()):
token_indices = router_mask[:, :, idx].bool()
next_states[token_indices] = expert(hidden_states[token_indices])
hidden_states = router_probs * next_states
return hidden_states, (router_logits, expert_index)
class GPTSanJapaneseLayerSparseFF(nn.Module):
r"""
Switch Transformers Feed Forward layer module. This is a wrapper around the Mixture of Experts module.
Parameters:
config : ([`GPTSanJapaneseConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
def __init__(self, config: GPTSanJapaneseConfig):
super().__init__()
self.mlp = GPTSanJapaneseSparseMLP(config)
self.soft_bypass_mlp = nn.Linear(config.d_model, config.d_model, bias=False)
self.norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
def forward(self, hidden_states, output_router_logits):
r"""
Args:
hidden_states (`torch.Tensor`) :
[num_groups, tokens_per_group, hidden_dim] inputs to send to experts.
output_router_logits (`bool`) :
output experts router output.
Returns:
torch.Tensor[num_groups, tokens_per_group, hidden_dim]
"""
forwarded_states, router_tuple = self.mlp(hidden_states)
forwarded_states += torch.tanh(self.soft_bypass_mlp(hidden_states))
output = hidden_states + self.norm(forwarded_states)
if output_router_logits and router_tuple is not None:
return output, router_tuple
else:
return output
class GPTSanJapaneseLayerDenseFF(nn.Module):
r"""
Extra Transformers Feed Forward layer module.
Parameters:
config : ([`GPTSanJapaneseConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
def __init__(self, config: GPTSanJapaneseConfig):
super().__init__()
# Check if it is a sparse layer, if not then it is a dense layer
self.mlp = GPTSanJapaneseDenseActDense(config, ext_layer=True)
self.norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
def forward(self, hidden_states):
r"""
Args:
hidden_states (`torch.Tensor`) :
[num_groups, tokens_per_group, hidden_dim] inputs to send to experts.
Returns:
torch.Tensor[num_groups, tokens_per_group, hidden_dim]
"""
forwarded_states = self.mlp(hidden_states)
output = hidden_states + self.norm(forwarded_states)
return output
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->GPTSanJapanese
class GPTSanJapaneseAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
is_decoder: bool = False,
bias: bool = True,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and past_key_value is not None
and past_key_value[0].shape[2] == key_value_states.shape[1]
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.reshape(*proj_shape)
value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped, past_key_value
class GPTSanJapaneseLayerSelfAttention(nn.Module):
"""
Self Attention and Normalization Unit
"""
def __init__(self, config, has_relative_attention_bias=False):
super().__init__()
self.self_attn = GPTSanJapaneseAttention(
embed_dim=config.d_model,
num_heads=config.num_heads,
is_decoder=True,
bias=has_relative_attention_bias,
)
self.norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
r"""
Self-attention and normalize block.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
if the model is configured as a decoder.
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up
decoding. If `past_key_values` are used, the user can optionally input only the last
`decoder_input_ids` (those that don't have their past key value states given to this model) of shape
`(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
head_mask (`numpy.ndarray` of shape `({0})`, `optional):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
Returns:
Tuple[torch.Tensor[num_groups, tokens_per_group, hidden_dim],...]
"""
# Self Attention
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# add present self-attn cache to positions 1,2 of present_key_value tuple
atten_out = self.self_attn(
hidden_states=hidden_states,
past_key_value=self_attn_past_key_value,
attention_mask=(1 - attention_mask) * torch.finfo(hidden_states.dtype).min,
layer_head_mask=head_mask,
output_attentions=output_attentions,
)
if output_attentions:
attn_weights = (atten_out[1],)
else:
attn_weights = ()
attention_output = atten_out[0]
hidden = hidden_states + self.norm(attention_output)
if use_cache:
outputs = (hidden, atten_out[2]) # hidden, present, (attentions)
else:
outputs = (hidden,) # hidden, (attentions)
return outputs + attn_weights
class GPTSanJapaneseBlock(nn.Module):
"""
Self Attention and FFN Unit
"""
def __init__(self, config, ext_layer=False):
super().__init__()
self.self_attn = GPTSanJapaneseLayerSelfAttention(config)
self.feed_forward = GPTSanJapaneseLayerDenseFF(config) if ext_layer else GPTSanJapaneseLayerSparseFF(config)
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
output_router_tuple: Optional[bool] = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
r"""
GPTSAN transformer block.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
if the model is configured as a decoder.
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up
decoding. If `past_key_values` are used, the user can optionally input only the last
`decoder_input_ids` (those that don't have their past key value states given to this model) of shape
`(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
head_mask (`numpy.ndarray` of shape `({0})`, `optional):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
output_attentions (`bool`) :
output attention probabirities.
output_router_tuple:
output experts router logits and expert id.
Returns:
Tuple[torch.Tensor[num_groups, tokens_per_group, hidden_dim],...]
"""
atten_out = self.self_attn(
hidden_states=hidden_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attention_output = atten_out[0]
if isinstance(self.feed_forward, GPTSanJapaneseLayerSparseFF):
sparse_out = self.feed_forward(attention_output, output_router_tuple)
if output_router_tuple:
hidden, router_tuple = sparse_out
else:
hidden = sparse_out
else:
hidden = self.feed_forward(attention_output)
outputs = (hidden,) + atten_out[1:]
if isinstance(self.feed_forward, GPTSanJapaneseLayerSparseFF) and output_router_tuple:
outputs += (router_tuple,)
return outputs
class GPTSanJapanesePreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = GPTSanJapaneseConfig
base_model_prefix = "gptsan_japanese"
supports_gradient_checkpointing = False
_no_split_modules = ["GPTSanJapaneseBlock"]
_skip_keys_device_placement = "past_key_values"
@property
def dummy_inputs(self):
input_ids = torch.tensor(DUMMY_INPUTS)
input_mask = torch.tensor(DUMMY_MASK)
dummy_inputs = {
"input_ids": input_ids,
"attention_mask": input_mask,
}
return dummy_inputs
def _init_weights(self, module):
"""Initialize the weights"""
factor = self.config.initializer_factor # Used for testing weights initialization
if isinstance(module, nn.LayerNorm):
module.weight.data.fill_(factor * 1.0)
module.bias.data.zero_()
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=factor * 1.0)
elif isinstance(module, GPTSanJapaneseModel):
# Mesh TensorFlow embeddings initialization
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
module.embed_tokens.weight.data.normal_(mean=0.0, std=factor * 1.0)
module.position_embeddings.weight.data.normal_(mean=0.0, std=factor * 1.0)
if hasattr(module, "extra_position_embeddings") and module.extra_position_embeddings is not None:
module.extra_position_embeddings.weight.data.normal_(mean=0.0, std=factor * 1.0)
elif isinstance(module, (GPTSanJapaneseModel, GPTSanJapaneseForConditionalGeneration)):
# Mesh TensorFlow embeddings initialization
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
module.final_logits_bias.data.normal_(mean=0.0, std=factor * 1.0)
if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
elif isinstance(module, GPTSanJapaneseDenseActDense):
# Mesh TensorFlow FF initialization
# See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
# and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
if hasattr(module.wi, "bias") and module.wi.bias is not None:
module.wi.bias.data.zero_()
module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
if hasattr(module.wo, "bias") and module.wo.bias is not None:
module.wo.bias.data.zero_()
elif isinstance(module, GPTSanJapaneseAttention):
# Multi-headed attention
d_model = self.config.d_model
key_value_proj_dim = self.config.d_model
n_heads = self.config.num_heads
module.k_proj.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
module.v_proj.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
module.q_proj.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
elif isinstance(module, GPTSanJapaneseSparseMLP):
# Mesh TensorFlow attention initialization to avoid scaling before softmax
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
d_model = self.config.d_model
key_value_proj_dim = self.config.d_model
n_heads = self.config.num_heads
module.router.classifier.weight.data.normal_(mean=0.0, std=factor * 1)
for idx in range(self.config.num_experts):
module.experts[f"expert_{idx}"].wi.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
module.experts[f"expert_{idx}"].wo.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (GPTSanJapaneseAttention,)):
module.gradient_checkpointing = value
# Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right
def _shift_right(self, input_ids):
decoder_start_token_id = self.config.decoder_start_token_id
pad_token_id = self.config.pad_token_id
if decoder_start_token_id is None:
raise ValueError(
"self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id."
"See T5 docs for more information."
)
# shift inputs to the right
if is_torch_fx_proxy(input_ids):
# Item assignment is not supported natively for proxies.
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
else:
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
GPTSAN_JAPANESE_START_DOCSTRING = r"""
The [GPTSAN-japanese](https://github.com/tanreinama/GPTSAN) model was proposed in General-purpose Swich transformer
based Japanese language model
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`GPTSanJapaneseConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
GPTSAN_JAPANESE_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. GPTSAN-japanese is a model that generates sentence
continuations or predicts tokens at mask positions. Special tokens required for inputs to the model are
automatically appended.
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
token_type_ids (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
An input that masks the Prefix part in the Prefix-LM input. Mask values selected in `[0, 1]`:
- 1 for tokens that are **prefix** input,
- 0 for tokens that are **not-prefix** input.
spout (`torch.Tensor` of shape `(batch_size, config.d_spout)`):
This vector is transformed through an 8-layer FFN and can be used instead of `past_key_values`.
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
input (see `past_key_values`). This is useful if you want more control over how to convert
`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
Router logits of the decoder model, useful to compute the auxiliary loss for Mixture of Experts models.
"""
@add_start_docstrings(
"The bare GPTSAN-japanese Model transformer outputting raw hidden-states without any specific head on top.",
GPTSAN_JAPANESE_START_DOCSTRING,
)
class GPTSanJapaneseModel(GPTSanJapanesePreTrainedModel):
def __init__(self, config: GPTSanJapaneseConfig):
super().__init__(config)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.d_model)
self.config = copy.deepcopy(config)
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
self.last_project = nn.Linear(config.d_model, config.d_model, bias=True)
self.act = ACT2FN["swish"]
self.blocks = torch.nn.ModuleList([])
for _ in range(config.num_switch_layers):
self.blocks.append(GPTSanJapaneseBlock(config))
for _ in range(config.num_ext_layers):
self.blocks.append(GPTSanJapaneseBlock(config, ext_layer=True))
if config.num_ext_layers > 0:
self.extra_position_embeddings = nn.Embedding(config.max_position_embeddings, config.d_model)
if config.d_spout:
spouts = []
for _ in range(8):
spouts.append(nn.Linear(config.d_spout, config.d_spout, bias=False))
spouts.append(nn.Tanh())
spouts.append(nn.Linear(config.d_spout, config.num_layers * 2 * config.d_model, bias=False))
self.spout = nn.Sequential(*spouts)
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, new_embeddings):
self.embed_tokens = new_embeddings
@add_start_docstrings_to_model_forward(GPTSAN_JAPANESE_INPUTS_DOCSTRING)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.FloatTensor] = None,
spout: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
num_precontext: Optional[torch.LongTensor] = None,
) -> Union[MoEModelOutputWithPastAndCrossAttentions, Tuple[torch.FloatTensor]]:
r"""
num_precontext (`torch.LongTensor` of shape `(batch_size,1)`):
length of `hybrid` input tokens in the input. Tokens up to this length refer to both front and back like
BERT, tokens after that refer only to front like GPT. see also:
https://github.com/tanreinama/GPTSAN/blob/main/report/model.md
Returns:
`MoEModelOutputWithPastAndCrossAttentions` or `tuple` if `return_dict` returns
MoEModelOutputWithPastAndCrossAttentions insted of tuple
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
device = self.position_embeddings.weight.device
if input_ids is None:
input_ids = torch.zeros([1, 1]).int().to(device) # dummy for input_ids was None
num_pasts_contexts = 0
num_batch = input_ids.shape[0]
pasts_or_spout_value = None
if past_key_values is not None:
num_pasts_contexts = past_key_values[0][0].shape[2]
elif self.config.d_spout and spout is not None:
# `spout` is a special input vector specific to GPTSAN
# This controls the output by projecting embedded information such as the class of sentences during learning.
# It should passed instead of the first past_key_value.
# See the original GPTSAN repository for details
num_pasts_contexts += 1
# If there is an attention_mask, increase first one for spout
if self.config.d_spout and spout is not None and attention_mask is not None:
attention_mask_with_spout = torch.ones(num_batch, attention_mask.shape[1] + 1, device=device)
attention_mask_with_spout[:, 1:] -= 1 - attention_mask # 1st token should be spout
attention_mask = attention_mask_with_spout # update attention_mask
if num_precontext is not None:
# `num_precontext` is the number of tokens that refer to each other in prefix-lm
# created per batch, so dimension of num_precontext should be [batch, 1]
if not (
len(num_precontext.shape) == 2 and num_precontext.shape[1] == 1
): # num_precontext Should be [batch,1]
raise ValueError("num_precontext should be [batch, 1] size.")
num_precontext = torch.reshape(num_precontext, [-1])
else:
num_precontext = torch.zeros([num_batch]).int().to(device)
num_input_contexts = input_ids.shape[1]
num_output_contexts = num_input_contexts + num_pasts_contexts
hidden_states = self.embed_tokens(input_ids)
if past_key_values is not None:
pasts_or_spout_value = past_key_values
elif self.config.d_spout and spout is not None:
# Make vector from `spout` of GPTSAN to the same shape as past_key_values
pasts_or_spout_value = self.spout(spout) # projecting `spout` vector
pasts_or_spout_value = torch.reshape(
pasts_or_spout_value,
[
num_batch,
self.config.num_layers,
2,
self.config.num_heads,
num_pasts_contexts,
self.config.d_model // self.config.num_heads,
],
)
pasts_or_spout_value = torch.split(pasts_or_spout_value, [1] * self.config.num_layers, dim=1)
# make same shape as past_key_values
pasts_or_spout_value = tuple(
tuple([b.squeeze(1) for b in torch.split(a.squeeze(1), [1, 1], dim=1)]) for a in pasts_or_spout_value
)
else:
pasts_or_spout_value = [None] * self.config.num_layers
# Token position considering spout and pasts
token_position = torch.arange(num_input_contexts).to(device) + num_pasts_contexts
if attention_mask is None:
attention_mask = torch.ones(num_batch, num_input_contexts, device=device)
# positions for get position_embeddings
gather_position = (
(