In [73]:
from math import log
from einops import rearrange

import torch
import torch.nn as nn

class Attention(nn.Module):
    def __init__(self,input_size,layer_size=64,latent_size=None,heads=1,dropout=0.5):
        super(Attention, self).__init__()
        self.heads = heads
        self.Q = nn.Linear(input_size,layer_size,bias=False) if latent_size is None else nn.Linear(input_size,latent_size,bias=False)
        self.K = nn.Linear(input_size,layer_size,bias=False)
        self.V = nn.Linear(input_size,layer_size,bias=False)
        self.softmax = nn.Softmax(dim=-1)
        inner_size = layer_size if latent_size is None else latent_size
        self.output = nn.Sequential(
            nn.Linear(inner_size,input_size),
            nn.Dropout(dropout)
        )
        
    def forward(self,x,context=None,mask=None):
        #h:heads, b:batches, y:y-axis x:x-axis
        q = rearrange(self.Q(x), 'b y (h x) -> (b h) y x', h=self.heads) #Query
        if context is None:
            k,v = map(lambda x:rearrange(x, 'b y (h x) -> (b h) y x', h=self.heads),(self.K(x),self.V(x)))
        else:
            k,v = map(lambda x:rearrange(x, 'b y (h x) -> (b h) y x', h=self.heads),(self.K(context),self.V(context)))

        #b:batches, y:y-axis, q:q x-axis, k: k x-axis
        z = torch.einsum('b y q, b y k -> b k q', q, k) / (i.size(-1)**(0.5)) #Scaled dot product [QK.T/sqrt(dk)]

        if mask is not None:
            mask_expanded = mask.unsqueeze(1).expand_as(z)
            z = z.masked_fill(mask_expanded, -1e18)
            
        z = self.softmax(z)
        #b:batches, c:common dim, z:z x-axis, v:v x-axis
        z = torch.einsum('b y z, b v y -> b v z', z, v) #Dot product [ZV]
        #h:heads, b:batches, y:y-axis, x:x-axis
        z = rearrange(z, '(b h) y x -> b y (h x)', h=self.heads) #Concat data
        z = self.output(z)
        return z

i = torch.rand(5,200,200)
#print(i)
print(i.size())
print('------------')

model = Attention(
    i.size(-1),
    layer_size=30,
    latent_size=5,
    heads=5
)

z = model(i)
#print(z)
print(z.size())

torch.Size([5, 200, 200])
------------
torch.Size([5, 200, 200])


In [74]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0),:]
        return self.dropout(x)

In [75]:
class Transformer(nn.Module):
    def __init__(self):
        super(Transformer, self).__init__()
        
    def forward(self,x):
        return True

In [78]:
class Perceiver(nn.Module):
    def __init__(self,input_size,recursions=1,transformer_blocks=1,layer_size=64,latent_size=None,cross_heads=1,self_heads=1,cross_dropout=0.5,self_dropout=0.5):
        super(Perceiver, self).__init__()
        self.recursions = recursions
        self.transformer_blocks = transformer_blocks
        self.cross_attention = Attention(input_size,layer_size=layer_size,latent_size=latent_size,heads=cross_heads,dropout=cross_dropout)
        self.self_attention = Attention(input_size,layer_size=latent_size,latent_size=latent_size,heads=self_heads,dropout=self_dropout)
        self.pos_encoding = PositionalEncoding(input_size) #Remove from perceiver and implement in other models
    
    def forward(self,x):
        enc = self.pos_encoding(x)
        z = self.cross_attention(enc)
        for _ in range(self.recursions):
            for _ in range(self.transformer_blocks):
                z = self.self_attention(z) #Switch to latent transformer
            z = self.cross_attention(z,context=enc)
        z = self.self_attention(z) #Switch to latent transformer
        return z

p_model = Perceiver(
    i.size(-1),
    recursions = 5,
    transformer_blocks = 50,
    layer_size = 300000,
    latent_size = 500,
    cross_heads = 5,
    self_heads = 5
)

#print(i)
print(i.size())
print('---------')
out = p_model(i)
#print(out)
print(out.size())

torch.Size([5, 200, 200])
---------


RuntimeError: [enforce fail at ..\c10\core\CPUAllocator.cpp:73] data. DefaultCPUAllocator: not enough memory: you tried to allocate 1200000000 bytes. Buy new RAM!