diff --git a/isab_pytorch/isab_pytorch.py b/isab_pytorch/isab_pytorch.py index 6000055..6aadf7d 100644 --- a/isab_pytorch/isab_pytorch.py +++ b/isab_pytorch/isab_pytorch.py @@ -18,6 +18,7 @@ def __init__(self, dim, heads = 8): self.scale = (dim // heads) ** -0.5 self.to_q = nn.Linear(dim, dim, bias = False) self.to_kv = nn.Linear(dim, dim * 2, bias = False) + self.to_out = nn.Linear(dim, dim) def forward(self, x, context, mask = None): h, scale = self.heads, self.scale @@ -36,7 +37,7 @@ def forward(self, x, context, mask = None): out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)', h = h) - return out + return self.to_out(out) class ISAB(nn.Module): def __init__( diff --git a/setup.py b/setup.py index 96571f6..a449a14 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'isab-pytorch', packages = find_packages(), - version = '0.0.3', + version = '0.1.0', license='MIT', description = 'Induced Set Attention Block - Pytorch', author = 'Phil Wang',