21
21
from ...configuration_utils import ConfigMixin , register_to_config
22
22
from ...utils import logging
23
23
from ...utils .torch_utils import maybe_allow_in_graph
24
- from ..attention import Attention , FeedForward
25
- from ..embeddings import PatchEmbed , MochiAttentionPool , TimestepEmbedding , Timesteps
24
+ from ..attention import Attention , FeedForward , JointAttnProcessor2_0
25
+ from ..embeddings import PatchEmbed , MochiCombinedTimestepCaptionEmbedding
26
26
from ..modeling_outputs import Transformer2DModelOutput
27
27
from ..modeling_utils import ModelMixin
28
- from ..normalization import MochiRMSNormZero , RMSNorm
28
+ from ..normalization import AdaLayerNormContinuous , MochiRMSNormZero , RMSNorm
29
29
30
30
31
31
logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
@@ -38,61 +38,73 @@ def __init__(
38
38
dim : int ,
39
39
num_attention_heads : int ,
40
40
attention_head_dim : int ,
41
- caption_dim : int ,
41
+ pooled_projection_dim : int ,
42
+ qk_norm : str = "rms_norm" ,
42
43
activation_fn : str = "swiglu" ,
43
- update_captions : bool = True ,
44
+ context_pre_only : bool = True ,
44
45
) -> None :
45
46
super ().__init__ ()
46
47
47
- self .update_captions = update_captions
48
+ self .context_pre_only = context_pre_only
48
49
49
50
self .norm1 = MochiRMSNormZero (dim , 4 * dim )
50
51
51
- if update_captions :
52
- self .norm_context1 = MochiRMSNormZero (dim , 4 * caption_dim )
52
+ if context_pre_only :
53
+ self .norm1_context = MochiRMSNormZero (dim , 4 * pooled_projection_dim )
53
54
else :
54
- self .norm_context1 = RMSNorm (caption_dim , eps = 1e-5 , elementwise_affine = False )
55
+ self .norm1_context = RMSNorm (pooled_projection_dim , eps = 1e-6 , elementwise_affine = False )
55
56
56
57
self .attn = Attention (
57
58
query_dim = dim ,
58
59
heads = num_attention_heads ,
59
60
attention_head_dim = attention_head_dim ,
60
61
out_dim = 4 * dim ,
61
- qk_norm = "rms_norm" ,
62
- eps = 1e-5 ,
63
- elementwise_affine = False ,
64
- )
65
- self .attn_context = Attention (
66
- query_dim = dim ,
67
- heads = num_attention_heads ,
68
- attention_head_dim = attention_head_dim ,
69
- out_dim = 4 * caption_dim if update_captions else caption_dim ,
70
- qk_norm = "rms_norm" ,
71
- eps = 1e-5 ,
62
+ qk_norm = qk_norm ,
63
+ eps = 1e-6 ,
72
64
elementwise_affine = False ,
65
+ processor = JointAttnProcessor2_0 (),
73
66
)
74
67
68
+ self .norm2 = RMSNorm (dim , eps = 1e-6 , elementwise_affine = False )
69
+ self .norm2_context = RMSNorm (pooled_projection_dim , eps = 1e-6 , elementwise_affine = False )
70
+
71
+ self .norm3 = RMSNorm (dim , eps = 1e-6 , elementwise_affine = False )
72
+ self .norm3_context = RMSNorm (pooled_projection_dim , eps = 1e-56 , elementwise_affine = False )
73
+
75
74
self .ff = FeedForward (dim , mult = 4 , activation_fn = activation_fn )
76
- self .ff_context = FeedForward (caption_dim , mult = 4 , activation_fn = activation_fn )
75
+ self .ff_context = FeedForward (pooled_projection_dim , mult = 4 , activation_fn = activation_fn )
76
+
77
+ self .norm4 = RMSNorm (dim , eps = 1e-6 , elementwise_affine = False )
78
+ self .norm4_context = RMSNorm (pooled_projection_dim , eps = 1e-56 , elementwise_affine = False )
77
79
78
80
def forward (self , hidden_states : torch .Tensor , encoder_hidden_states : torch .Tensor , temb : torch .Tensor , image_rotary_emb : Optional [torch .Tensor ] = None ) -> Tuple [torch .Tensor , torch .Tensor ]:
79
81
norm_hidden_states , gate_msa , scale_mlp , gate_mlp = self .norm1 (hidden_states , temb )
80
82
81
- if self .update_captions :
82
- norm_encoder_hidden_states , enc_gate_msa , enc_scale_mlp , enc_gate_mlp = self .norm_context1 (encoder_hidden_states , temb )
83
+ if self .context_pre_only :
84
+ norm_encoder_hidden_states , enc_gate_msa , enc_scale_mlp , enc_gate_mlp = self .norm1_context (encoder_hidden_states , temb )
83
85
else :
84
- norm_encoder_hidden_states = self .norm_context1 (encoder_hidden_states )
86
+ norm_encoder_hidden_states = self .norm1_context (encoder_hidden_states )
85
87
86
- attn_hidden_states = self .attn (
88
+ attn_hidden_states , context_attn_hidden_states = self .attn (
87
89
hidden_states = norm_hidden_states ,
88
- encoder_hidden_states = None ,
90
+ encoder_hidden_states = norm_encoder_hidden_states ,
89
91
image_rotary_emb = image_rotary_emb ,
90
92
)
91
- attn_encoder_hidden_states = self .attn_context (
92
- hidden_states = norm_encoder_hidden_states ,
93
- encoder_hidden_states = None ,
94
- image_rotary_emb = None ,
95
- )
93
+
94
+ hidden_states = hidden_states + self .norm2 (attn_hidden_states ) * torch .tanh (gate_msa ).unsqueeze (1 )
95
+ hidden_states = self .norm3 (hidden_states ) * (1 + scale_mlp .unsqueeze (1 ))
96
+ if not self .context_pre_only :
97
+ encoder_hidden_states = encoder_hidden_states + self .norm2_context (context_attn_hidden_states ) * torch .tanh (enc_gate_msa ).unsqueeze (1 )
98
+ encoder_hidden_states = encoder_hidden_states + self .norm3_context (encoder_hidden_states ) * (1 + enc_scale_mlp .unsqueeze (1 ))
99
+
100
+ ff_output = self .ff (hidden_states )
101
+ context_ff_output = self .ff_context (encoder_hidden_states )
102
+
103
+ hidden_states = hidden_states + ff_output * torch .tanh (gate_mlp ).unsqueeze (1 )
104
+ if not self .context_pre_only :
105
+ encoder_hidden_states = encoder_hidden_states + context_ff_output * torch .tanh (enc_gate_mlp ).unsqueeze (0 )
106
+
107
+ return hidden_states , encoder_hidden_states
96
108
97
109
98
110
@maybe_allow_in_graph
@@ -106,32 +118,35 @@ def __init__(
106
118
num_attention_heads : int = 24 ,
107
119
attention_head_dim : int = 128 ,
108
120
num_layers : int = 48 ,
109
- caption_dim = 1536 ,
110
- mlp_ratio_x = 4.0 ,
111
- mlp_ratio_y = 4.0 ,
121
+ pooled_projection_dim : int = 1536 ,
112
122
in_channels = 12 ,
113
- qk_norm = True ,
114
- qkv_bias = False ,
115
- out_bias = True ,
123
+ out_channels : Optional [int ] = None ,
124
+ qk_norm : str = "rms_norm" ,
116
125
timestep_mlp_bias = True ,
117
126
timestep_scale = 1000.0 ,
118
- text_embed_dim = 4096 ,
127
+ text_embed_dim : int = 4096 ,
128
+ time_embed_dim : int = 256 ,
119
129
activation_fn : str = "swiglu" ,
120
- max_sequence_length = 256 ,
130
+ max_sequence_length : int = 256 ,
121
131
) -> None :
122
132
super ().__init__ ()
123
133
124
134
inner_dim = num_attention_heads * attention_head_dim
135
+ out_channels = out_channels or in_channels
136
+
137
+ self .time_embed = MochiCombinedTimestepCaptionEmbedding (
138
+ embedding_dim = text_embed_dim ,
139
+ pooled_projection_dim = pooled_projection_dim ,
140
+ time_embed_dim = time_embed_dim ,
141
+ num_attention_heads = 8 ,
142
+ )
125
143
126
144
self .patch_embed = PatchEmbed (
127
145
patch_size = patch_size ,
128
146
in_channels = in_channels ,
129
147
embed_dim = inner_dim ,
130
148
)
131
149
132
- self .caption_embedder = MochiAttentionPool (num_attention_heads = 8 , embed_dim = text_embed_dim , output_dim = inner_dim )
133
- self .caption_proj = nn .Linear (text_embed_dim , caption_dim )
134
-
135
150
self .pos_frequencies = nn .Parameter (
136
151
torch .empty (3 , num_attention_heads , attention_head_dim // 2 )
137
152
)
@@ -141,9 +156,53 @@ def __init__(
141
156
dim = inner_dim ,
142
157
num_attention_heads = num_attention_heads ,
143
158
attention_head_dim = attention_head_dim ,
144
- caption_dim = caption_dim ,
159
+ pooled_projection_dim = pooled_projection_dim ,
160
+ qk_norm = qk_norm ,
145
161
activation_fn = activation_fn ,
146
- update_captions = i < num_layers - 1 ,
162
+ context_pre_only = i < num_layers - 1 ,
147
163
)
148
164
for i in range (num_layers )
149
165
])
166
+
167
+ self .norm_out = AdaLayerNormContinuous (inner_dim , inner_dim , elementwise_affine = False , eps = 1e-6 , norm_type = "layer_norm" )
168
+ self .proj_out = nn .Linear (inner_dim , patch_size * patch_size * out_channels )
169
+
170
+ def forward (
171
+ self ,
172
+ hidden_states : torch .Tensor ,
173
+ encoder_hidden_states : torch .Tensor ,
174
+ timestep : torch .LongTensor ,
175
+ encoder_attention_mask : torch .Tensor ,
176
+ image_rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
177
+ return_dict : bool = True ,
178
+ ) -> torch .Tensor :
179
+ batch_size , num_channels , num_frames , height , width = hidden_states .shape
180
+ p = self .config .patch_size
181
+
182
+ post_patch_height = height // p
183
+ post_patch_width = width // p
184
+
185
+ temb , caption_proj = self .time_embed (timestep , encoder_hidden_states , encoder_attention_mask )
186
+
187
+ hidden_states = self .patch_embed (hidden_states )
188
+
189
+ for i , block in enumerate (self .transformer_blocks ):
190
+ hidden_states , encoder_hidden_states = block (
191
+ hidden_states = hidden_states ,
192
+ encoder_hidden_states = encoder_hidden_states ,
193
+ temb = temb ,
194
+ image_rotary_emb = image_rotary_emb ,
195
+ )
196
+
197
+ # TODO(aryan): do something with self.pos_frequencies
198
+
199
+ hidden_states = self .norm_out (hidden_states , temb )
200
+ hidden_states = self .proj_out (hidden_states )
201
+
202
+ hidden_states = hidden_states .reshape (batch_size , num_frames , post_patch_height , post_patch_height , p , p , - 1 )
203
+ hidden_states = hidden_states .permute (0 , 6 , 1 , 2 , 4 , 3 , 5 )
204
+ output = hidden_states .reshape (batch_size , - 1 , num_frames , height , width )
205
+
206
+ if not return_dict :
207
+ return (output ,)
208
+ return Transformer2DModelOutput (sample = output )
0 commit comments