$$Attention(Q,K,V) = \mbox{softmax}\left(\frac{QK^{T}}{\sqrt{d_{K}}}\right)V$$

$$Q = \mbox{Query vector}$$
$$K = \mbox{Key vector}$$
$$V = \mbox{Value vector}$$
$$d_{k} = \mbox{Size of key vector}$$

<center>
    <img src="img/multi_head_attention.png" alt="Multi Head Attention" width="450" height="250" />
</center>

In [1]:
from math import log
from einops import rearrange

import torch
import torch.nn as nn

class Attention(nn.Module):
    def __init__(self,input_size,context_size=None,layer_size=64,heads=1,dropout=0.5):
        super(Attention, self).__init__()
        self.heads = heads
        self.Q = nn.Linear(input_size,layer_size,bias=False)
        self.K = nn.Linear(input_size,layer_size,bias=False) if context_size is None else nn.Linear(context_size,layer_size,bias=False)
        self.V = nn.Linear(input_size,layer_size,bias=False) if context_size is None else nn.Linear(context_size,layer_size,bias=False)
        self.softmax = nn.Softmax(dim=-1)
        self.output = nn.Sequential(
            nn.Linear(layer_size,input_size,bias=False),
            nn.Dropout(dropout)
        )
        
    def forward(self,x,context=None,mask=None):
        #h:heads, b:batches, y:y-axis x:x-axis
        q = rearrange(self.Q(x), 'b y (h x) -> (b h) y x', h=self.heads) #Query
        if context is None:
            k,v = map(lambda x:rearrange(x, 'b y (h x) -> (b h) y x', h=self.heads),(self.K(x),self.V(x)))
        else:
            k,v = map(lambda x:rearrange(x, 'b y (h x) -> (b h) y x', h=self.heads),(self.K(context),self.V(context)))
        #b:batches, y:y-axis, q:q x-axis, k: k x-axis
        z = torch.einsum('b q y, b k y -> b k q', q, k) / (i.size(-1)**(0.5)) #Scaled dot-product [QK.T/sqrt(dk)]

        if mask is not None:
            mask_expanded = mask.unsqueeze(1).expand_as(z)
            z = z.masked_fill(mask_expanded, -1e18)
            
        z = self.softmax(z)
        #b:batches, c:common dim, z:z x-axis, v:v x-axis
        z = torch.einsum('b y z, b y v -> b z v', z, v) #Dot product [ZV]
        #h:heads, b:batches, y:y-axis, x:x-axis
        z = rearrange(z, '(b h) y x -> b y (h x)', h=self.heads) #Concat data
        z = self.output(z)
        return z

i = torch.rand(1,3,2)
c = torch.rand(1,20,20)
#print(i)
print(i.size())
print(c.size())
print('------------')

attention = Attention(
    i.size(-1),
    layer_size=30,
    context_size=c.size(-1),
    heads=2
)

a_out = attention(i,context=c)
print(a_out.size())

torch.Size([1, 3, 2])
torch.Size([1, 20, 20])
------------
torch.Size([1, 3, 2])


<center>
    <img src="img/transformer.png" alt="Multi Head Attention" width="450" height="250" />
</center>

In [2]:
class DecoderOnlyTransformer(nn.Module):
    def __init__(self,input_size,layer_size=64,heads=1,dropout=0.5):
        super(DecoderOnlyTransformer, self).__init__()
        self.self_attention = Attention(input_size,layer_size=layer_size,heads=heads,dropout=dropout)
        self.linear = nn.Sequential(
            nn.Linear(input_size,input_size,bias=False),
            nn.Dropout(dropout)
        )

    def forward(self,x):
        z = self.self_attention(x)
        z = nn.functional.normalize(z)
        z = self.linear(z)
        z = nn.functional.normalize(z)
        return z
    
transformer = DecoderOnlyTransformer(
    i.size(-1),
    layer_size = 20,
    heads = 5,
    dropout = 0.5
)

print(i.size())
print('---------')
t_out = transformer(i)
print(t_out.size())

torch.Size([1, 3, 2])
---------
torch.Size([1, 3, 2])


<center>
    <img src="img/perceiver.png" alt="Multi Head Attention" width="750" height="550" />
</center>

