In [2]:
from math import log
from einops import rearrange

import torch
import torch.nn as nn

class Attention(nn.Module):
    def __init__(self,input_size,layer_size=64,latent_size=None,heads=1,dropout=0.5):
        super(Attention, self).__init__()
        self.heads = heads
        self.Q = nn.Linear(input_size,layer_size,bias=False) if latent_size is None else nn.Linear(input_size,latent_size,bias=False)
        self.K = nn.Linear(input_size,layer_size,bias=False)
        self.V = nn.Linear(input_size,layer_size,bias=False)
        self.softmax = nn.Softmax(dim=-1)
        inner_size = layer_size if latent_size is None else latent_size
        self.output = nn.Sequential(
            nn.Linear(inner_size,input_size,bias=False),
            nn.Dropout(dropout)
        )
        
    def forward(self,x,context=None,mask=None):
        #h:heads, b:batches, y:y-axis x:x-axis
        q = rearrange(self.Q(x), 'b y (h x) -> (b h) y x', h=self.heads) #Query
        if context is None:
            k,v = map(lambda x:rearrange(x, 'b y (h x) -> (b h) y x', h=self.heads),(self.K(x),self.V(x)))
        else:
            k,v = map(lambda x:rearrange(x, 'b y (h x) -> (b h) y x', h=self.heads),(self.K(context),self.V(context)))

        #b:batches, y:y-axis, q:q x-axis, k: k x-axis
        z = torch.einsum('b y q, b y k -> b k q', q, k) / (i.size(-1)**(0.5)) #Scaled dot product [QK.T/sqrt(dk)]

        if mask is not None:
            mask_expanded = mask.unsqueeze(1).expand_as(z)
            z = z.masked_fill(mask_expanded, -1e18)
            
        z = self.softmax(z)
        #b:batches, c:common dim, z:z x-axis, v:v x-axis
        z = torch.einsum('b y z, b v y -> b v z', z, v) #Dot product [ZV]
        #h:heads, b:batches, y:y-axis, x:x-axis
        z = rearrange(z, '(b h) y x -> b y (h x)', h=self.heads) #Concat data
        z = self.output(z)
        return z

i = torch.rand(3,8,4)
#print(i)
print(i.size())
print('------------')

attention = Attention(
    i.size(-1),
    layer_size=30,
    latent_size=10,
    heads=2
)

a_out = attention(i)
print(a_out.size())

torch.Size([3, 8, 4])
------------
torch.Size([3, 8, 4])


In [3]:
class Transformer(nn.Module):
    def __init__(self,input_size,layer_size=64,heads=1,dropout=0.5):
        super(Transformer, self).__init__()
        self.self_attention = Attention(input_size,layer_size=layer_size,heads=heads,dropout=dropout)
        self.linear = nn.Sequential(
            nn.Linear(input_size,input_size,bias=False),
            nn.Dropout(dropout)
        )
            
    def forward(self,x):
        z = self.self_attention(x) 
        z = nn.functional.normalize(z)
        z = self.linear(z)
        z = nn.functional.normalize(z)
        return z
    
transformer = Transformer(
    i.size(-1),
    layer_size = 20,
    heads = 5,
    dropout = 0.5
)

print(i.size())
print('---------')
t_out = transformer(i)
print(t_out.size())

torch.Size([3, 8, 4])
---------
torch.Size([3, 8, 4])


In [4]:
class Perceiver(nn.Module):
    def __init__(self,input_size,recursions=1,transformer_blocks=1,layer_size=64,latent_size=None,cross_heads=1,self_heads=1,cross_dropout=0.5,self_dropout=0.5):
        super(Perceiver, self).__init__()
        self.recursions = recursions
        self.transformer_blocks = transformer_blocks
        self.cross_attention = Attention(
            input_size,layer_size=layer_size,
            latent_size=latent_size,
            heads=cross_heads,
            dropout=cross_dropout
        )
        self.latent_transformer = Transformer(
            input_size,
            latent_size,
            self_heads,
            self_dropout
        )
    
    def forward(self,x):
        z = self.cross_attention(x)
        for _ in range(self.recursions):
            for _ in range(self.transformer_blocks):
                z = self.latent_transformer(z)
            z = self.cross_attention(z,context=x)
        z = self.latent_transformer(z)
        return z

perceiver = Perceiver(
    i.size(-1),
    recursions = 5,
    transformer_blocks = 3,
    layer_size = 20,
    latent_size = 10,
    cross_heads = 5,
    self_heads = 2
)

#print(i)
print(i.size())
print('---------')
p_out = perceiver(i)
#print(out)
print(p_out.size())

torch.Size([3, 8, 4])
---------
torch.Size([3, 8, 4])


In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0),:]
        return self.dropout(x)