In [1]:
from torchvision.io import read_image
from einops import rearrange, reduce
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torchvision.transforms import transforms
import matplotlib.pyplot as plt
from transformers import GPT2Config, GPT2Tokenizer,GPT2LMHeadModel,GPT2Model
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions,CausalLMOutputWithCrossAttentions

In [2]:
Batch_Size = 5
Image_Height = 224
Image_Width = 224
K = 7
dmodel = 256  
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('device type is:',device)

device type is: cpu


In [3]:
def normaliseImage(bm):
  #Normalise the image:
  mean, std = bm.mean([2,3]), bm.std([2,3])
  i = 0
  for m,s in zip(mean,std):
      normaliser = transforms.Normalize(m, s)
      normalised_image = normaliser(bm[i])
      scale = max(abs(normalised_image.max()),abs(normalised_image.min()))
      bm[i] = normalised_image/scale
      i += 1
  return reduce(bm,'b c h w -> b h w', 'mean')

In [4]:
def fourier_encoder(batch_image,K=7):
    m = 256    
    pe = torch.rand(batch_image.size(0),batch_image.size(1),batch_image.size(2),4*K + 1).to(device)
    band_frequency = torch.logspace(start=1, end= m/2,steps=K,base=2,dtype=torch.float64).to(device)
    x_normalised_coordinate = torch.linspace(start=-1, end=1,steps=batch_image.size(1),dtype=torch.float64).to(device)
    y_normalised_coordinate = torch.linspace(start=-1, end=1,steps=batch_image.size(2),dtype=torch.float64).to(device)
    b = 0
    for b in range(batch_image.size(0)):
        x = 0
        for i in x_normalised_coordinate:
            angle_x = i*math.pi*band_frequency
            angle_y = torch.einsum('i,j -> ij',y_normalised_coordinate,band_frequency)*math.pi
            pe[b][x][:,-1] = batch_image[b][x]
            pe[b][x][:,0:2*K:2] = angle_x.sin()
            pe[b][x][:,1:2*K:2] = angle_x.cos()
            pe[b][x][:,2*K:4*K:2] = angle_y.sin()
            pe[b][x][:,2*K + 1:4*K:2] = angle_y.cos()
            x += 1
    return rearrange(pe,'b h w c -> b (h w) c')

In [5]:
class SingleCrossAttentionHead(nn.Module):
    def __init__(self,dmodel,dk,dv,cuda_number='cuda:0'):
        super(SingleCrossAttentionHead,self).__init__()
        self.proj_key = nn.Linear(dmodel,dk).to(cuda_number)
        self.proj_query = nn.Linear(dmodel,dk).to(cuda_number)
        self.proj_value  = nn.Linear(dmodel,dv).to(cuda_number)
        self.dk = dk
        self.cuda_number = cuda_number
        
    def forward(self,byte_array,latten_array):
        byte_array = byte_array.to(self.cuda_number)
        latten_array = latten_array.to(self.cuda_number)
        
        k = self.proj_key(byte_array)
        q = self.proj_query(latten_array)
        v = self.proj_value(byte_array)
        
        I = torch.einsum('b i d , b j d -> b i j', q, k)
        attention = F.softmax(I/(self.dk**0.5), dim=-1)
        head = torch.einsum('b i j , b j d -> b i d', attention, v)

        return head