In [3]:
class Perceiver(nn.Module):
    def __init__(self,input_size,latent_size,recursions=1,transformer_blocks=1,layer_size=64,cross_heads=1,self_heads=1,cross_dropout=0.5,self_dropout=0.5):
        super(Perceiver, self).__init__()
        self.recursions = recursions
        self.transformer_blocks = transformer_blocks
        self.cross_attention = Attention(
            latent_size,
            layer_size=layer_size,
            context_size=input_size,
            heads=cross_heads,
            dropout=cross_dropout
        )
        self.latent_transformer = DecoderOnlyTransformer(
            latent_size,
            layer_size,
            self_heads,
            self_dropout
        )

    def forward(self,x,latent):
        z = self.cross_attention(latent,context=x)
        for _ in range(self.recursions):
            for _ in range(self.transformer_blocks):
                z = self.latent_transformer(z)
            z = self.cross_attention(z,context=x)
        z = self.latent_transformer(z)
        return z

perceiver = Perceiver(
    c.size(-1),
    i.size(-1),
    recursions = 5,
    transformer_blocks = 3,
    layer_size = 20,
    cross_heads = 5,
    self_heads = 2
)

#print(i)
print(i.size())
print(c.size())
print('---------')
p_out = perceiver(c,i)
#print(out)
print(p_out.size())

torch.Size([1, 3, 2])
torch.Size([1, 20, 20])
---------
torch.Size([1, 3, 2])


<center>
    <img src="img/perceiver_IO.png" alt="Multi Head Attention" width="750" height="550" />
</center>

In [4]:
class PerceiverIO(nn.Module):
    def __init__(self,input_size,latent_size,output_size,recursions=1,transformer_blocks=1,layer_size=64,cross_heads=1,self_heads=1,cross_dropout=0.5,self_dropout=0.5):
        super(PerceiverIO, self).__init__()
        self.recursions = recursions
        self.transformer_blocks = transformer_blocks
        self.cross_attention = Attention(
            latent_size,
            layer_size=layer_size,
            context_size=input_size,
            heads=cross_heads,
            dropout=cross_dropout
        )
        self.latent_transformer = DecoderOnlyTransformer(
            latent_size,
            layer_size,
            self_heads,
            self_dropout
        )
        self.decode = Attention(
            output_size,
            layer_size=layer_size,
            context_size=latent_size,
            heads=cross_heads,
            dropout=cross_dropout
        )
    
    def forward(self,x,latent,output):
        z = self.cross_attention(latent,context=x)
        for _ in range(self.recursions):
            for _ in range(self.transformer_blocks):
                z = self.latent_transformer(z)
            z = self.cross_attention(z,context=x)
        z = self.latent_transformer(z)
        return self.decode(output,z)
    
o = torch.rand(1,1,3)
    
perceiverIO = PerceiverIO(
    c.size(-1),
    i.size(-1),
    o.size(-1),
    recursions = 5,
    transformer_blocks = 3,
    layer_size = 20,
    cross_heads = 5,
    self_heads = 2
)

#print(i)
print(i.size())
print(c.size())
print(o.size())
print('---------')
pio_out = perceiverIO(c,i,o)
#print(out)
print(pio_out.size())

torch.Size([1, 3, 2])
torch.Size([1, 20, 20])
torch.Size([1, 1, 3])
---------
torch.Size([1, 1, 3])


<center>
    <img src="img/positional_encoding.png" alt="Multi Head Attention" width="150" height="150" />
</center>

$$PE_{\left(\mbox{pos},2_{i}\right)}=\mbox{sin}\left(\mbox{pos}10000^{2_{i}/d_{model}}\right)$$
$$PE_{\left(\mbox{pos},2_{i+1}\right)}=\mbox{cos}\left(\mbox{pos}10000^{2_{i}/d_{model}}\right)$$

In [5]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0),:]
        return self.dropout(x)

In [10]:
class CrossAttentionMap(nn.Module):
    def __init__(
        self,
        input_size,
        latent_size,
        layer_size=64,
        heads=1,
        dropout=0.5
    ):
        super(CrossAttentionMap, self).__init__()
        self.decode = Attention(
            latent_size,
            layer_size=layer_size,
            context_size=input_size,
            heads=heads,
            dropout=dropout
        )

    def forward(self,latent,context):
        return self.decode(latent,context)

x = torch.full((1,2, 3), 1).float()
l = torch.full((1,8, 64), 1).float()

