13
13
# See the License for the specific language governing permissions and
14
14
# limitations under the License.
15
15
16
- from typing import Any , Dict , Optional , Tuple , Union
16
+ from typing import Optional , Tuple
17
17
18
18
import torch
19
19
import torch .nn as nn
22
22
from ...utils import logging
23
23
from ...utils .torch_utils import maybe_allow_in_graph
24
24
from ..attention import Attention , FeedForward , JointAttnProcessor2_0
25
- from ..embeddings import PatchEmbed , MochiCombinedTimestepCaptionEmbedding
25
+ from ..embeddings import MochiCombinedTimestepCaptionEmbedding , PatchEmbed
26
26
from ..modeling_outputs import Transformer2DModelOutput
27
27
from ..modeling_utils import ModelMixin
28
28
from ..normalization import AdaLayerNormContinuous , MochiRMSNormZero , RMSNorm
@@ -46,14 +46,14 @@ def __init__(
46
46
super ().__init__ ()
47
47
48
48
self .context_pre_only = context_pre_only
49
-
49
+
50
50
self .norm1 = MochiRMSNormZero (dim , 4 * dim )
51
51
52
52
if context_pre_only :
53
53
self .norm1_context = MochiRMSNormZero (dim , 4 * pooled_projection_dim )
54
54
else :
55
55
self .norm1_context = RMSNorm (pooled_projection_dim , eps = 1e-6 , elementwise_affine = False )
56
-
56
+
57
57
self .attn = Attention (
58
58
query_dim = dim ,
59
59
heads = num_attention_heads ,
@@ -67,7 +67,7 @@ def __init__(
67
67
68
68
self .norm2 = RMSNorm (dim , eps = 1e-6 , elementwise_affine = False )
69
69
self .norm2_context = RMSNorm (pooled_projection_dim , eps = 1e-6 , elementwise_affine = False )
70
-
70
+
71
71
self .norm3 = RMSNorm (dim , eps = 1e-6 , elementwise_affine = False )
72
72
self .norm3_context = RMSNorm (pooled_projection_dim , eps = 1e-56 , elementwise_affine = False )
73
73
@@ -76,15 +76,23 @@ def __init__(
76
76
77
77
self .norm4 = RMSNorm (dim , eps = 1e-6 , elementwise_affine = False )
78
78
self .norm4_context = RMSNorm (pooled_projection_dim , eps = 1e-56 , elementwise_affine = False )
79
-
80
- def forward (self , hidden_states : torch .Tensor , encoder_hidden_states : torch .Tensor , temb : torch .Tensor , image_rotary_emb : Optional [torch .Tensor ] = None ) -> Tuple [torch .Tensor , torch .Tensor ]:
79
+
80
+ def forward (
81
+ self ,
82
+ hidden_states : torch .Tensor ,
83
+ encoder_hidden_states : torch .Tensor ,
84
+ temb : torch .Tensor ,
85
+ image_rotary_emb : Optional [torch .Tensor ] = None ,
86
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
81
87
norm_hidden_states , gate_msa , scale_mlp , gate_mlp = self .norm1 (hidden_states , temb )
82
88
83
89
if self .context_pre_only :
84
- norm_encoder_hidden_states , enc_gate_msa , enc_scale_mlp , enc_gate_mlp = self .norm1_context (encoder_hidden_states , temb )
90
+ norm_encoder_hidden_states , enc_gate_msa , enc_scale_mlp , enc_gate_mlp = self .norm1_context (
91
+ encoder_hidden_states , temb
92
+ )
85
93
else :
86
94
norm_encoder_hidden_states = self .norm1_context (encoder_hidden_states )
87
-
95
+
88
96
attn_hidden_states , context_attn_hidden_states = self .attn (
89
97
hidden_states = norm_hidden_states ,
90
98
encoder_hidden_states = norm_encoder_hidden_states ,
@@ -94,16 +102,20 @@ def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tens
94
102
hidden_states = hidden_states + self .norm2 (attn_hidden_states ) * torch .tanh (gate_msa ).unsqueeze (1 )
95
103
hidden_states = self .norm3 (hidden_states ) * (1 + scale_mlp .unsqueeze (1 ))
96
104
if not self .context_pre_only :
97
- encoder_hidden_states = encoder_hidden_states + self .norm2_context (context_attn_hidden_states ) * torch .tanh (enc_gate_msa ).unsqueeze (1 )
98
- encoder_hidden_states = encoder_hidden_states + self .norm3_context (encoder_hidden_states ) * (1 + enc_scale_mlp .unsqueeze (1 ))
99
-
105
+ encoder_hidden_states = encoder_hidden_states + self .norm2_context (
106
+ context_attn_hidden_states
107
+ ) * torch .tanh (enc_gate_msa ).unsqueeze (1 )
108
+ encoder_hidden_states = encoder_hidden_states + self .norm3_context (encoder_hidden_states ) * (
109
+ 1 + enc_scale_mlp .unsqueeze (1 )
110
+ )
111
+
100
112
ff_output = self .ff (hidden_states )
101
113
context_ff_output = self .ff_context (encoder_hidden_states )
102
-
114
+
103
115
hidden_states = hidden_states + ff_output * torch .tanh (gate_mlp ).unsqueeze (1 )
104
116
if not self .context_pre_only :
105
117
encoder_hidden_states = encoder_hidden_states + context_ff_output * torch .tanh (enc_gate_mlp ).unsqueeze (0 )
106
-
118
+
107
119
return hidden_states , encoder_hidden_states
108
120
109
121
@@ -140,33 +152,35 @@ def __init__(
140
152
time_embed_dim = time_embed_dim ,
141
153
num_attention_heads = 8 ,
142
154
)
143
-
155
+
144
156
self .patch_embed = PatchEmbed (
145
157
patch_size = patch_size ,
146
158
in_channels = in_channels ,
147
159
embed_dim = inner_dim ,
148
160
)
149
161
150
- self .pos_frequencies = nn .Parameter (
151
- torch .empty (3 , num_attention_heads , attention_head_dim // 2 )
162
+ self .pos_frequencies = nn .Parameter (torch .empty (3 , num_attention_heads , attention_head_dim // 2 ))
163
+
164
+ self .transformer_blocks = nn .ModuleList (
165
+ [
166
+ MochiTransformerBlock (
167
+ dim = inner_dim ,
168
+ num_attention_heads = num_attention_heads ,
169
+ attention_head_dim = attention_head_dim ,
170
+ pooled_projection_dim = pooled_projection_dim ,
171
+ qk_norm = qk_norm ,
172
+ activation_fn = activation_fn ,
173
+ context_pre_only = i < num_layers - 1 ,
174
+ )
175
+ for i in range (num_layers )
176
+ ]
152
177
)
153
178
154
- self .transformer_blocks = nn .ModuleList ([
155
- MochiTransformerBlock (
156
- dim = inner_dim ,
157
- num_attention_heads = num_attention_heads ,
158
- attention_head_dim = attention_head_dim ,
159
- pooled_projection_dim = pooled_projection_dim ,
160
- qk_norm = qk_norm ,
161
- activation_fn = activation_fn ,
162
- context_pre_only = i < num_layers - 1 ,
163
- )
164
- for i in range (num_layers )
165
- ])
166
-
167
- self .norm_out = AdaLayerNormContinuous (inner_dim , inner_dim , elementwise_affine = False , eps = 1e-6 , norm_type = "layer_norm" )
179
+ self .norm_out = AdaLayerNormContinuous (
180
+ inner_dim , inner_dim , elementwise_affine = False , eps = 1e-6 , norm_type = "layer_norm"
181
+ )
168
182
self .proj_out = nn .Linear (inner_dim , patch_size * patch_size * out_channels )
169
-
183
+
170
184
def forward (
171
185
self ,
172
186
hidden_states : torch .Tensor ,
@@ -193,13 +207,13 @@ def forward(
193
207
temb = temb ,
194
208
image_rotary_emb = image_rotary_emb ,
195
209
)
196
-
210
+
197
211
# TODO(aryan): do something with self.pos_frequencies
198
212
199
213
hidden_states = self .norm_out (hidden_states , temb )
200
214
hidden_states = self .proj_out (hidden_states )
201
215
202
- hidden_states = hidden_states .reshape (batch_size , num_frames , post_patch_height , post_patch_height , p , p , - 1 )
216
+ hidden_states = hidden_states .reshape (batch_size , num_frames , post_patch_height , post_patch_width , p , p , - 1 )
203
217
hidden_states = hidden_states .permute (0 , 6 , 1 , 2 , 4 , 3 , 5 )
204
218
output = hidden_states .reshape (batch_size , - 1 , num_frames , height , width )
205
219
0 commit comments