In [1]:
from torchvision.io import read_image
from einops import rearrange, reduce
import torch
import math
from torchvision.transforms import transforms
import matplotlib.pyplot as plt 
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [2]:
Batch_Size = 5
Image_Height = 224
Image_Width = 224

In [3]:
# Create a fake batch image. Will use actual images later to see if the network can learn.For now, lets create the model first
batch_image = torch.rand(Batch_Size,3,Image_Height,Image_Width)
print(batch_image.size())

torch.Size([5, 3, 224, 224])


In [4]:
#Normalise the image:
mean, std = batch_image.mean([2,3]), batch_image.std([2,3])
i = 0
for m,s in zip(mean,std):
    normaliser = transforms.Normalize(m, s)
    normalised_image = normaliser(batch_image[i])
    scale = max(abs(normalised_image.max()),abs(normalised_image.min()))
    batch_image[i] = normalised_image/scale
    i += 1
    
batch_image = reduce(batch_image,'b c h w -> b h w', 'mean')

In [30]:
def fourier_encoder(batch_image,K=7):
    m = 256    
    pe = torch.rand(batch_image.size(0),batch_image.size(1),batch_image.size(2),4*K + 1) 
    band_frequency = torch.logspace(start=1, end= m/2,steps=K,base=2,dtype=torch.float64)
    x_normalised_coordinate = torch.linspace(start=-1, end=1,steps=batch_image.size(1),dtype=torch.float64)
    y_normalised_coordinate = torch.linspace(start=-1, end=1,steps=batch_image.size(2),dtype=torch.float64)
    b = 0
    for b in range(batch_image.size(0)):
        x = 0
        for i in x_normalised_coordinate:
            angle_x = i*math.pi*band_frequency
            sin_x = torch.sin(angle_x)
            cos_x = torch.cos(angle_x)
            y = 0
            for j in y_normalised_coordinate:
                pixel_value = batch_image[b][x][y]
                angle_y = j*math.pi*band_frequency
                sin_y = torch.sin(angle_y)
                cos_y = torch.cos(angle_y)
                for k in range(K):
                    pe[b][x][y][k*2] = sin_x[k]
                    pe[b][x][y][k*2 + 1] = cos_x[0]
                    pe[b][x][y][k*2 + 2*K] = sin_y[k]
                    pe[b][x][y][k*2 + 2*K + 1] = cos_y[k]
                pe[b][x][y][-1] = pixel_value
                y += 1
            x += 1
    return rearrange(pe,'b h w c -> b (h w) c')

In [31]:
pe = fourier_encoder(batch_image)

In [40]:
latent_array = torch.rand(5,32,32)
latent_pe = fourier_encoder(latent_array)

In [34]:
to_q = torch.nn.Linear(29,512)
to_k = torch.nn.Linear(29,512)
to_v = torch.nn.Linear(29,512)

In [35]:
q = to_q(latent_pe)
k = to_k(pe)
v = to_v(pe)

In [36]:
I = torch.einsum('b i d , b j d -> b i j', q, k)/512**0.5

In [37]:
weight = torch.nn.functional.softmax(I, dim=-1)

In [38]:
attention = torch.einsum('b i j , b j d -> b i d', weight, v)

In [39]:
attention.size()

torch.Size([5, 1024, 512])