Skip to content

Commit

Permalink
use GRUCell instead of GRU
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 23, 2020
1 parent 272cab9 commit cb3a182
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'slot_attention',
packages = find_packages(),
version = '0.0.6',
version = '0.1.0',
license='MIT',
description = 'Implementation of Slot Attention in Pytorch',
author = 'Phil Wang',
Expand Down
8 changes: 4 additions & 4 deletions slot_attention/slot_attention.py
Expand Up @@ -16,7 +16,7 @@ def __init__(self, num_slots, dim, iters = 3, eps = 1e-8, hidden_dim = 128):
self.to_k = nn.Linear(dim, dim, bias = False)
self.to_v = nn.Linear(dim, dim, bias = False)

self.gru = nn.GRU(dim, dim)
self.gru = nn.GRUCell(dim, dim)

hidden_dim = max(dim, hidden_dim)

Expand Down Expand Up @@ -53,9 +53,9 @@ def forward(self, inputs, num_slots = None):

updates = torch.einsum('bjd,bij->bid', v, attn)

slots, _ = self.gru(
updates.reshape(1, -1, d),
slots_prev.reshape(1, -1, d)
slots = self.gru(
updates.reshape(-1, d),
slots_prev.reshape(-1, d)
)

slots = slots.reshape(b, -1, d)
Expand Down
8 changes: 4 additions & 4 deletions slot_attention/slot_attention_experimental.py
Expand Up @@ -42,17 +42,17 @@ def forward(self, x):
class GatedResidual(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.gru = nn.GRU(dim, dim)
self.gru = nn.GRUCell(dim, dim)
self.fn = fn
def forward(self, *args):
inputs = args[0]
b, _, d = inputs.shape

updates = self.fn(*args)

inputs, _ = self.gru(
updates.reshape(1, -1, d),
inputs.reshape(1, -1, d)
inputs = self.gru(
updates.reshape(-1, d),
inputs.reshape(-1, d)
)
return inputs.reshape(b, -1, d)

Expand Down

0 comments on commit cb3a182

Please sign in to comment.