You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
so Phil - I love your work - I wish you could go extra few steps to help out users.
I found this class by François-Guillaume @frgfm - which adds in clear math coments.
I want to merge it but there's a bit of code drift don't want to introduce any bugs.
I beseech you to go extra step to help users bridge from papers to code.
Please give any clarity in arguments.
# Project input and context to get queries, keys & values
Throw in some maths as a comment / this is great as it bridges the paper to the code.
B x (num_heads * dim_k) * H * W -> B x num_heads x dim_k x (H * W)
importtorchfromtorchimportnn, einsumimporttorch.nn.functionalasFfromtypingimportOptional__all__= ['LambdaLayer']
classLambdaLayer(nn.Module):
"""Lambda layer from `"LambdaNetworks: Modeling long-range interactions without attention" <https://openreview.net/pdf?id=xTJEN-ggl1b>`_. The implementation was adapted from `lucidrains' <https://github.com/lucidrains/lambda-networks/blob/main/lambda_networks/lambda_networks.py>`. Args: in_channels (int): input channels out_channels (int, optional): output channels dim_k (int): key dimension n (int, optional): number of input pixels r (int, optional): receptive field for relative positional encoding num_heads (int, optional): number of attention heads dim_u (int, optional): intra-depth dimension """def__init__(
self,
in_channels: int,
out_channels: int,
dim_k: int,
n: Optional[int] =None,
r: Optional[int] =None,
num_heads: int=4,
dim_u: int=1
) ->None:
super().__init__()
self.u=dim_uself.num_heads=num_headsifout_channels%num_heads!=0:
raiseAssertionError('values dimension must be divisible by number of heads for multi-head query')
dim_v=out_channels//num_heads# Project input and context to get queries, keys & valuesself.to_q=nn.Conv2d(in_channels, dim_k*num_heads, 1, bias=False)
self.to_k=nn.Conv2d(in_channels, dim_k*dim_u, 1, bias=False)
self.to_v=nn.Conv2d(in_channels, dim_v*dim_u, 1, bias=False)
self.norm_q=nn.BatchNorm2d(dim_k*num_heads)
self.norm_v=nn.BatchNorm2d(dim_v*dim_u)
self.local_contexts=risnotNoneifrisnotNone:
ifr%2!=1:
raiseAssertionError('Receptive kernel size should be odd')
self.padding=r//2self.R=nn.Parameter(torch.randn(dim_k, dim_u, 1, r, r))
else:
ifnisNone:
raiseAssertionError('You must specify the total sequence length (h x w)')
self.pos_emb=nn.Parameter(torch.randn(n, n, dim_k, dim_u))
defforward(self, x: torch.Tensor) ->torch.Tensor:
b, c, h, w=x.shape# Project inputs & context to retrieve queries, keys and valuesq=self.to_q(x)
k=self.to_k(x)
v=self.to_v(x)
# Normalize queries & valuesq=self.norm_q(q)
v=self.norm_v(v)
# B x (num_heads * dim_k) * H * W -> B x num_heads x dim_k x (H * W)q=q.reshape(b, self.num_heads, -1, h*w)
# B x (dim_k * dim_u) * H * W -> B x dim_u x dim_k x (H * W)k=k.reshape(b, -1, self.u, h*w).permute(0, 2, 1, 3)
# B x (dim_v * dim_u) * H * W -> B x dim_u x dim_v x (H * W)v=v.reshape(b, -1, self.u, h*w).permute(0, 2, 1, 3)
# Normalized keysk=k.softmax(dim=-1)
# Content functionλc=einsum('b u k m, b u v m -> b k v', k, v)
Yc=einsum('b h k n, b k v -> b n h v', q, λc)
# Position functionifself.local_contexts:
# B x dim_u x dim_v x (H * W) -> B x dim_u x dim_v x H x Wv=v.reshape(b, self.u, v.shape[2], h, w)
λp=F.conv3d(v, self.R, padding=(0, self.padding, self.padding))
Yp=einsum('b h k n, b k v n -> b n h v', q, λp.flatten(3))
else:
λp=einsum('n m k u, b u v m -> b n k v', self.pos_emb, v)
Yp=einsum('b h k n, b n k v -> b n h v', q, λp)
Y=Yc+Yp# B x (H * W) x num_heads x dim_v -> B x (num_heads * dim_v) x H x Wout=Y.permute(0, 2, 3, 1).reshape(b, self.num_heads*v.shape[2], h, w)
returnout
The text was updated successfully, but these errors were encountered:
johndpope
changed the title
Lack of polish on code
Please add clarity to code
Feb 26, 2021
so Phil - I love your work - I wish you could go extra few steps to help out users.
I found this class by François-Guillaume @frgfm - which adds in clear math coments.
I want to merge it but there's a bit of code drift don't want to introduce any bugs.
I beseech you to go extra step to help users bridge from papers to code.
https://github.com/frgfm/Holocron/blob/bcc3ea19a477e4b28dc5973cdbe92a9b05c690bb/holocron/nn/modules/lambda_layer.py
eg.
please articulate return types
def forward(self, x: torch.Tensor) -> torch.Tensor:
Please give any clarity in arguments.
# Project input and context to get queries, keys & values
Throw in some maths as a comment / this is great as it bridges the paper to the code.
B x (num_heads * dim_k) * H * W -> B x num_heads x dim_k x (H * W)
The text was updated successfully, but these errors were encountered: