In [1]:
%%capture
!pip install einops transformers

In [2]:
from torchvision.io import read_image
from einops import rearrange, reduce
import torch
import math
from torchvision.transforms import transforms
import matplotlib.pyplot as plt
from transformers import GPT2Model, GPT2Config

In [3]:
Batch_Size = 5
Image_Height = 224
Image_Width = 224
K = 7
dmodel = 512  
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('device type is:',device)

device type is: cpu


In [4]:
def normaliseImage(bm):
  #Normalise the image:
  mean, std = bm.mean([2,3]), bm.std([2,3])
  i = 0
  for m,s in zip(mean,std):
      normaliser = transforms.Normalize(m, s)
      normalised_image = normaliser(bm[i])
      scale = max(abs(normalised_image.max()),abs(normalised_image.min()))
      bm[i] = normalised_image/scale
      i += 1
  return reduce(bm,'b c h w -> b h w', 'mean')

In [5]:
def fourier_encoder(batch_image,K=7):
    m = 256    
    pe = torch.rand(batch_image.size(0),batch_image.size(1),batch_image.size(2),4*K + 1).to(device)
    band_frequency = torch.logspace(start=1, end= m/2,steps=K,base=2,dtype=torch.float64).to(device)
    x_normalised_coordinate = torch.linspace(start=-1, end=1,steps=batch_image.size(1),dtype=torch.float64).to(device)
    y_normalised_coordinate = torch.linspace(start=-1, end=1,steps=batch_image.size(2),dtype=torch.float64).to(device)
    b = 0
    for b in range(batch_image.size(0)):
        x = 0
        for i in x_normalised_coordinate:
            angle_x = i*math.pi*band_frequency
            angle_y = torch.einsum('i,j -> ij',y_normalised_coordinate,band_frequency)*math.pi
            pe[b][x][:,-1] = batch_image[b][x]
            pe[b][x][:,0:2*K:2] = angle_x.sin()
            pe[b][x][:,1:2*K:2] = angle_x.cos()
            pe[b][x][:,2*K:4*K:2] = angle_y.sin()
            pe[b][x][:,2*K + 1:4*K:2] = angle_y.cos()
            x += 1
    return rearrange(pe,'b h w c -> b (h w) c')

In [6]:
# Create a fake batch image. Will use actual images later to see if the network can learn.For now, lets create the model first
batch_image = torch.rand(Batch_Size,3,Image_Height,Image_Width).to(device)
print(batch_image.size())
#If you dont have a GPU then this operation can take awhile. A GPU performs about 30 times faster than a CPU so you should see the result instantaneously
batch_image = normaliseImage(batch_image)
batch_image = fourier_encoder(batch_image,K)
latent_array = torch.rand(Batch_Size,32,32).to(device)
latent_array = fourier_encoder(latent_array,K)

torch.Size([5, 3, 224, 224])


In [7]:
to_q = torch.nn.Linear(K*4+1,dmodel).to(device)  #29 = 4*4 + 1; 512 is the dmodel dimension
to_k = torch.nn.Linear(K*4+1,dmodel).to(device)
to_v = torch.nn.Linear(K*4+1,dmodel).to(device)

In [8]:
q = to_q(latent_array)
k = to_k(batch_image)
v = to_v(batch_image)

In [9]:
I = torch.einsum('b i d , b j d -> b i j', q, k)/dmodel**0.5

In [10]:
weight = torch.nn.functional.softmax(I, dim=-1)

In [11]:
attention = torch.einsum('b i j , b j d -> b i d', weight, v)

In [12]:
attention.size()

torch.Size([5, 1024, 512])

In [13]:
config = GPT2Config()
config.add_cross_attention = True
config.n_head = 8 
config.n_layer = 8 
config.n_embd = 512

In [14]:
latent_transformer = GPT2Model(config).to(device)

In [15]:
CAPTION_LENGTH = 128
random_text_input = torch.randint(config.vocab_size, (Batch_Size,CAPTION_LENGTH)).to(device)

In [40]:
output = latent_transformer(input_ids=random_text_input,
                            encoder_hidden_states=attention)

In [41]:
output_last_hidden_state = output['last_hidden_state'].permute(0,2,1)

In [18]:
ffn = torch.nn.Linear(CAPTION_LENGTH,1024)

In [42]:
output_block_1 = ffn(output_last_hidden_state).permute(0,2,1)

In [44]:
output_block_1.size()

torch.Size([5, 1024, 512])

In [46]:
#Now repeat the same process 8 more times with the same sharing weight???
#Not sure if I fed in the correct information into the latent_transformer though. Need to do more research
#Also, I think I need to use the GPT2LTModel to load the pretrained model; and also to train the entire input_ids sequence. GPT2Model only train one generating sequence step each time.