# Computer Vision - 1st Assignment

## Names: 
 - Idan Dunsky 
 - Yaniv Kaveh-Shtul 


### imports and downloads

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torch.utils.data import Dataset
%matplotlib inline
from PIL import Image


#download MNIST dataset
mnist_data = torchvision.datasets.MNIST(root='',
download=True)

In [None]:
import albumentations as A

from albumentations import (
    HorizontalFlip, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    GaussNoise, MotionBlur, MedianBlur, RandomBrightnessContrast, Flip, OneOf, Compose, Rotate, Affine, CenterCrop   
)

### some helper functions

In [None]:
#convert PIL image to numpy.array
def pil2np(img):
  if isinstance(img, Image.Image):
    return np.array(img)
  return img

In [None]:
def get_single_sample(dataset_object: Dataset, sample_index = None, desired_label = None):
    '''
        get a random sample from a given datatset
        
        input:
            dataset_object - Dataset,
            sample_index - int,
            desired_label - int
            
        output:
            img - data,
            sample_index - data index in dataset,
            label - data label.
    '''
    while(True):
        if type(sample_index) == type(None):
            sample_index = np.random.randint(0, len(dataset_object))
            
        if not isinstance(sample_index, int):
            raise TypeError(f'expected type of sample index (if provided )is int, yet {type(sample_index) = }')
        
        img2augment, label = dataset_object[sample_index][0], dataset_object[sample_index][1]
        
        if type(desired_label) == type(None) or label == desired_label:
            break
        
        sample_index = None
        
    return pil2np(img2augment), sample_index ,label 

### print 5 examples for each class

In [None]:
num_classes = len(mnist_data.classes)
classes = [ mnist_data.class_to_idx[i] for i  in mnist_data.classes ]

for label in classes :
    plt.figure(figsize=(25,5))
    for i in range(5):
        img_np, sample_index, sample_label = get_single_sample(mnist_data, desired_label=label)
        plt.subplot(1,5,i+1)
        plt.title(f'sample #{sample_index}, digit {sample_label}')
        plt.imshow(img_np,'gray')
    plt.show()       


## Creating a 3 digit number

In [None]:
def create_three_digit_img(imgs : np.array, labels : list[int]) -> tuple:
    '''
    gets 3 images and a list of 3 labels, and creates a new 3 digit image and a label
    
    input:
        imgs - np.array 
        labels - list[int]
    
    output:
        (img, label) - (np.array , int)
    '''
    imgs_np = np.array([pil2np(img) for img in imgs])
    new_img = np.concatenate(imgs_np, axis=1)
    label =  100*labels[0] + 10*labels[1] + labels[2]
    
    return new_img, label



def plot_three_digit_image(img : np.array, label : int) -> None:
    '''
    plots only three digits number image
    
    input:
        img - np.array 
        label - int
    
    output:
        None
        
    '''
    plt.figure(figsize=(9,9))
    plt.title(f"Number - {label}")
    plt.imshow(img,'gray')
    plt.show()       

### Augmentation function

In [None]:
def augment_digit(digit_imgs : list[np.array], transforms : Compose, num_of_aug : int) -> list[np.array]:
    '''
        augments a given list of images of the same digit with a given Compose of transforms num_of_aug times for each image
        
        input:
            digit_imgs - list[np.array]
            transforms - Compose
            num_of_aug - int
        
        output:
            aug_imgs - list[np.array]
    '''
    aug_imgs = []
    for img in digit_imgs:
        for i in range(num_of_aug):
            aug_imgs.append(transforms(image=img)['image'])
        
    return aug_imgs

## Creating the MNIST101 DataSet

