Skip to content

Commit

Permalink
fix layer norm in feedforward
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Mar 22, 2023
1 parent 3280b3d commit 4430a87
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 4 additions & 4 deletions make_a_video_pytorch/make_a_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ def forward(self, x):

# layernorm 3d

class LayerNorm(nn.Module):
class ChanLayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.g = nn.Parameter(torch.ones(dim))
self.g = nn.Parameter(torch.ones(dim, 1, 1, 1))

def forward(self, x):
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
Expand Down Expand Up @@ -79,7 +79,7 @@ def __init__(self, dim, mult = 4):
)

self.proj_out = nn.Sequential(
LayerNorm(inner_dim),
ChanLayerNorm(inner_dim),
nn.Conv3d(inner_dim, dim, 1, bias = False)
)

Expand Down Expand Up @@ -175,7 +175,7 @@ def __init__(
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads

self.norm = LayerNorm(dim)
self.norm = nn.LayerNorm(dim)

self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'make-a-video-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.2',
version = '0.2.0',
license='MIT',
description = 'Make-A-Video - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 4430a87

Please sign in to comment.