In [6]:
# Copied from HuggingFace with modified forward function
class CustomGPT2(GPT2Model):
    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        token_type_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
            batch_size = input_ids.shape[0]
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
            batch_size = inputs_embeds.shape[0]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        device = input_ids.device if input_ids is not None else inputs_embeds.device

        if token_type_ids is not None:
            token_type_ids = token_type_ids.view(-1, input_shape[-1])

        if past_key_values is None:
            past_length = 0
            past_key_values = tuple([None] * len(self.h))
        else:
            past_length = past_key_values[0][0].size(-2)


        # GPT2Attention mask.
        if attention_mask is not None:
            assert batch_size > 0, "batch_size has to be defined and > 0"
            attention_mask = attention_mask.view(batch_size, -1)
            # We create a 3D attention mask from a 2D tensor mask.
            # Sizes are [batch_size, 1, 1, to_seq_length]
            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
            # this attention mask is more simple than the triangular masking of causal attention
            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
            attention_mask = attention_mask[:, None, None, :]

            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
            # masked positions, this operation will create a tensor which is 0.0 for
            # positions we want to attend and -10000.0 for masked positions.
            # Since we are adding it to the raw scores before the softmax, this is
            # effectively the same as removing these entirely.
            attention_mask = attention_mask.to(dtype=self.dtype)  # fp16 compatibility
            attention_mask = (1.0 - attention_mask) * -10000.0

        # If a 2D ou 3D attention mask is provided for the cross-attention
        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
        if self.config.add_cross_attention and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
            if encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
            encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
        else:
            encoder_attention_mask = None

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # head_mask has shape n_layer x batch x n_heads x N x N
        head_mask = self.get_head_mask(head_mask, self.config.n_layer)

        if inputs_embeds is None:
            inputs_embeds = self.wte(input_ids)
        hidden_states = inputs_embeds

        if token_type_ids is not None:
            token_type_embeds = self.wte(token_type_ids)
            hidden_states = hidden_states + token_type_embeds

        hidden_states = self.drop(hidden_states)

        output_shape = input_shape + (hidden_states.size(-1),)

        presents = () if use_cache else None
        all_self_attentions = () if output_attentions else None
        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
        all_hidden_states = () if output_hidden_states else None
        for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):

            # Model parallel
            if self.model_parallel:
                torch.cuda.set_device(hidden_states.device)
                # Ensure layer_past is on same device as hidden_states (might not be correct)
                if layer_past is not None:
                    layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
                # Ensure that attention_mask is always on the same device as hidden_states
                if attention_mask is not None:
                    attention_mask = attention_mask.to(hidden_states.device)
                if isinstance(head_mask, torch.Tensor):
                    head_mask = head_mask.to(hidden_states.device)
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            if getattr(self.config, "gradient_checkpointing", False) and self.training:

                if use_cache:
                    logger.warning(
                        "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
                        "`use_cache=False`..."
                    )
                    use_cache = False

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        # None for past_key_value
                        return module(*inputs, use_cache, output_attentions)

                    return custom_forward

                outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(block),
                    hidden_states,
                    None,
                    attention_mask,
                    head_mask[i],
                    encoder_hidden_states,
                    encoder_attention_mask,
                )
            else:
                outputs = block(
                    hidden_states,
                    layer_past=layer_past,
                    attention_mask=attention_mask,
                    head_mask=head_mask[i],
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_attention_mask,
                    use_cache=use_cache,
                    output_attentions=output_attentions,
                )

            hidden_states = outputs[0]
            if use_cache is True:
                presents = presents + (outputs[1],)

            if output_attentions:
                all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
                if self.config.add_cross_attention:
                    all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)

            # Model Parallel: If it's the last layer for that device, put things on the next device
            if self.model_parallel:
                for k, v in self.device_map.items():
                    if i == v[-1] and "cuda:" + str(k) != self.last_device:
                        hidden_states = hidden_states.to("cuda:" + str(k + 1))

        hidden_states = self.ln_f(hidden_states)

        hidden_states = hidden_states.view(*output_shape)
        # Add last hidden state
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)

        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=presents,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=all_cross_attentions,
        )

In [7]:
# Copied from HuggingFace with modified customGPT2 layer and forward function
class Latent_Transformers(GPT2LMHeadModel):
    def __init__(self, config):
        super().__init__(config)
        self.transformer = CustomGPT2(config)
    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
            ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        transformer_outputs = self.transformer(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        
        hidden_states = transformer_outputs[0]
        #print(hidden_states)

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.transformer.first_device)
            hidden_states = hidden_states.to(self.lm_head.weight.device)

        lm_logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            labelss = torch.zeros(labels.size(0),1024, dtype=torch.long)
            labelss[:,:labels.size(1)] = labels
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labelss[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        if not return_dict:
            output = (lm_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return CausalLMOutputWithCrossAttentions(
            loss=loss,
            logits=lm_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
            cross_attentions=transformer_outputs.cross_attentions,
        )
         

In [8]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

In [9]:
# Create a fake batch image. Will use actual images later to see if the network can learn.For now, lets create the model first
batch_image = torch.rand(Batch_Size,3,Image_Height,Image_Width).to(device)
print(batch_image.size())
#If you dont have a GPU then this operation can take awhile. A GPU performs about 30 times faster than a CPU so you should see the result instantaneously
batch_image = normaliseImage(batch_image)
batch_image = fourier_encoder(batch_image,K)
latent_array = torch.rand(Batch_Size,32,32).to(device)
latent_array = fourier_encoder(latent_array,K)

torch.Size([5, 3, 224, 224])


In [10]:
config = GPT2Config()
config.add_cross_attention=False
config.n_head = 4 
config.n_layer = 4 
config.n_embd = 256

In [11]:
labels = [
    'This is a dog',
    'a cat',
    'a person working',
    'a construction defects',
    'an uncaption image'
]

labels = tokenizer.batch_encode_plus(labels, return_tensors='pt',padding=True, truncation=True).to(device)

In [12]:
fnn = nn.Linear(50176,1024)

In [13]:
cross_attention = SingleCrossAttentionHead(29,256,256,cuda_number=device)

In [14]:
cross_attention_output = cross_attention(latent_array,batch_image)
cross_attention_output = cross_attention_output.permute(0,2,1)
cross_attention_output = fnn(cross_attention_output)
cross_attention_output = cross_attention_output.permute(0,2,1)

#Havent added drop out or normalisation layer yet. Will add when creating the final wrapper CrossAttentionTransformer class

In [15]:
latent_transformer = Latent_Transformers(config).to(device)

In [16]:
latent_transformer_outputs = latent_transformer(input_ids=None,
                   inputs_embeds=cross_attention_output,
                   labels=labels['input_ids'])

In [17]:
latent_transformer_outputs.loss

tensor(10.6460, grad_fn=<NllLossBackward>)