In [None]:
def create_mnist101(mnist_data : Dataset, transforms : A.Compose) -> np.array:
    '''
    Creates mnist101 dataset - a full dataset of three digit images within the range of 000-100, the dataset contains 5832 images for each number - total of 589032 different images.
    
    input:
        mnist_data - Dataset,
        transforms - Albumentations.Compose.
        
    output:
        mnist101 - np.array[(img, label)]
    '''
    mnist101 = []
    num_of_aug = 5
        
    #creating the actual data of 3 digit images
    for hundreds in range(2):
        for tens in range(10):
            for units in range(10):
                hundreds_imgs = []
                tens_imgs = []
                units_imgs = []
                
                # get 3 different samples for each digit in position - only for diversty
                for i in range(3):
                    
                    tmp_hundreds_img, __ , __ = get_single_sample(mnist_data, desired_label = hundreds)
                    hundreds_imgs.append(tmp_hundreds_img)
                    
                    tmp_tens_img, __ ,__ = get_single_sample(mnist_data, desired_label = tens)
                    tens_imgs.append(tmp_tens_img)
                    
                    tmp_units_img, __ ,__ = get_single_sample(mnist_data, desired_label = units)
                    units_imgs.append(tmp_units_img)

                # augment digit images for each position
                aug_imgs_hundreds = augment_digit(hundreds_imgs, mnist_tansforms, num_of_aug)
                aug_imgs_tens = augment_digit(tens_imgs, mnist_tansforms, num_of_aug)
                aug_imgs_units = augment_digit(units_imgs, mnist_tansforms, num_of_aug)
                
                # add augmented digits
                hundreds_imgs.extend(aug_imgs_hundreds)
                tens_imgs.extend(aug_imgs_tens)
                units_imgs.extend(aug_imgs_units)
                
                # creat all combinations of images for the specific number
                for h_img in hundreds_imgs:
                    for t_img in tens_imgs:
                        for u_img in units_imgs:
                            new_img, new_label = create_three_digit_img([h_img, t_img, u_img], [hundreds, tens, units])
                            mnist101.append((new_img, new_label))            
               
                print(f"finished creating {hundreds}{tens}{units}, total size of dataset: {len(mnist101)}")
                
                # stop at 100
                if hundreds == 1 and mnist101[len(mnist101)-1][1] == 100:
                    return mnist101
                                
               

In [None]:
# define augmentation pipeline
mnist_transforms = A.Compose([
    OneOf([
        Blur(blur_limit=5),
        Rotate(limit=45),
        HorizontalFlip(), 
        CLAHE(),
        RandomRotate90(),
        Transpose(),
        ShiftScaleRotate(),
        OpticalDistortion(),
        GridDistortion(),
        HueSaturationValue(),
        GaussNoise(),
        MotionBlur(),
        MedianBlur(),
        RandomBrightnessContrast(),
        Affine(scale=0.5),
        Flip()],
        p=1)
    ])

# create MNIST101    
mnist101 = create_mnist101(mnist_data , mnist_transforms)

### let's check the data

In [None]:
for _ in range(10):
    test_img = mnist101[np.random.choice(len(mnist101))]
    plot_three_digit_image(test_img[0], test_img[1])

### Explaining the augmentations

The augmentations used in this notebook are:

- Flip
- Rotate
- Blur
- HorizontalFlip
- CLAHE
- Transpose
- RandomRotate90
- ShiftScaleRotate
- OpticalDistortion
- GridDistortion
- HueSaturationValue
- GausseNoise
- MotionBlur
- MedianBlur
- RandomBrightnessContrast
- AffineScaling

the use of augmentations such as Flip, Horizontal Flip, Transpose, RandomRotate90, Rotate(over 45 degrees) is not appropriate with this specific dataset because it can cause imbiguity by making digits look like other digits, such as 6 and 9 or 5 and 2 being flipped.

In [None]:
# According to MNIST doc - test and train sets are no more seperated

