Skip to content

Commit

Permalink
place sigma parameters in log space, as in official repo
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 30, 2021
1 parent 6ee89b1 commit 94b3390
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'slot_attention',
packages = find_packages(),
version = '1.0.2',
version = '1.1.0',
license='MIT',
description = 'Implementation of Slot Attention in Pytorch',
author = 'Phil Wang',
Expand Down
10 changes: 7 additions & 3 deletions slot_attention/slot_attention.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from torch import nn
import torch
from torch import nn
from torch.nn import init

class SlotAttention(nn.Module):
def __init__(self, num_slots, dim, iters = 3, eps = 1e-8, hidden_dim = 128):
Expand All @@ -10,7 +11,9 @@ def __init__(self, num_slots, dim, iters = 3, eps = 1e-8, hidden_dim = 128):
self.scale = dim ** -0.5

self.slots_mu = nn.Parameter(torch.randn(1, 1, dim))
self.slots_sigma = nn.Parameter(torch.randn(1, 1, dim) * self.scale)

self.slots_logsigma = nn.Parameter(torch.zeros(1, 1, dim))
init.xavier_uniform_(self.slots_logsigma)

self.to_q = nn.Linear(dim, dim)
self.to_k = nn.Linear(dim, dim)
Expand All @@ -35,7 +38,8 @@ def forward(self, inputs, num_slots = None):
n_s = num_slots if num_slots is not None else self.num_slots

mu = self.slots_mu.expand(b, n_s, -1)
sigma = self.slots_sigma.expand(b, n_s, -1)
sigma = self.slots_logsigma.exp().expand(b, n_s, -1)

slots = mu + sigma * torch.randn(mu.shape)

inputs = self.norm_input(inputs)
Expand Down
10 changes: 7 additions & 3 deletions slot_attention/slot_attention_experimental.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from torch import nn
import torch
from torch import nn
from torch.nn import init

class WeightedAttention(nn.Module):
def __init__(self, dim, eps = 1e-8, softmax_dim = 1, weighted_mean_dim = 2):
Expand Down Expand Up @@ -82,7 +83,9 @@ def __init__(self, num_slots, dim, iters = 3, eps = 1e-8, hidden_dim = 128):
self.norm_inputs = nn.LayerNorm(dim)

self.slots_mu = nn.Parameter(torch.randn(1, 1, dim))
self.slots_sigma = nn.Parameter(torch.randn(1, 1, dim) * scale)

self.slots_logsigma = nn.Parameter(torch.zeros(1, 1, dim))
init.xavier_uniform_(self.slots_logsigma)

self.slots_to_inputs_attn = GatedResidual(dim, WeightedAttention(dim, eps = eps))
self.slots_ff = GatedResidual(dim, FeedForward(dim, hidden_dim))
Expand All @@ -95,7 +98,8 @@ def forward(self, inputs, num_slots = None):
n_s = num_slots if num_slots is not None else self.num_slots

mu = self.slots_mu.expand(b, n_s, -1)
sigma = self.slots_sigma.expand(b, n_s, -1)
sigma = self.slots_logsigma.exp().expand(b, n_s, -1)

slots = mu + sigma * torch.randn(mu.shape)

inputs = self.norm_inputs(inputs)
Expand Down

0 comments on commit 94b3390

Please sign in to comment.