In [1]:
import ipywidgets
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torchvision.transforms import InterpolationMode
from torch.utils.data import Dataset,DataLoader
import cv2
import os
import torch
import numpy as np
from PIL import Image

# create sequence

In [2]:
import gzip
import pickle
import os
import sys
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import os

PY2 = sys.version_info[0] == 2

if PY2:
    from urllib import urlretrieve

    def pickle_load(f, encoding):
        return pickle.load(f)
else:
    from urllib.request import urlretrieve

    def pickle_load(f, encoding):
        return pickle.load(f, encoding=encoding)

def _load_data(url, filename):
    """Load data from `url` and store the result in `filename`."""
    if not os.path.exists(filename):
        print("Downloading MNIST dataset")
        urlretrieve(url, filename)

    with gzip.open(filename, 'rb') as f:
        return pickle_load(f, encoding='latin-1')



def load_data(filename, url=None):
    """Get data with labels, split into training, validation and test set."""
    data = _load_data(url,filename)
    X_train, y_train = data[0]
    X_valid, y_valid = data[1]
    X_test, y_test = data[2]



    return dict(
        X_train=X_train,
        y_train=y_train,
        X_valid=X_valid,
        y_valid=y_valid,
        X_test=X_test,
        y_test=y_test,
        num_examples_train=X_train.shape[0],
        num_examples_valid=X_valid.shape[0],
        num_examples_test=X_test.shape[0],
        input_dim=X_train.shape[1],
        output_dim=10)

In [3]:
''' Parameters:
ORG_SHP:  digit image shape 
OUT_SHP: output cluttered_MNIST image shape
NUM_DISTORTIONS: number of distortions set in the output image
dist_shape: shape of each distortion
NUM_DISTORTIONS_DB: length of  distortions list
'''
ORG_SHP = [28,28]
OUT_SHP = [100,100]
NUM_DISTORTIONS = 6
dist_size = (9,9)  
NUM_DISTORTIONS_DB = 100000

mnist_data = load_data('data/mnist.pkl.gz')
np.random.seed(1234)
''' mnist dataset mnist.pkl.gz
contains: X_train (50000),X_vaild (10000),X_test (10000), each img of size 784
input dim:784
output_dim:(10)
'''
### create list with distortions
all_digits = np.concatenate([mnist_data['X_train'], mnist_data['X_valid']], axis=0)
all_digits = all_digits.reshape([-1] + ORG_SHP) #(600000,28,28)
num_digits = all_digits.shape[0] 

'''create a list of different distortions
shape of each distortion: dist_size
length of the list: NUM_DISTORTIONS_DB
'''
distortions = []
for i in range(NUM_DISTORTIONS_DB):
    rand_digit = np.random.randint(num_digits)
    rand_x = np.random.randint(ORG_SHP[1]-dist_size[1])
    rand_y = np.random.randint(ORG_SHP[0]-dist_size[0])

    digit = all_digits[rand_digit]
    distortion = digit[rand_y:rand_y + dist_size[0],
                       rand_x:rand_x + dist_size[1]]
    assert distortion.shape == dist_size
    distortions += [distortion]
print("Created distortions")

Created distortions


In [4]:
def create_sample1(x, output_shp, num_distortions=NUM_DISTORTIONS):
    ''' combine digitals with distortions, the True digit set in the center of output image
    Parameters:
    x (np.array): True digital images,  dim=(28,28)
    output_shp: output shape of the True digit
    '''
    a, b= x.shape
    x_offset = (output_shp[1]-a)//2 #center of the image
    y_offset = (output_shp[1]-a)//2 #center of the image
    x_offset += np.random.choice(range(int(-2*x_offset/3), int(2*x_offset/3))) # set the offset of  x randomly
    y_offset += np.random.choice(range(int(-2*y_offset/3), int(2*y_offset/3))) 
    #y_offset = np.random.choice(range(output_shp[1])) #set the  offset of y randomly

    angle = np.random.choice(range(int(-b*0.5), int(b*0.5))) # set the angle randomly


    output = np.zeros(output_shp)
    
    x_start = 0*b+x_offset

    x_end = x_start + b
    y_start = y_offset + np.floor(0*angle)
    y_end = y_start + a
    if y_end > (output_shp[1]-1):
        m = output_shp[1] - y_end
        y_end += m
        y_start += m
    if y_start < 0:
        m = y_start
        y_end -= m
        y_start -= m
    y_start,y_end=int(y_start),int(y_end)
    
    output[y_start:y_end, x_start:x_end] = x

    if num_distortions > 0:
            output = add_distortions(output, num_distortions)
    return output



