<a href="https://colab.research.google.com/github/martinpius/Computer-Vission/blob/main/LOADING_AND_PREPROCESSING_CUSTOM_DATASET_IMAGES_AND_TEXTS_(IMAGE_CAPTION_DATASET).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount("/content/drive", force_remount = True)
try:
  COLAB = True
  import torch
  print(f">>>> You are on CoLaB with torch version {torch.__version__}")
except Exception as e:
  print(f">>>> {type(e)} {e}\n>>>> please correct {type(e)} and reload your device")
  COLAB = False
def time_fmt(t: float = 129.98)->float:
  h = int(t / (60 * 60))
  m = int(t % (60 * 60) / 60)
  s = int(t % 60)
  return f"hrs: {h} min: {m:>02} sec: {s:>05.2f}"
if torch.cuda.is_available():
  device = torch.device("cuda")
else:
  device = torch.device("cpu")
print(f">>>> testing time formating function.....\n>>>> time elapsed\t{time_fmt()}")

Mounted at /content/drive
>>>> You are on CoLaB with torch version 1.9.0+cu102
>>>> testing time formating function.....
>>>> time elapsed	hrs: 0 min: 02 sec: 09.00


In [2]:
# In this notebook we are going to learn how to import and preprocess the dataset from 
# different platform (here being Google drive). For demo we will work with flickr30k dataset.
# since flickr30k consists of both images and texts we implement procedures to process the 
# data in both perspective.


In [21]:
import torch, os, spacy, random, time, datetime
from torch.nn.utils.rnn import pad_sequence
from PIL import Image
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

In [14]:
#set the seed for reproducability:
seed = 123
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
spacy_eng = spacy.load("en")

In [47]:
# A class to build vocabulary
class Vocabulary:
  def __init__(self, freq_threshold):
    '''
    We construct a dictionary which key-values as index-tring and vice-versa
    to convert the strings to indice and indices back to strings
    <UNK> is when the word doesnt bit the frequency threshold limit.
    '''
    self.freq_threshold = freq_threshold
    self.itos = {0:"<PAD>", 1:"<SOS>", 2:"<EOS>", 3:"<UNK>"}
    self.stoi = {"<PAD>":0, "<SOS>":1, "<EOS>":2, "<UNK>":3}
  
  def __len__(self):
    return len(self.itos)

  @staticmethod
  def eng_tokenizer(text):
    ''' 
    we use spacy-tokenizer to tokenize the texts and then change them to lower cases
    '''
    return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
  
  def build_vocabulary(self, caption_list):
    '''
    we send in caption list and build a corpus / bag of vocabulary
    in every caption we inspect the each word count. If the word occured
    more than frequency threshold we assign an index otherwise it will be assigned
    to unknown index.
    '''
    frequencies = {} # a dictionary / place-holder to store the words
    idx = 4 # we start at 4 because 0 = PAD, 1 = SOS, 2 = EOS, 3 = UNK
    for caption in caption_list:
      for word in caption:
        if word not in frequencies:
          frequencies[word] = 1
        else:
          frequencies[word] += 1

        #here we do the conversion if the criteria is met
        if frequencies[word] == self.freq_threshold:
          self.stoi[word] = idx
          self.itos[idx] = word
          idx += 1

  def numericalize(self, text):
    '''
    we actually convert the texts to numerics using this method

    '''
    tokenized_text = self.eng_tokenizer(text) # get the tokens in lower cases
    return [
            self.stoi[token] if token in self.stoi else self.stoi['<UNK>']
            for token in tokenized_text
    ]
      

# a class to Load the data from google drive
class Flickr30kData(Dataset):
  def __init__(self, root_dir, csv_file, transform = None, freq_threshold = 5):
    '''
    root_dir == directory to images folder
    csv_file == csv file directory for image discription (id, caption)
    transform == if we will apply some transformations to images
    freq_threshold == frequency threshold to keep most frequent words in captions

    '''
    self.root_dir = root_dir
    self.dfm = pd.read_csv(csv_file, error_bad_lines = False) # we read the csv file from a specified directory
    self.transform = transform

    # Grab the image id and caption data from the panda dataframe:
    self.imgs = self.dfm['image']
    self.captions = self.dfm['caption']

    #initialize and buil a vocabulary
    self.vocab = Vocabulary(freq_threshold) # we use Vocabulary class (to be defined)
    self.vocab.build_vocabulary(self.captions.tolist())
  
  def __len__(self):
    ''' 
    we grasp total number of examples from our data frame to mark the end of our
    loop when we load one datapoint after the other

    '''
    return len(self.dfm)
  
  def __getitem__(self, idx):
    '''
    this method help to grasp one sample at a time
    a single image with a corresponding caption

    '''
    caption = self.captions[idx] # grab a caption from an image (texts) from image description csv_file
    img_id = self.imgs[idx] # grab the image id from an image description csv-file
    img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB") # grab an image from the image folder and convert to RGB
    # we apply some transformation to the image if needed
    if self.transform is not None:
      img = self.transform(img)
    
    # We pre-process the texts here: captions (change into numeric)
    numericalized_caption = [self.vocab.stoi['<SOS>']] # we start at the begining of the sentennce (SOS)
    numericalized_caption += self.vocab.numericalize(caption) # change the caption to numeric
    numericalized_caption.append(self.vocab.stoi["<EOS>"]) # mark the end of the sentence <EOS>
    # convert to a tensor and then return the image with the corresponding caption
    return img, torch.tensor(numericalized_caption)

# Since every caption is of specified legth, economically we need to pad-the generated 
# sequences with the maximum length of a sentence on a specified batch.
class MyCollate:
  def __init__(self, pad_idx):
    self.pad_idx = pad_idx 

  def __call__(self, batch):
    images = [item[0].unsqueeze(0) for item in batch] # list of images with an added batch dimension
    images = torch.cat(images, dim = 0) # combine images accross the batch dims
    targets = [item[1] for item in batch] # grab all captions 
    targets = pad_sequence(targets, batch_first = False,padding_value = self.pad_idx) # pad every batch with its max len
    return images, targets

# We finally define our iterator (dataloader method to stream the data during training)
def get_loader(images_dir,
               csv_dir,
               transform,
               batch_size = 64,
               shuffle = True,
               pin_memory = True):
  
  #instantiate the data-loader, splits the data into batches padded independntly with their max len
  my_flickrdata = Flickr30kData(images_dir, csv_dir, transform)
  pad_idx = my_flickrdata.vocab.stoi["<PAD>"] # to use in the custom- collate function
  loader = DataLoader(dataset = my_flickrdata, 
                      batch_size = batch_size, 
                      shuffle = shuffle,
                      pin_memory = pin_memory, 
                      collate_fn = MyCollate(pad_idx = pad_idx))
  return loader



In [None]:
# testing our codes by loading the data from google drive.

mytransform = transforms.Compose([
                                transforms.Resize((224,224)),
                                transforms.ToTensor(),])


tic = time.time()

loader = get_loader(images_dir= "/content/drive/MyDrive/flickr30k_images/flickr8k/images",
                    csv_dir = "/content/drive/MyDrive/flickr30k_images/flickr8k/captions.txt",
                    transform = mytransform)

for idx, (image, caption) in enumerate(loader):
  print(f">>> image_shape: {image.shape}\tcaption_shape: {caption.shape}")

tok = time.time()

print(f">>>> time elapsed: {tok - tic}")


>>> image_shape: torch.Size([64, 3, 224, 224])	caption_shape: torch.Size([30, 64])
>>> image_shape: torch.Size([64, 3, 224, 224])	caption_shape: torch.Size([27, 64])
>>> image_shape: torch.Size([64, 3, 224, 224])	caption_shape: torch.Size([29, 64])
>>> image_shape: torch.Size([64, 3, 224, 224])	caption_shape: torch.Size([22, 64])
>>> image_shape: torch.Size([64, 3, 224, 224])	caption_shape: torch.Size([29, 64])
>>> image_shape: torch.Size([64, 3, 224, 224])	caption_shape: torch.Size([28, 64])
>>> image_shape: torch.Size([64, 3, 224, 224])	caption_shape: torch.Size([25, 64])
>>> image_shape: torch.Size([64, 3, 224, 224])	caption_shape: torch.Size([21, 64])
>>> image_shape: torch.Size([64, 3, 224, 224])	caption_shape: torch.Size([26, 64])
>>> image_shape: torch.Size([64, 3, 224, 224])	caption_shape: torch.Size([27, 64])
>>> image_shape: torch.Size([64, 3, 224, 224])	caption_shape: torch.Size([23, 64])
>>> image_shape: torch.Size([64, 3, 224, 224])	caption_shape: torch.Size([28, 64])
>>> 