Skip to content

Commit

Permalink
go with what @cfoster0 says
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Mar 5, 2021
1 parent cf4d8d9 commit 005279f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
6 changes: 4 additions & 2 deletions glom_pytorch/glom_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def forward(self, img, iters = None, return_all = False):
n = tokens.shape[1]

pos_embs = self.pos_emb(torch.arange(n, device = device))
bottom_level = tokens + rearrange(pos_embs, 'n d -> () n d')
pos_embs = rearrange(pos_embs, 'n d -> () n () d')

bottom_level = tokens
bottom_level = rearrange(bottom_level, 'b n d -> b n () d')

levels = repeat(self.init_levels, 'l d -> b n l d', b = b, n = n)
Expand All @@ -104,7 +106,7 @@ def forward(self, img, iters = None, return_all = False):
levels_with_input = torch.cat((bottom_level, levels), dim = -2) # each iteration, attach original input (with positional embedding) at the bottom level

bottom_up_out = self.bottom_up(levels_with_input[..., 1:-1, :])
top_down_out = self.top_down(levels_with_input[..., 2:, :])
top_down_out = self.top_down(levels_with_input[..., 2:, :] + pos_embs) # positional embeddings given to top-down networks

bottom_up_out = torch.cat((bottom_level, bottom_up_out), dim = -2)
top_down_out = F.pad(top_down_out, (0, 0, 0, 1), value = 0.)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'glom-pytorch',
packages = find_packages(),
version = '0.0.5',
version = '0.0.6',
license='MIT',
description = 'Glom - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 005279f

Please sign in to comment.