def add_distortions(digits, num_distortions):
    ''' choose num_distortions diff distortions and add them to the output image'''
    canvas = np.zeros_like(digits)
    for i in range(num_distortions):
        rand_distortion = distortions[np.random.randint(NUM_DISTORTIONS_DB)]
        rand_x = np.random.randint(OUT_SHP[1]-dist_size[1])
        rand_y = np.random.randint(OUT_SHP[0]-dist_size[0])
        canvas[rand_y:rand_y+dist_size[0],
               rand_x:rand_x+dist_size[1]] = rand_distortion
    canvas += digits

    return np.clip(canvas, 0, 1)

In [5]:
## Data Augmentation
class DataAugmentation:
    def __init__(self,global_crops_scale=(0.5,1.0),n_local_crops=2,output_size=224):
        
        self.n_local_crops = n_local_crops
        RandomGaussianBlur=lambda p: transforms.RandomApply([transforms.GaussianBlur(kernel_size=3,sigma=(0.1,2))],p=p)
        #flip_and_rotation=transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomRotation(degrees=(10)),])
        colorjitter=transforms.ColorJitter(brightness=0,contrast=0,saturation=0,hue=0.2)
        crop=transforms.CenterCrop(56)
        resize=transforms.Resize((output_size,output_size),interpolation=InterpolationMode.BICUBIC)
        rotation=transforms.RandomRotation(degrees=(6))
        shift=transforms.RandomAffine(degrees=6,translate=(0.2,0.1))
        normalize=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.13,),(0.3,)),])
        

        self.global_1=transforms.Compose([
            #shift,
            #flip_and_rotation,
            transforms.RandomResizedCrop(output_size,scale=global_crops_scale,interpolation=InterpolationMode.BICUBIC),
            #colorjitter,
            rotation,
            RandomGaussianBlur(0.1),
            normalize
        ])
        self.global_2=transforms.Compose([
            transforms.RandomResizedCrop(output_size,scale=global_crops_scale,interpolation=InterpolationMode.BICUBIC),
            rotation,
            #colorjitter,
            RandomGaussianBlur(1.0),
            #transforms.RandomSolarize(170,p=0.2),
            normalize
        ])
        self.local=transforms.Compose([
            crop,
            resize,
            #colorjitter,
            rotation,
            RandomGaussianBlur(0.5),
            normalize
        ])

    
    def __call__(self,image):
        '''
        all_crops:list of torch.Tensor
        represent different version of input img
        '''
        all_crops=[]
        image=(np.asarray(image.convert('L')))/255.0 # input image has PIL.Image format, convert to numpy
        
        image1=create_sample1(image, OUT_SHP) #add distortions
        image2=create_sample1(image, OUT_SHP)
        image1=(image1*255.0).astype(np.uint8) 
        image2=(image2*255.0).astype(np.uint8)
        image1=Image.fromarray(cv2.cvtColor(image1,cv2.COLOR_GRAY2RGB)) #convert to PIL.Image
        image2=Image.fromarray(cv2.cvtColor(image2,cv2.COLOR_GRAY2RGB))

        all_crops.append(self.global_1(image1))
        #all_crops.append(self.global_2(image1))
        #all_crops.append(self.local(image1))
        #all_crops.append(self.local(image1))
        

        all_crops.append(self.global_1(image2))
        #all_crops.append(self.global_2(image2))
        #all_crops.append(self.local(image2))
        #all_crops.append(self.local(image2))
        return all_crops

In [7]:
transform=DataAugmentation()
PATH = 'data/MNIST'
dataset = ImageFolder(PATH, transform=transform)


In [8]:
import numpy as np
def to_numpy(t):
    arr=torch.clip(t,0,1).permute(1,2,0).numpy()
    return arr

In [9]:
#GUI
'''
i: range(0,len(dataset)-1), choose a sample from dataset
seed: choose a seed from 0 to 50
'''
@ipywidgets.interact
def _(i=ipywidgets.IntSlider(min=0,max=len(dataset)-1,continuous_update=False),
    seed=ipywidgets.IntSlider(min=0,max=50,continuous_update=False),):
    torch.manual_seed(seed)
    all_crops,labels=dataset[i]
    print("number of crops:",len(all_crops))

    titles=['view_0','view_1']
    orig_img=np.array(Image.open(dataset.samples[i][0]))
    fig,axs=plt.subplots()
    axs.imshow(orig_img)
    axs.set_title('original_image')
    axs.axis('off')

    
    fig,axs=plt.subplots(1,2,figsize=(10,10))
    for i,t in enumerate(titles):
        ax=axs[i]
        ax.imshow(to_numpy(all_crops[i]))
        ax.set_title(t)
        ax.axis('off')
    fig.tight_layout()

interactive(children=(IntSlider(value=0, continuous_update=False, description='i', max=59999), IntSlider(value…