diff --git a/setup.py b/setup.py index 3e84569..4bfc67b 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'slot_attention', packages = find_packages(), - version = '0.0.6', + version = '0.1.0', license='MIT', description = 'Implementation of Slot Attention in Pytorch', author = 'Phil Wang', diff --git a/slot_attention/slot_attention.py b/slot_attention/slot_attention.py index 3363e4d..aa4dd75 100644 --- a/slot_attention/slot_attention.py +++ b/slot_attention/slot_attention.py @@ -16,7 +16,7 @@ def __init__(self, num_slots, dim, iters = 3, eps = 1e-8, hidden_dim = 128): self.to_k = nn.Linear(dim, dim, bias = False) self.to_v = nn.Linear(dim, dim, bias = False) - self.gru = nn.GRU(dim, dim) + self.gru = nn.GRUCell(dim, dim) hidden_dim = max(dim, hidden_dim) @@ -53,9 +53,9 @@ def forward(self, inputs, num_slots = None): updates = torch.einsum('bjd,bij->bid', v, attn) - slots, _ = self.gru( - updates.reshape(1, -1, d), - slots_prev.reshape(1, -1, d) + slots = self.gru( + updates.reshape(-1, d), + slots_prev.reshape(-1, d) ) slots = slots.reshape(b, -1, d) diff --git a/slot_attention/slot_attention_experimental.py b/slot_attention/slot_attention_experimental.py index 03f7677..8ce42a5 100644 --- a/slot_attention/slot_attention_experimental.py +++ b/slot_attention/slot_attention_experimental.py @@ -42,7 +42,7 @@ def forward(self, x): class GatedResidual(nn.Module): def __init__(self, dim, fn): super().__init__() - self.gru = nn.GRU(dim, dim) + self.gru = nn.GRUCell(dim, dim) self.fn = fn def forward(self, *args): inputs = args[0] @@ -50,9 +50,9 @@ def forward(self, *args): updates = self.fn(*args) - inputs, _ = self.gru( - updates.reshape(1, -1, d), - inputs.reshape(1, -1, d) + inputs = self.gru( + updates.reshape(-1, d), + inputs.reshape(-1, d) ) return inputs.reshape(b, -1, d)