In [1]:
%%capture
!pip install einops

In [1]:
from torchvision.io import read_image
from einops import rearrange, reduce
import torch
import math
from torchvision.transforms import transforms
import matplotlib.pyplot as plt

In [2]:
Batch_Size = 5
Image_Height = 224
Image_Width = 224
K = 4
dmodel = 512

In [3]:
# Create a fake batch image. Will use actual images later to see if the network can learn.For now, lets create the model first
batch_image = torch.rand(Batch_Size,3,Image_Height,Image_Width)
print(batch_image.size())

torch.Size([5, 3, 224, 224])


In [4]:
def normaliseImage(bm):
  #Normalise the image:
  mean, std = bm.mean([2,3]), bm.std([2,3])
  i = 0
  for m,s in zip(mean,std):
      normaliser = transforms.Normalize(m, s)
      normalised_image = normaliser(bm[i])
      scale = max(abs(normalised_image.max()),abs(normalised_image.min()))
      bm[i] = normalised_image/scale
      i += 1
  return reduce(bm,'b c h w -> b h w', 'mean')

In [5]:
def fourier_encoder(batch_image,K=7):
    m = 256    
    pe = torch.rand(batch_image.size(0),batch_image.size(1),batch_image.size(2),4*K + 1) 
    band_frequency = torch.logspace(start=1, end= m/2,steps=K,base=2,dtype=torch.float64)
    x_normalised_coordinate = torch.linspace(start=-1, end=1,steps=batch_image.size(1),dtype=torch.float64)
    y_normalised_coordinate = torch.linspace(start=-1, end=1,steps=batch_image.size(2),dtype=torch.float64)
    b = 0
    for b in range(batch_image.size(0)):
        x = 0
        for i in x_normalised_coordinate:
            angle_x = i*math.pi*band_frequency
            angle_y = torch.einsum('i,j -> ij',y_normalised_coordinate,band_frequency)*math.pi
            pe[b][x][:,-1] = batch_image[b][x]
            pe[b][x][:,0:2*K:2] = angle_x.sin()
            pe[b][x][:,1:2*K:2] = angle_x.cos()
            pe[b][x][:,2*K:4*K:2] = angle_y.sin()
            pe[b][x][:,2*K + 1:4*K:2] = angle_y.cos()
            x += 1
    return rearrange(pe,'b h w c -> b (h w) c')

In [6]:
batch_image = normaliseImage(batch_image)

In [7]:
batch_image = fourier_encoder(batch_image,K)
latent_array = fourier_encoder(torch.rand(Batch_Size,32,32),K)

In [8]:
to_q = torch.nn.Linear(K*4+1,dmodel)  #29 = 4*4 + 1; 512 is the dmodel dimension
to_k = torch.nn.Linear(K*4+1,dmodel)
to_v = torch.nn.Linear(K*4+1,dmodel)

In [9]:
q = to_q(latent_array)
k = to_k(batch_image)
v = to_v(batch_image)

In [10]:
I = torch.einsum('b i d , b j d -> b i j', q, k)/dmodel**0.5    # actual not dmodel**0.5 but dk**05. However since we only use one layer dmodel = dk = 512

In [11]:
weight = torch.nn.functional.softmax(I, dim=-1)

In [12]:
attention = torch.einsum('b i j , b j d -> b i d', weight, v)

In [13]:
attention.size()

torch.Size([5, 1024, 512])