In [8]:
import os 
import numpy as np
import h5py 
import json 
import torch
from imageio import  imread #or import matplotlib.image as mpimg and mpimg.imread
from PIL import Image #resized_img = Image.fromarray(orj_img).resize(size=(new_h, new_w))
from tqdm import tqdm
from collections import Counter
from random import seed ,choice ,sample
from torch.utils.data  import Dataset,DataLoader


In [16]:
import torchvision
from torch import nn

In [10]:
def create_input_files( karpathy_json_path ,image_folder, captions_per_image 
                       ,min_word_freq , output_folder ,max_len = 50): 
  """
  Funtion that prepares data for training,validation and testing and
  saves in HDF5 and json format as it is an efficient way of loading data using pyTorch.

  Parameters : 
  -karpathy_json_path : path for the splits and captions
  -image_folder : folder with downloaded images 
  -captions_per_image : no of captions to sample per image
  -min_word_freq : a threshold for words to be declared as <UNK>s
  -output_folder : folder to save prepared h5py files 
  -max_len : maximum sequence length 
  """

  with open(karpathy_json_path , 'r') as f :
    data = json.load(f)
  
  training_images_paths ,training_images_captions = [] ,[]
  val_images_paths , val_images_captions = [],[]
  test_images_paths , test_images_captions = [],[]
  word_frequencies = Counter()

  for image in data["images"]:
    captions = [] #a list that stores all appropriate captions of an image 
    for each_caption in image["sentences"]:
      #update word_frequencies
      word_frequencies.update(each_caption["tokens"])
      if len(each_caption["tokens"]) <= max_len:#discard too long captions
        captions.append(each_caption["tokens"])

    if len(captions) ==  0 :
        continue #we skip the image as we don't need to do anything further
      
    path_to_an_image = os.path.join(image_folder ,image["filename"])

    if image['split'] in {'train' , 'restval'} :
      training_images_paths.append(path_to_an_image)
      training_images_captions.append(captions)
    elif image['split'] == 'val' :
      val_images_paths.append(path_to_an_image)
      val_images_captions.append(captions)
    else :
      test_images_paths.append(path_to_an_image)
      test_images_captions.append(captions)
    
  assert len(training_images_paths) == len(training_images_captions) 
  assert len(val_images_paths) == len(val_images_captions)
  assert len(test_images_paths) == len(test_images_captions)

  print(f'training dataset size :{len(training_images_paths)}')
  print(f'validation dataset size :{len(val_images_paths)}')
  print(f'test dataset size :{len(test_images_paths)}')

  #Vocabulary
  words = ['<PAD>' ,'<UNK>' , '<START>' , '<END>'] #special tokens
  words = words + [w for w in word_frequencies.keys() if word_frequencies[w] > min_word_freq]
  vocabulary = {key : value for value,key in enumerate(words)}

  #name for all prepared HDF5 files to save 
  nameHDF5 = f"flicker8k_{captions_per_image}_captions_per_image_{min_word_freq}_min_freq_word"

  #saving vocabulary to a json file
  with open(os.path.join(output_folder , f"Vocabulary_{nameHDF5}.json") , 'w') as vocab_f:
    json.dump(vocabulary , vocab_f)

  seed(55)
  #sampling of a caption for each image and saving images to a HDF5 file ,
  # and captions , caption lengths to json files 
   
  
  for image_paths , image_captions ,split in [(training_images_paths ,training_images_captions ,"TRAIN"),
                                            (val_images_paths , val_images_captions , "VAL"),
                                            (test_images_paths , test_images_captions , "TEST")]:
    if os.path.exists(os.path.join(output_folder ,f"_{split}_IMAGES_{nameHDF5}.hdf5")) == False:
      with h5py.File(os.path.join(output_folder ,f"_{split}_IMAGES_{nameHDF5}.hdf5") , 'a') as hfile:
        hfile.attrs["Captions_per_image"] = captions_per_image 
        # create a dataset inside hdf5 for images 
        images = hfile.create_dataset('images' , (len(image_paths) , 3 ,256,256) , dtype = 'uint') 
        #read image and save it in h5py
        im = imread(image_paths[index])
        if len(im.shape) == 2 :
            im = im[: , : , np.newaxis]
            im = np.concatenate([im , im ,im] ,axis = -1)
        im = np.array(Image.fromarray(im).resize(size = (256,256)))
        im = im.transpose(2,0,1)
          
          
        assert im.shape == (3,256,256) 
        assert np.max(im) <= 255

        images[index] = im #saving to hdf5 

    encoded_indexed_captions ,caption_lengths = [] , []
    for index , path in enumerate(tqdm(image_paths)):
          
          #sample captions 
          if len(image_captions[index]) < captions_per_image :
            # re add caption from available ones to make no of captions equal to captions_per_image
            captions = image_captions[index] + [choice(image_captions[index]) for _ in range(captions_per_image-image_captions)]
          else :
            captions = sample(image_captions[index] , k=captions_per_image)
          

          assert len(captions) == captions_per_image
          
          #encoding captions to index of their tokens
          for cap_no , caption in enumerate(captions):

              encoded_caption = [vocabulary["<START>"]] + [vocabulary.get(word ,vocabulary["<UNK>"]) for word in caption] \
               + [vocabulary["<END>"]] + [vocabulary["<PAD>"]*(max_len-len(caption))]
              caption_len = len(caption) + 2
              encoded_indexed_captions.append(caption)
              caption_lengths.append(caption_len)

    assert len(encoded_indexed_captions) == len(caption_lengths) == len(image_paths) * captions_per_image
    print(f"\nEncoded captions for {split} : {len(encoded_indexed_captions)}")
    with open(os.path.join(output_folder , f"_Encoded_captions_{split}_{nameHDF5}.json") , 'w') as f:
        json.dump(encoded_indexed_captions ,f )

    with open(os.path.join(output_folder , f"_captions_lengths_{split}_{nameHDF5}'.json") , 'w') as f:
        json.dump(caption_lengths , f)
            
        

