In [46]:
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.utils.data as data

In [None]:
class Cutout(object):
  """Randomly mask out one or more patches from an image.
  Args:
      n_holes (int): Number of patches to cut out of each image.
      length (int): The length (in pixels) of each square patch.
  """
  def __init__(self, n_holes, length):
      self.n_holes = n_holes
      self.length = length

  def __call__(self, img):
      """
      Args:
          img (Tensor): Tensor image of size (C, H, W).
      Returns:
          Tensor: Image with n_holes of dimension length x length cut out of it.
      """
      h = img.size(1)
      w = img.size(2)

      mask = np.ones((h, w), np.float32)

      for n in range(self.n_holes):
          y = np.random.randint(h)
          x = np.random.randint(w)

          y1 = np.clip(y - self.length // 2, 0, h)
          y2 = np.clip(y + self.length // 2, 0, h)
          x1 = np.clip(x - self.length // 2, 0, w)
          x2 = np.clip(x + self.length // 2, 0, w)

          mask[y1: y2, x1: x2] = 0.

      mask = torch.from_numpy(mask)
      mask = mask.expand_as(img)
      img = img * mask

      return img

In [57]:
def load_CIFAR10(batch_size, train_ratio):

  ROOT = './data'
  trainset = torchvision.datasets.CIFAR10(
      root = ROOT,
      train = True, 
      download = True
  )

  # Compute means and standard deviations
  means = trainset.data.mean(axis=(0,1,2)) / 255
  stds = trainset.data.std(axis=(0,1,2)) / 255
  #print(means, stds)

  # Preprocess setting
  transform_train = transforms.Compose([
      transforms.RandomCrop(32, padding=4),
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      transforms.Normalize(mean=means, std=stds),
      Cutout(n_holes=1, length=16)
  ])
  transform_test = transforms.Compose([
      transforms.ToTensor(),
      transforms.Normalize(mean=means, std=stds)
  ])

  # Load the dataset
  trainset = torchvision.datasets.CIFAR10(
      root = ROOT, 
      train = True, 
      download = True, 
      transform = transform_train
  )
  testset = torchvision.datasets.CIFAR10(
      root = ROOT, 
      train = False, 
      download = True, 
      transform = transform_test
  )

  # Split trainset for validset
  n_train = int(len(trainset) * train_ratio)
  n_valid = len(trainset) - n_train
  train_dataset, valid_dataset = data.random_split(trainset, [n_train, n_valid])
  
  # Build dataloader
  train_iterator = data.DataLoader(train_dataset, batch_size)
  valid_iterator = data.DataLoader(valid_dataset, batch_size)
  test_iterator = data.DataLoader(testset, batch_size)

  return train_iterator, valid_iterator, test_iterator

In [66]:
trainLoader, validLoader, testLoader = load_CIFAR10(batch_size=16, train_ratio=1)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [67]:
print(len(trainLoader))
print(len(validLoader))
print(len(testLoader))

3125
0
625
