# Cutout

An exploratory analysis into how different types of [Cutout](https://github.com/uoguelph-mlrg/Cutout) affect model performance on the CIFAR10 dataset.



In [16]:
import torch

import torchvision
import torchvision.models as models
import torchvision.transforms as transforms


from PIL import Image, ImageDraw


In [173]:
import torch
import numpy as np
from itertools import product

class CutoutDispersed:
    """Randomly mask out pixels from an image
    Args:
        n_pixels (int): Number of pixels to mask
    """
    
    def __init__(self, n_pixels):
        self.n_pixels = n_pixels
    
    def __call__(self, img):
        
        c = img.size(0)
        h = img.size(1) # double check if width / height correct!
        w = img.size(2)
        
        N = h*w        
        mask = torch.cat(
            (torch.randint(0, w, size=(self.n_pixels, 1)),
            torch.randint(0, h, size=(self.n_pixels, 1))),
            dim=1
        )
        
        mask = torch.randperm(N)[:self.n_pixels]
        for i in range(c):
            img[i].flatten()[mask] = 0
            
        return img

class CutoutOfficial:
    """Randomly mask out one or more patches from an image.
    Args:
        n_holes (int): Number of patches to cut out of each image.
        length (int): The length (in pixels) of each square patch.
    """
    def __init__(self, n_holes, length):
        self.n_holes = n_holes
        self.length = length

    def __call__(self, img):
        """
        Args:
            img (Tensor): Tensor image of size (C, H, W).
        Returns:
            Tensor: Image with n_holes of dimension length x length cut out of it.
        """
        h = img.size(1)
        w = img.size(2)

        mask = np.ones((h, w), np.float32)

        for n in range(self.n_holes):
            y = np.random.randint(h)
            x = np.random.randint(w)

            y1 = np.clip(y - self.length // 2, 0, h)
            y2 = np.clip(y + self.length // 2, 0, h)
            x1 = np.clip(x - self.length // 2, 0, w)
            x2 = np.clip(x + self.length // 2, 0, w)

            mask[y1: y2, x1: x2] = 0.

        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img = img * mask

        return img
    
class CutoutVariable:
    """Randomly mask out one or more patches (or random size) from an image.
    Args:
        max_size (int): The maximum size of the square patch
        n_masks (int): Number of patches to cut out of each image.


    """
    def __init__(self, max_size, n_masks=1):
        self.max_size = max_size
        self.n_masks = n_masks

    def __call__(self, img):
        """
        Args:
            img (Tensor): Tensor image of size (C, H, W).
        Returns:
            Tensor: Image with n_holes of dimension length x length cut out of it.
        """
        h = img.size(1)
        w = img.size(2)
        
        if isinstance(self.max_size, int):
            size_h = size_w = int(torch.randint(0, high=self.max_size+1, size=(1,)))
        else:
            
            size_h = int(torch.randint(0, high=self.max_size[0]+1, size=(1,)))
            size_w = int(torch.randint(0, high=self.max_size[1]+1, size=(1,)))
        
        i = int(torch.randint(0, high=h - size_h, size=(1,)))
        j = int(torch.randint(0, high=w - size_w, size=(1,)))
        img[..., i:i+size_h, j:j+size_w] = 0
        # TODO not sure if logic correct
        return img

In [7]:
a = Cutout(1, 4)

In [20]:
transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), Cutout(1, 4)])

batch_size = 16  # sets batch_size to 16 for training and saving data!
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)  # TODO set shuffle back to True
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified


In [21]:
for i, (inputs, outputs) in enumerate(trainloader):
    im = Image.fromarray(
                        (torch.cat(
                            inputs.to('cpu').split(1, 0), 3
                        ).squeeze() / 2 * 255 + .5 * 255).permute(1, 2, 0).numpy().astype('uint8')
                    )
    im.save("cutout.png")
    break



In [172]:
A = torch.ones((3,5,5))
m = CutoutVariable(5)
m(A)

tensor([[[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [0., 0., 1., 1., 1.],
         [0., 0., 1., 1., 1.],
         [1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [0., 0., 1., 1., 1.],
         [0., 0., 1., 1., 1.],
         [1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [0., 0., 1., 1., 1.],
         [0., 0., 1., 1., 1.],
         [1., 1., 1., 1., 1.]]])

In [148]:
A = torch.ones((3,5,5))
A[:, 0].flatten()



tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

In [117]:
indices = torch.randperm(5)[:2]
indices

tensor([1, 4])

In [118]:
torch.randperm(5)

tensor([0, 4, 1, 3, 2])