In [11]:
json_path = "./drive/MyDrive/flicker8k/dataset_flickr8k.json"
img_folder = "./drive/MyDrive/flicker8k/Images"
out_folder = "./drive/MyDrive/flicker8k/processed_input_files"

#create_input_files(json_path,img_folder,5,5,out_folder)

In [12]:
class DataSet(Dataset):
  """ Dateset class to be used in Dataloader to create batches """
  
  def __init__(self , dir ,base_name , split , transform= None ):
    """
      parameter : 

        dir : folder with data files
        base_name : common base name used to save hdf5 files
        split : in { "TRAIN"  , "VAL" , "TEST" }
        transform : image transformations pipeline
    """
    self.dir = dir 
    self.split = split 
    assert self.split in {"TRAIN" , "VAL" , "TEST"} , "split is not specidied correctly"
    self.transform = transform 
    #reading the data stored in hdf5 file
    self.data = h5py.File(os.path.join(dir , f"{self.split}_IMAGES_{base_name}.hdf5") , 'r')
    self.images =  self.data["images"]
    self.caps_per_im = self.data.attrs["captions_per_image"]

    #loading encoded captions completely into memory
    with open(os.path.join(dir , f"_Encoded_captions_{self.split}_{base_name}.json") , 'r') as cap_f:
        self.captions = json.load(cap_f)
    #loading encoded captions length completely into memory

    with open(os.path.join(dir  , f"_captions_lengths_{self.split}_{base_name}.json")) as cap_f :
        self.caps_len = json.load(cap_f)
    
    self.dataset_size = len(self.captions)

    def __len__(self):
        return self.dataset_size

    def __getitem__(self , index) :
        #Nth caption corespond to (N//captions_per_image)th image and iamges will be fed according to captions
        image = torch.FloatTensor(self.images[index // self.caps_per_im / 255.0]) #normlizing by 255
        if self.transform is not None :
            image = self.transform(image)

        caption_of_image = torch.FloatTensor(self.captions[index])
        caption_length = torch.FloatTensor(self.caps_len[index])
        if self.split == "TRAIN"    :
            return image ,caption_of_image , caption_length 
        else:
            # for validation and testing we need all captions to find BLEU-4 score

            all_captions = torch.FloatTensor(self.captions[
                (index//self.caps_per_im) * self.caps_per_len : (index//self.caps_per_im)*self.caps_per_len + self.caps_per_im
                ])
            return image , caption_of_image , caption_length ,all_captions
        

In [15]:
device  = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

In [None]:
class Encoder(nn.Module):
    """
    Encoder that uses a pretrained model to extract features embedded vectors from images 

    """

    def __init__(self , image_embedding_size = 14)
    super(Encoder , self  ).__init__()

    self.embedding_size = embedding_size
    wide_resnet = torchvision.models.wide_resnet50_2(pretrained=True , progress= True)
    Blocks_to_keep = list(wide_resnet.children())[:-2] # we drop the last two avgpooling and linear layers
    self.features_extractor = nn.Sequential(*blocks_to_keep)

    #to allow inputs of various resolution we use adaptive avg pooling which adjusts strides ,kernel size
    #and outputs an image of given size

    self.AdaptiveAvgPool2d = nn.AdaptiveAvgPool2d((image_embedding_size , image_embedding_size))

    for p in self.features_extractor.parameters():
        p.requires_grad() = False

    def forward(self , images):
        """
        applies forward pass

        parameters : images -> a tensor of shape [ batch_size , channels , image_input_size , image_input_size ]

        """
        features = self.features_extractor(images) # results in shape [ batch_size , 2048, image_input_size/32 , image_input_size/32  ]
        features = self.AdaptiveAvgPool2d(feature)  # results in shape [batch_size , 2048 ,image_embedding_size , image_embedding_size]
        features = features.permute(0,2,3,1) #results in shape [batch_size , image_embedding_size , image_embedding_size , 3]
        features = features.view(features.shape[0] , -1 , features.shape[-1]) # results in [batch_size , 14*14 , 2048]
        return features

    

In [35]:
class AttentionNn(nn.Module):
    """ A Nn to learn the attention mapping """
    def __init__(self ,encoder_output_dim , decoder_output_dim , attention_hidden_dim  ):
        """
        parameters:
            encoder_output_dim : dim of encoded image features embedding
            decoder_output_dim : dim of decoder output word embedding
            attention_hidden_dim : dim of attention neural network's hidden layer
        """
        super(AttentionNn , self).__init__()
        # transformation to same space
        self.transformed_enc = nn.Linear(encoder_output_dim , attention_hidden_dim)
        self.transformed_dec = nn.Linear(decoder_output_dim , attention_hidden_dim)

        self.attention_map = nn.Sequential([
            nn.Linear(attention_hidden_dim , attention_hidden_dim ) ,nn.LeakyReLU(0.2) ,
            nn.Linear(attention_hidden_dim , 1) , nn.LeakyReLU(0.2) ,nn.Softmax(dim = 1)                                 
        ])

        def forward(self , image_embedding , word_embedding):
            """
            Performs the attention mapping
            Parameters :
                image_embedding : output of feature extractor of shape [batch_size , num_pixels ,encoder_output_dim]
                word_embedding : previous word output from decoder [batch_size , decoder_output_dim]
            """
            enc_transformed = self.transformed_enc(image_embedding) #results in [batch_size , num_pixels , attention_hidden_dim]
            dec_transformed = self.transformed_dec(word_embedding) #results in [batch_size  , attention_hidden_dim]
            alpha_distribution = self.attention_map(nn.Tanh(enc_transformed + dec_transformed.unsqueeze(1))) #results in [batch_size ,num_pixels ,1]


            # applying attention on  pixels and summing to get the  weighted-pixel embedding
            attended_feature_embedding = torch.sum(image_embedding * alpha_distribution , dim = 1) # results in [batch_size , encoder_output_dim]
            return attended_feature_embedding , alpha_distribution



torch.Size([10, 20])
