Skip to content

Commit

Permalink
add a structured dropout for incoming data being perceived, after see…
Browse files Browse the repository at this point in the history
…ing PerAct paper use perceiver io successfully for robotics
  • Loading branch information
lucidrains committed Dec 5, 2022
1 parent abbb5d5 commit 441c6c6
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 3 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ model = PerceiverIO(
latent_heads = 8, # number of heads for latent self attention, 8
cross_dim_head = 64, # number of dimensions per cross attention head
latent_dim_head = 64, # number of dimensions per latent self attention head
weight_tie_layers = False # whether to weight tie layers (optional, as indicated in the diagram)
weight_tie_layers = False, # whether to weight tie layers (optional, as indicated in the diagram)
seq_dropout_prob = 0.2 # fraction of the tokens from the input sequence to dropout (structured dropout, for saving compute and regularizing effects)
)

seq = torch.randn(1, 512, 32)
Expand Down
32 changes: 31 additions & 1 deletion perceiver_pytorch/perceiver_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,28 @@ def cached_fn(*args, _cache = True, **kwargs):
return cache
return cached_fn

# structured dropout, more effective than traditional attention dropouts

def dropout_seq(seq, mask, dropout):
b, n, *_, device = *seq.shape, seq.device
logits = torch.randn(b, n, device = device)

if exists(mask):
logits = logits.masked_fill(~mask, -torch.finfo(logits.dtype).max)

num_keep = max(1, int((1 - dropout) * n))
keep_indices = logits.topk(num_keep, dim = 1).indices

batch_indices = torch.arange(b, device = device)
batch_indices = rearrange(batch_indices, 'b -> b 1')

seq = seq[batch_indices, keep_indices]

if exists(mask):
mask = mask[batch_indices, keep_indices]

return seq, mask

# helper classes

class PreNorm(nn.Module):
Expand Down Expand Up @@ -117,9 +139,12 @@ def __init__(
cross_dim_head = 64,
latent_dim_head = 64,
weight_tie_layers = False,
decoder_ff = False
decoder_ff = False,
seq_dropout_prob = 0.
):
super().__init__()
self.seq_dropout_prob = seq_dropout_prob

self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))

self.cross_attend_blocks = nn.ModuleList([
Expand Down Expand Up @@ -157,6 +182,11 @@ def forward(

cross_attn, cross_ff = self.cross_attend_blocks

# structured dropout (as done in perceiver AR https://arxiv.org/abs/2202.07765)

if self.training and self.seq_dropout_prob > 0.:
data, mask = dropout_seq(data, mask, self.seq_dropout_prob)
print(data.shape, mask.shape)
# cross attention only happens once for Perceiver IO

x = cross_attn(x, context = data, mask = mask) + x
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
setup(
name = 'perceiver-pytorch',
packages = find_packages(),
version = '0.8.3',
version = '0.8.4',
license='MIT',
description = 'Perceiver - Pytorch',
long_description_content_type = 'text/markdown',
author = 'Phil Wang',
author_email = 'lucidrains@gmail.com',
url = 'https://github.com/lucidrains/perceiver-pytorch',
Expand Down

0 comments on commit 441c6c6

Please sign in to comment.