class Mnist101(Dataset):

  def __init__(self, train_set, transforms = None):
    super(Mnist101, self).__init__()
    self.transforms =  transforms
    
    if type(train_set) == torchvision.datasets.mnist.MNIST:
      train_set = self.__create101__(train_set)
    
    np.random.shuffle(train_set)
    self.xs = [item[0] for item in train_set]
    self.ys = [item[1] for item in train_set]
    

  def __getitem__(self, idx):

    x = np.array(self.xs[idx])
    y = self.ys[idx]
    
    return x, y

  def __len__(self):
    ds_len = len(self.xs)
    return ds_len
  
  
  def __augment_digit__(self, digit_imgs : list[np.array], num_of_aug : int) -> list[np.array]:
    '''
    augments a given list of images of the same digit with a given Compose of transforms num_of_aug times for each image\n
        
        input:
            digit_imgs - list[np.array]
            transforms - Compose
            num_of_aug - int
        
        output:
            aug_imgs - list[np.array]
    '''
    aug_imgs = []
    for img in digit_imgs:
        for i in range(num_of_aug):
            aug_imgs.append(self.transforms(image=img)['image'])
          
    return aug_imgs
  
  def __create101__(self, mnist10_data : Dataset):
    '''
    generate the MNIST101 from MNIST10 Dataset + augmentations.
    
    input:
        mnist_data - Dataset,
        transforms - Albumentations.Compose.
        
    output:
        mnist101 - np.array[(img, label)]
    '''
    mnist101 = []
    num_of_aug = 5
        
    #creating the actual data of 3 digit images
    for hundreds in range(2):
        for tens in range(10):
            for units in range(10):
                hundreds_imgs = []
                tens_imgs = []
                units_imgs = []
                
                # get 3 different samples for each digit in position - only for diversty
                for i in range(3):
                    
                    tmp_hundreds_img, __ , __ = get_single_sample(mnist10_data, desired_label = hundreds)
                    hundreds_imgs.append(tmp_hundreds_img)
                    
                    tmp_tens_img, __ ,__ = get_single_sample(mnist10_data, desired_label = tens)
                    tens_imgs.append(tmp_tens_img)
                    
                    tmp_units_img, __ ,__ = get_single_sample(mnist10_data, desired_label = units)
                    units_imgs.append(tmp_units_img)

                # augment digit images for each position
                aug_imgs_hundreds = self.__augment_digit__(hundreds_imgs, num_of_aug)
                aug_imgs_tens = self.__augment_digit__(tens_imgs, num_of_aug)
                aug_imgs_units = self.__augment_digit__(units_imgs, num_of_aug)
                
                # add augmented digits
                hundreds_imgs.extend(aug_imgs_hundreds)
                tens_imgs.extend(aug_imgs_tens)
                units_imgs.extend(aug_imgs_units)
                
                # creat all combinations of images for the specific number
                for h_img in hundreds_imgs:
                    for t_img in tens_imgs:
                        for u_img in units_imgs:
                            new_img, new_label = create_three_digit_img([h_img, t_img, u_img], [hundreds, tens, units])
                            mnist101.append((new_img, new_label))            
                
                # stop at 100
                if hundreds == 1 and mnist101[len(mnist101)-1][1] == 100:
                    print(f"DataSet initialized successfully, with total size of : {len(mnist101)}")
                    return mnist101
    
    
    

### Let's test our Dataset 

We will create MNIST101 using our custom Dataset class, from the original MNIST10 Dataset.

After testing different augmentations of which some were non-appropriate for this specific dataset, now we will use only appropriate ones.

* In the assignment we were asked to use both MNIST10 train and test sets, but According to MNIST doc - test and train sets are no more seperated

In [None]:
# define augmentation pipeline
transforms = A.Compose([
    OneOf([
        Blur(blur_limit=5),
        Rotate(limit=45), 
        CLAHE(),
        ShiftScaleRotate(),
        OpticalDistortion(),
        GridDistortion(),
        GaussNoise(),
        MotionBlur(),
        MedianBlur(),
        RandomBrightnessContrast(),
        Affine(scale=0.5)],
        p=1)
    ])


mnist101_class = Mnist101(mnist_data, transforms)

In [None]:
for _ in range(5):
    test_image, test_label = mnist101_class[np.random.choice(len(mnist101_class.xs))]
    plot_three_digit_image(test_image, test_label)

## Summary

This project extends the classic MNIST dataset to create a more complex dataset of three-digit numbers (MNIST101). The main focus is on dataset generation and augmentation.

The project demonstrates skills in dataset creation, data augmentation, and custom dataset implementation in PyTorch.

This new dataset could potentially be used for various computer vision tasks involving multi-digit numbers.