# Import libraries

In [21]:
import os
import glob
import numpy as np
import matplotlib.pyplot as plt

# Others
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

### Tensorflow dependecies ### 
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam

### Some constants ###
data_dir = "../data/DOG_CAT_SMALL/train"
batch_size = 32

# Data Loader
The data loader for ViT should be able to :
- [x] Load images and resize them to 64 x 64 dimension.
- [ ] Divide each image into 16 16x16 patches (Four patches on each dimension).
- [ ] Flatten the patches -> Final dimension = (batch_size, 16, 16*16)

In [45]:
class DataLoader:
    def __init__(self, directory, batch_size=32, split_ratio=0.2,
                 img_size=64, shuffle=True):
        self.directory = directory
        self.shuffle = shuffle
        self.batch_size = batch_size 
        self.split_ratio = 0.2
        self.img_size = img_size
        self.patch_size = img_size // 4
        self.n_classes = 10 # To be adjusted when dataset is parsed
        
        self.train_dataset, self.val_dataset = None, None
        self.train_paths, self.train_labels = None, None
        self.val_paths, self.val_labels = None, None
        
        self.parse_dataset()
        
    def map_fn(self, img, img_size):
        # After the image has been decoded into TF tensor
        # Regular image decoding
        img = tf.image.resize(img, [img_size, img_size])
        img = tf.clip_by_value(img, 0.0, 255.0)
        
        img = img / 127.5 - 1
        
        return img
        
    def parse_fn_with_label(self, path, label):
        img = tf.io.read_file(path)
        img = tf.io.decode_png(img, 3)
        
        img = self.map_fn(img, self.img_size)
        label = tf.one_hot(label, depth=self.n_classes)
        
        return img, label
        
        
    def parse_dataset(self):
        all_imgs = glob.glob(os.path.join(data_dir, "*", "*.jpg"))
        
        img_paths = []
        img_labels = []
        for entry in all_imgs:
            class_name = entry.split('/')[-2]
            
            img_labels.append(class_name)
            img_paths.append(entry)
        
        img_labels = np.array(img_labels)
        img_paths = np.array(img_paths)
        
        img_labels = LabelEncoder().fit_transform(img_labels).flatten()
        
        self.n_classes = len(np.unique(img_labels))
        self.all_paths = img_paths
        self.all_labels = img_labels
        
        self.train_paths, self.val_paths, self.train_labels, self.val_labels = train_test_split(
            self.all_paths, self.all_labels, test_size = self.split_ratio)
        
    def get_train_dataset(self):
        if(self.train_dataset is None):
            self.train_dataset = tf.data.Dataset.from_tensor_slices((self.train_paths, self.train_labels))
            
            if(self.shuffle):
                self.train_dataset = self.train_dataset.shuffle(40000)
            
            self.train_dataset = self.train_dataset.map(self.parse_fn_with_label)
            self.train_dataset = self.train_dataset.batch(self.batch_size)
            self.train_dataset = self.train_dataset.repeat(1).prefetch(1)
            
        return self.train_dataset
    
    def get_val_dataset(self):
        if(self.val_dataset is None):
            self.val_dataset = tf.data.Dataset.from_tensor_slices((self.val_paths, self.val_labels))
            
            if(self.shuffle):
                self.val_dataset = self.val_dataset.shuffle(40000)
                
            self.val_dataset = self.val_dataset.map(self.parse_fn_with_label)
            self.val_dataset = self.val_dataset.batch(self.batch_size)
            self.val_dataset = self.val_dataset.repeat(1).prefetch(1)
            
        return self.val_dataset
            

In [46]:
# Create a loader and get train dataset
loader = DataLoader(data_dir)
dataset = loader.get_val_dataset()

In [51]:
# Extract sample batch and create patches
patch_size = 16
images, labels = next(iter(dataset))

# Flattening images into patches
patches = tf.image.extract_patches(
    images=images,
    sizes=[1, patch_size, patch_size, 1],
    strides=[1, patch_size, patch_size, 1],
    rates=[1, 1, 1, 1],
    padding="VALID",
)
patch_dims = patches.shape[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])

print(images.shape)
print(patches.shape)

768
(32, 64, 64, 3)
(32, 16, 768)
