In [4]:
import sys
sys.path.append('../')

import torch
import torch.nn as nn
from torch.optim import Adam, SGD
from torch.utils.data import DataLoader
from torch.utils.data import sampler

import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.nn.functional as F

import numpy as np
import pandas as pd
import os
import glob
import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import sklearn

import PIL

eps = np.finfo(float).eps

plt.rcParams['figure.figsize'] = 10, 10
%matplotlib inline



%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.datasets import CIFAR10

import numpy as np
from torch.utils.data import SubsetRandomSampler

from PIL import ImageFilter
from PIL import Image
import torch


def cifar_strong_transforms():
    all_transforms = transforms.Compose([
        transforms.RandomResizedCrop(32),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
    ])
    return all_transforms


def cifar_weak_transforms():
    all_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
    ])
    return all_transforms


def cifar_test_transforms():
    all_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
    ])
    return all_transforms


class CIFAR10C(CIFAR10):
    def __init__(self, weak_transform, strong_transform, *args, **kwargs):
        super(CIFAR10C, self).__init__(*args, **kwargs)
        self.weak_transform = weak_transform
        self.strong_transform = strong_transform

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]

        # return a PIL Image
        img = Image.fromarray(img)

        xi = self.weak_transform(img)
        xj = self.strong_transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        # Return label just for debugging
        return xi, xj, target


class LoaderCIFAR(object):
    def __init__(self, file_path, download, batch_size, num_labeled, use_cuda):

        kwargs = {'num_workers': 4, 'pin_memory': True} if use_cuda else {}

        # Get the datasets
        train_labeled_dataset, train_unlabeled_dataset, test_dataset, labeled_ind, unlabeled_ind = self.get_dataset(file_path, download, num_labeled)
        # Set the loaders
        self.train_labeled = DataLoader(train_labeled_dataset, batch_size=batch_size, shuffle=False, sampler=SubsetRandomSampler(labeled_ind), **kwargs)
        self.train_unlabeled = DataLoader(train_unlabeled_dataset, batch_size=batch_size, shuffle=False, sampler=SubsetRandomSampler(unlabeled_ind), **kwargs)

        self.test = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, **kwargs)

        tmp_batch = self.test.__iter__().__next__()[0]
        self.img_shape = list(tmp_batch.size())[1:]

    @staticmethod
    def get_dataset(file_path, download, num_labeled):

        num_labeled = 10

        # transforms
        weak_transform = cifar_weak_transforms()
        strong_transform = cifar_strong_transforms()
        test_transform = cifar_test_transforms()

        # Training and Validation datasets
        train_labeled_dataset = CIFAR10(root=file_path, train=True, download=download,
                                        transform=weak_transform,
                                        target_transform=None)
        train_unlabeled_dataset = CIFAR10C(weak_transform=weak_transform, strong_transform=strong_transform,
                                           root=file_path, train=True, download=download,
                                           transform=None,
                                           target_transform=None)

        test_dataset = CIFAR10(root=file_path, train=False, download=download,
                               transform=test_transform,
                               target_transform=None)

        if isinstance(train_labeled_dataset.targets, torch.Tensor):
            train_labels = train_labeled_dataset.targets.numpy()
        else:
            train_labels = np.array(train_labeled_dataset.targets)

        labeled_ind, unlabeled_ind = [], []

        for cl in range(10):
            class_indices = np.random.permutation(np.where(train_labels == cl)[0]).tolist()
            labeled_ind.extend(class_indices[:num_labeled])
            unlabeled_ind.extend(class_indices[num_labeled:])

        return train_labeled_dataset, train_unlabeled_dataset, test_dataset, labeled_ind, unlabeled_ind


- [ ]  auto contrast
- [ ] brightness
- [ ] to grayscale with prob p
- [ ] constrast change
- [ ] equalize
- [ ] invert 
- [ ] do nothing
- [ ] rescale
- [ ] posterize
- [ ] solarize
- [ ] rotate + rescale
- [ ] sharpness
- [ ] smooth
- [ ] flipx
- [ ] flipy

In [7]:
np.random.choice(np.arange(1, 8))

4

In [None]:
def cutout(image, p):

    image = np.asarray(image).copy()

    draw = np.random.rand()
    if draw > p:
        return image

    h, w = image.shape[:2]

    draw = np.random.uniform(0, 0.5, 1)
    if draw == 0:
        return image
    else:
        patch_size = int(draw * h)

    lu_x = np.random.randint(0, w - patch_size)
    lu_y = np.random.randint(0, h - patch_size)

    mask_color = np.asarray([0.5, 0.5, 0.5])

    image[lu_y:lu_y + patch_size, lu_x:lu_x + patch_size] = mask_color

    return image



def auto_contrast(image):
    def _auto_contrast(image):
        return PIL.ImageOps.autocontrast(image)

    return transforms.Lambda(lambda x: _auto_contrast(x))


def random_brightness(image):
    
    def _random_brightness(image):
        return PIL.ImageOps.autocontrast(image)

    return transforms.Lambda(lambda x: _random_brightness(x))
    
    
    
def random_grayscale(image, p=0.5):
    
    def _random_grayscale(image, p=0.5):
        draw = np.random.rand()
        if draw > p:
            return PIL.ImageOps.grayscale(image)
        else:
            return image
    
    return transforms.Lambda(lambda x: _random_grayscale(x, p=0.5))
    
    
def random_constrast_change(image):
    
    def _random_constrast_change(image):
        return
    
    return transforms.Lambda(lambda x: _random_constrast_change(x))
    
    
def equalize(image, alpha=0.3): 
    def _equalize(image, alpha=0.3):    
        return
    
    return transforms.Lambda(lambda x: _equalize(x, alpha=alpha))


def invert(image, alpha=0.3):
    
    def _invert(image, alpha=0.3):
        return image
    
    return transforms.Lambda(lambda x: _invert(x, alpha=alpha))


def nop(image):
    
    def _nop(image):
        return image
    
    return transforms.Lambda(lambda x: _nop(x))


def random_rescale(image):
    return transforms.RandomResizedCrop(32)


def posterize(image):

    def _posterize(image):
        num_bits = np.random.choice(np.arange(1, 8))
        return PIL.ImageOps.posterize(image, num_bits)
    
    return transforms.Lambda(lambda x: _posterize(x)),


def solarize(image):

    def _solarize(image):
        max_val = max(image)[0]
        values = np.linspace(0, max_val, 256)
        threshold = np.random.choice(values)
        return PIL.ImageOps.solarize(image, threshold)

    return transforms.Lambda(lambda x: _solarize(x)),


def rotate():
    return transforms.RandomRotation(resample=True, expand=False)

def _sharpness(image):
    
    def _sharpness(image):
        radius = 2
        percent = 150
        threshold = 3
        H = PIL.ImageFilter.UnsharpMask(radius, percent, threshold)
        image = image.filter(H)
        return image
    
    return transforms.Lambda(lambda x: _sharpness(x)),


def smooth():
    
    def _smooth(image):
        radius = 2
        H = PIL.ImageFilter.GaussianBlur(radius)
        image = image.filter(H)
        return image
    
    return transforms.Lambda(lambda x: _smooth(x)),


def flipx(image):
    return transforms.RandomHorizontalFlip(p=0.5)    
 
    
def flipy(image):
        return transforms.RandomVerticalFlipFlip(p=0.5)    

In [None]:
transform_list = [auto_contrast]



transform = transforms.RandomChoice(transform_list)