diff --git a/rela_transformer/rela_transformer.py b/rela_transformer/rela_transformer.py index 4a8a47d..3e4b5db 100644 --- a/rela_transformer/rela_transformer.py +++ b/rela_transformer/rela_transformer.py @@ -19,13 +19,13 @@ def __init__( super().__init__() self.scale = dim ** -0.5 self.eps = eps - self.to_gate = nn.Linear(dim, dim, bias = False) + self.w = nn.Parameter(torch.ones(dim)) self.g = nn.Parameter(torch.ones(dim)) def forward(self, x): norm = torch.norm(x, dim = -1, keepdim = True) * self.scale normed_x = x / norm.clamp(min = self.eps) * self.g - return normed_x * self.to_gate(x).sigmoid() + return normed_x * (x * self.w).sigmoid() def FeedForward(dim, mult = 4): return nn.Sequential( @@ -59,9 +59,9 @@ def __init__( self.mem_k = nn.Parameter(torch.randn(num_memory_kv, inner_dim)) self.mem_v = nn.Parameter(torch.randn(num_memory_kv, inner_dim)) + self.norm_values = GatedRMSNorm(dim_head) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim), - GatedRMSNorm(dim) ) def forward(self, x, mask = None): @@ -95,6 +95,8 @@ def forward(self, x, mask = None): attn = attn.masked_fill(causal_mask, 0.) out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = self.norm_values(out) + out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) diff --git a/setup.py b/setup.py index 4d6e132..0bb886a 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'rela-transformer', packages = find_packages(exclude=[]), - version = '0.0.6', + version = '0.0.7', license='MIT', description = 'ReLA Transformer', author = 'Phil Wang',