print(x.size())
print(l.size())
print('--------')
c_map = CrossAttentionMap(
    x.size(-1),
    l.size(-1),
    layer_size=64,
    heads=1,
    dropout=0.5
)

z = c_map(l,x)
print(z.size())

torch.Size([1, 2, 3])
torch.Size([1, 8, 64])
--------
torch.Size([1, 8, 64])


In [20]:
class ChappieZero(nn.Module):
    def __init__(
        self,
        input_size,
        latent_size,
        reward_size,
        policy_size,
        ntoken = 30,
        embedding_size=64,
        padding_idx=29,
        encoder_dropout=0.5,
        latent_inner=64,
        latent_heads=1,
        latent_dropout=0.5,
        perceiver_inner=64,
        recursions=1,
        transformer_blocks=1,
        cross_heads=1,
        self_heads=1,
        cross_dropout=0.5,
        self_dropout=0.5,
        reward_inner=64,
        reward_heads=1,
        reward_dropout=0.5,
        policy_inner=64,
        policy_heads=1,
        policy_dropout=0.5
    ):
        super(ChappieZero, self).__init__()
        self.Embedding = nn.Embedding(ntoken,embedding_size,padding_idx=padding_idx)
        self.PosEncoder = PositionalEncoding(embedding_size,encoder_dropout)
        self.LatentMap = CrossAttentionMap(
            embedding_size,
            latent_size,
            layer_size=latent_inner,
            heads=latent_heads,
            dropout=latent_dropout
        )
        self.RewardMap = CrossAttentionMap(
            embedding_size,
            reward_size,
            layer_size=reward_inner,
            heads=reward_heads,
            dropout=reward_dropout
        )
        self.PolicyMap = CrossAttentionMap(
            embedding_size,
            policy_size,
            layer_size=policy_inner,
            heads=policy_heads,
            dropout=policy_dropout
        )
        self.Perceiver = Perceiver(
            embedding_size,
            latent_size,
            recursions=recursions,
            transformer_blocks=transformer_blocks,
            layer_size=perceiver_inner,
            cross_heads=cross_heads,
            self_heads=self_heads,
            cross_dropout=cross_dropout,
            self_dropout=self_dropout
        )
        self.RewardNetwork = CrossAttentionMap(
            latent_size,
            reward_size,
            layer_size=reward_inner,
            heads=reward_heads,
            dropout=reward_dropout
        )
        self.PolicyNetwork = CrossAttentionMap(
            latent_size,
            policy_size,
            layer_size=policy_inner,
            heads=policy_heads,
            dropout=policy_dropout
        )

    def forward(self,x,latent,reward,policy):
        x_emb = self.Embedding(x)
        x_emb = self.PosEncoder(x_emb)
        latent = self.LatentMap(latent,x_emb)
        enc = self.Perceiver(x_emb,latent)

        reward = self.RewardMap(reward,x_emb)
        v = self.RewardNetwork(reward,enc)

        policy = self.PolicyMap(policy,x_emb)
        p = self.PolicyNetwork(policy,enc)
        return v,p

x = torch.tensor([[2,3,5,4,5,2,5,6]])
l = torch.full((1,2, 3), 1).float()
r = torch.full((1,1, 3), 1).float()
p = torch.full((1,64, 64), 1).float()
print(x.size())
print(l.size())
print(r.size())
print(p.size())

chappie = ChappieZero(
    x.size(-1),
    l.size(-1),
    r.size(-1),
    p.size(-1),
    ntoken = 30,
    embedding_size=64,
    padding_idx=29,
    encoder_dropout=0.5,
    latent_inner=10,
    latent_heads=1,
    latent_dropout=0.5,
    perceiver_inner=64,
    recursions=1,
    transformer_blocks=1,
    cross_heads=1,
    self_heads=1,
    cross_dropout=0.5,
    self_dropout=0.5,
    reward_inner=64,
    reward_heads=1,
    reward_dropout=0.5,
    policy_inner=64,
    policy_heads=1,
    policy_dropout=0.5
)
    
v,p = chappie(x,l,r,p)
print('-----')
print(v.size())
print(p.size())


torch.Size([1, 8])
torch.Size([1, 2, 3])
torch.Size([1, 1, 3])
torch.Size([1, 64, 64])
-----
torch.Size([1, 1, 3])
torch.Size([1, 64, 64])
