You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
you send t first and then x_t
but in the forward of the model
def forward(self, x, t, **kwargs):
t = self.time_pos_emb(t)
t = self.mlp(t)
time_embed = t.view(x.size(0), 1, self.emb_dim, self.n_blocks, self.depth)
x = self.first(x)
x_embed_axial = x + self.axial_pos_emb(x).type(x.type())
# x_embed_axial_time = x_embed_axial + time_embed
h = torch.zeros_like(x_embed_axial)
for i, block in enumerate(self.transformer_blocks):
h = h + x_embed_axial
for j, transformer in enumerate(block):
h = transformer(h + time_embed[..., i, j])
h = self.norm(h)
return self.out(h)
x is first and t is second
The text was updated successfully, but these errors were encountered:
In the predict start method
you send t first and then x_t
but in the forward of the model
x is first and t is second
The text was updated successfully, but these errors were encountered: