<a href="https://colab.research.google.com/github/martinpius/PYTORCH/blob/main/TextLoader_Pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
from timeit import default_timer as timer
t1 = timer()
try:
  from google.colab import drive
  drive.mount("/content/drive/", force_remount = True)
  import torch
  from torch.nn.utils.rnn import pad_sequence
  from torch.utils.data import Dataset, DataLoader
  import os, spacy
  import pandas as pd
  from torchvision import transforms
  import matplotlib.pyplot as plt
  from PIL import Image
  print(f">>>> You are on CoLaB with torch version: {torch.__version__}")
except Exception as e:
  print(f">>>> {type(e)}: {e}\n>>>> Please correct {type(e)} and reload the Google drive")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
BATCH_SIZE = 64 if device == torch.device("cuda") else 32
print(f">>>> Available device: {device}")
def mytimer(t: float = timer())->float:
  h = int(t / (60 * 60))
  m = int(t % (60 * 60) / 60)
  s = int(t % 60)
  return f"hrs: {h}, mins: {m:>02}, secs: {s:>05.2f}"
!nvidia-smi
print(f">>>> Time elapsed: {mytimer(timer() - t1)}")

Mounted at /content/drive/
>>>> You are on CoLaB with torch version: 2.0.0+cu118
>>>> Available device: cpu
/bin/bash: nvidia-smi: command not found
>>>> Time elapsed: hrs: 0, mins: 00, secs: 06.00


In [2]:
spacy_en = spacy.load("en_core_web_sm")

In [3]:
# We implement a Pytorch DataLoader for user customized dataset

In [11]:
class WordsDictionary():
  """""
  This class create a word embedding conditioned
  on minimum frequency of 5.
  """""
  def __init__(self,
               min_frequency =  5):
    self.stoi = {"<PAD>": 0,
                 "<SOS>": 1, 
                 "<EOS>": 2, 
                 "<UNK>": 3}
    self.itos = {0: "<PAD>", 
                 1: "<SOS>", 
                 2: "<EOS>", 
                 3: "<UNK>"}
    self.min_frequency = min_frequency
  
  def __len__(self):
    return len(self.stoi)
  
  @staticmethod
  def mytokenizer(caption: str)->list:
    # This method convert the text to lower case and returns the list tokens
    return [tok.text.lower() for tok in spacy_en.tokenizer(caption)]
  
  def tfdif(self, caption_list):
    """""
    This method create a term frequency document inverse frequency dictionary
    """""
    terms_freq = {}
    idx = 4 # Since we have already used 0-3 indices
    for caption in caption_list:
      for term in caption:
        if term not in terms_freq:
          terms_freq[term] = 1
        else:
          terms_freq[term] += 1
        # We create the TFIDF dictionary if a term has occured at least 5 times
        if terms_freq[term] == self.min_frequency:
          self.stoi[term] = idx
          self.itos[idx] = term

  def word2vec(self, caption):
    """""
    This method convert the text into numbers
    """""
    tokens = self.mytokenizer(caption) # Tokenize the caption
    
    # We return an index of the corresponding token if min_freq == 5
    # else we return an index of unknown word.
    return [
        self.stoi[tok] if tok in self.stoi else self.stoi["<UNK>"]\
        for tok in tokens
    ]
  
class PytorchTextDataReader(Dataset):
  def __init__(self,
               root_path: str,
               text_path: str,
               transform:transforms = None,
               min_frequency: int = 5)->None:
      
      self.root_path = root_path
      self.text_path = text_path
      self.transform = transform
      self.min_frequency = min_frequency
      self.dfm = pd.read_csv(text_path)
      self.images = self.dfm["image"] # grab all images: pd.Series of images
      self.captions = self.dfm["caption"] # grab all captions: pd.Series of texts
      # Instantiating the word dictionary class
      self.mydictionary = WordsDictionary(min_frequency = self.min_frequency)
      # Create a TFIDF instance
      self.mydictionary.tfdif(caption_list = self.captions.tolist())
  
  def __len__(self):
    return len(self.dfm)
  
  def __getitem__(self, index):

    img_root = self.images[index] # grab an image
    caption = self.captions[index] # grab a caption
    # Read an image and convert to RGB
    img = Image.open(os.path.join(self.root_path, img_root)).convert("RGB")
    if self.transform:
      img = self.transform(img)
    
    # converting a caption to numeric
    word_vec = [self.mydictionary.stoi["<SOS>"]] # grab the index of a starting token in a caption
    # adding indices for the rest of the token using word2vec method
    word_vec += self.mydictionary.word2vec(caption = caption)
    word_vec.append(self.mydictionary.stoi["<EOS>"]) # adding an end of the text token index

    word_vec = torch.tensor(word_vec) # converting into torch tensor

    return img, word_vec # List of lists [[img1, caption1]....]
  
class MyCollate_fn():
  """""
  This class create a user-defined collate function
  to padd every caption in a batch with a maximum length
  of the caption within the respective batch. [(zero padding)]
  """""
  def __init__(self, pad_index):
    self.pad_index = pad_index

  def __call__(self, batch):
    """""
    batch is the list of lists which contains images and embedded captions
    [[image1, caption1], [image2, caption2],...]
    """""
    images = [item[0].unsqueeze(dim = 0) for item in batch] # add a batch dimension 
    images = torch.cat(images) # concatenate all images in a batch: "must be of same shape"

    captions = [item[1] for item in batch] # grab all captions into a one list

    # Padding the shorter sequences with max len of the sequence in the batch
    padded_captions = pad_sequence(sequences = captions,
                                  padding_value = self.pad_index,
                                  batch_first = True)
    
    return images, padded_captions


transform = transforms.Compose([
    transforms.Resize((224,224)), transforms.ToTensor()
])

def textloader(
    root_path: str,
    text_path: str,
    num_workers: int = 2,
    pin_memmory: bool = True,
    shuffle: bool = True,
    batch_size: int = BATCH_SIZE):
  
  mytext_dataset = PytorchTextDataReader(root_path = root_path,
                                         text_path = text_path,
                                         transform = transform)
  
  pad_index = mytext_dataset.mydictionary.stoi["<PAD>"] # grab the padding index

  mytext_loader = DataLoader(dataset = mytext_dataset,
                             batch_size = batch_size,
                             shuffle = shuffle,
                             pin_memory = pin_memmory,
                             collate_fn = MyCollate_fn(pad_index = pad_index),
                             num_workers = num_workers
                             )
  
  return mytext_loader

In [None]:
if __name__ == "__main__":

  images_path = "/content/drive/MyDrive/flickr30k_images/flickr8k/images"
  texts_path = "/content/drive/MyDrive/flickr30k_images/flickr8k/captions.txt"
  textloader = textloader(root_path = images_path,
                          text_path = texts_path)
  
  for idx, (image, caption) in enumerate(textloader):
    print(f">>>> Image shape: {image.shape}\t Caption shape: {caption.shape}")