In [1]:
import ipywidgets
import matplotlib.pyplot as plt
import torch
import torchvision

In [2]:
# Data Augmentation
import torch
import torchvision.transforms as transforms
from torchvision.transforms import InterpolationMode
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
class DataAugmentation:
    def __init__(self,global_crops_scale=(0.4,1),local_crops_scale=(0.05,4),n_local_crops=2,output_size=112):

        self.n_local_crops=n_local_crops
        RandomGaussianBlur=lambda p: transforms.RandomApply([transforms.GaussianBlur(kernel_size=1,sigma=(0.1,2))],p=p)
        flip_and_rotation=transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomRotation(degrees=(10)),])
        normalize=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,)),])


        self.global_1=transforms.Compose([
            transforms.RandomResizedCrop(output_size,scale=global_crops_scale,interpolation=InterpolationMode.BICUBIC),
            flip_and_rotation,
            RandomGaussianBlur(1.0),
            normalize
        ])
        self.global_2=transforms.Compose([
            transforms.RandomResizedCrop(output_size,scale=global_crops_scale,interpolation=InterpolationMode.BICUBIC),
            flip_and_rotation,
            RandomGaussianBlur(0.1),
            transforms.RandomSolarize(170,p=0.2),
            normalize
        ])
        self.local=transforms.Compose([
            transforms.RandomResizedCrop(224,scale=local_crops_scale,interpolation=InterpolationMode.BICUBIC),
            flip_and_rotation,
            RandomGaussianBlur(0.5),
            normalize
        ])

    
    def __call__(self,image):
        '''
        all_crops:list of torch.Tensor
        represent different version of input img
        '''
        all_crops=[]
        all_crops.append(self.global_1(image))
        all_crops.append(self.global_2(image))
        all_crops.extend([self.local(image) for _ in range(self.n_local_crops)])
        return all_crops

In [7]:
transform=DataAugmentation(n_local_crops=2)
DOWNLOAD_PATH = './data'
dataset = torchvision.datasets.MNIST(DOWNLOAD_PATH, train=True, download=False,transform=transform)

In [4]:
def to_numpy(t):
    arr=torch.clip((t*0.224)+0.45,0,1).permute(1,2,0).numpy()
    return arr

In [8]:
#GUI
'''
i: range(0,len(dataset)-1), choose a sample from dataset
seed: choose a seed from 0 to 50
'''
@ipywidgets.interact
def _(i=ipywidgets.IntSlider(min=0,max=len(dataset)-1,continuous_update=False),
    seed=ipywidgets.IntSlider(min=0,max=50,continuous_update=False),):
    torch.manual_seed(seed)
    all_crops,labels=dataset[i]
    print("number of crops:",len(all_crops))

    titles=['global_0','global_1','local_0','local_1']
    orig_img=dataset.data[i]
    fig,axs=plt.subplots(figsize=(10,10))
    axs.imshow(orig_img)
    axs.set_title('original_image')
    axs.axis('off')

    
    fig,axs=plt.subplots(2,2,figsize=(10,10))
    for i,t in enumerate(titles):
        ax=axs[i//2,i%2]
        ax.imshow(to_numpy(all_crops[i]))
        ax.set_title(t)
        ax.axis('off')
    fig.tight_layout()

interactive(children=(IntSlider(value=0, continuous_update=False, description='i', max=59999), IntSlider(value…