In [None]:
import torch
import torchvision
import torchvision.transforms as tr
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt

In [None]:
class ToTensor:
    def __call__(self, sample):
        inputs, labels=sample
        inputs = torch.FloatTensor(inputs)
        inputs = inputs.permute(2, 0, 1)
        return inputs, torch.LongTensor(labels)


In [None]:
class CutOut:
    def __init__(self, ratio=.5):
        self.ratio = int(1/ratio)

    def __call__(self, inputs):
        active = int(np.random.randint(0, self.ratio, 1))

        if active == 0:
            _, w, h = inputs.size()
            min_len = min(w, h)
            box_size = int(min_len // 4)
            idx = int(np.random.randint(0, min_len - box_size, 1))
            inputs[:, idx:idx + box_size, idx : idx + box_size] = 0

        return inputs
    
transf = tr.Compose([tr.Resize(128), tr.ToTensor(), CutOut()])
trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform = transf)

In [None]:
class MyDatasets(Dataset):
    def __init__(self, x_data, y_data, transform=None):
        self.x_data = x_data
        self.y_data = y_data
        self.transform = transform
        self.len = len(y_data)
        self.tensor = ToTensor()
    
    def __getitem__(self, index):
        sample = self.x_data[index], self.y_data[index]

        if self.transform:
            sample = self.transform(sample)
        else:
            sample = self.tensor(sample)
        return sample
    def __len__(self):
        return self.len

In [None]:
train_images = np.random.randint(256, size=(100, 32, 32, 3)) / 255
train_labels = np.random.randint(2, size=(100, 1))


trans = tr.Compose([ToTensor(), CutOut()])
dataset1 = MyDatasets(train_images, train_labels, transform=trans)
train_loader1 = DataLoader(dataset1, batch_size=10, shuffle = True)

In [None]:
import torchvision
images1, labels1 = next(iter(train_loader1))

def imshow(img):
    plt.figure(figsize=(10, 100))
    plt.imshow(img.permute(1, 2, 0).nunmpy())
    plt.show()

imshow(torchvision.utils.make_grid(images1, nrow=10))

In [None]:
class MyTransform:
    def __call__(self, sample):
        inputs, labels = sample
        inputs = torch.FloatTensor(inputs)
        inputs = inputs.permute(2, 0, 1)
        labels = torch.FloatTensor(labels)

        transf =  tr.Compose([tr.ToPILImage(), tr.Resize(128), tr.ToTensor()])
        final_output = transf(inputs)

        return final_output, labels

In [None]:
dataset2 = MyDatasets(train_images, train_labels, transform = MyTransform())
train_loader2 = DataLoader(dataset2, batch_size=10, shuffle = True)

In [None]:
class CutOut:
    def __init__(self, ratio=.5):
        self.ratio = int(1/ratio)

    def __call__(self, inputs):
        active = int(np.random.randint(0, self.ratio, 1))

        if active == 0:
            _, w, h = inputs.size()
            min_len = min(w, h)
            box_size = int(min_len // 4)
            idx = int(np.random.randint(0, min_len - box_size, 1))
            inputs[:, idx:idx + box_size, idx : idx + box_size] = 0

        return inputs
    
transf = tr.Compose([tr.Resize(128), tr.ToTensor(), CutOut()])
trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform = transf)

In [None]:
trainloader = DataLoader(trainset, batch_size = 10, shuffle = True)
images, labels = next(iter(trainloader))
imshow(torchvision.utils.make_grid(images, nrow=10))
print(images.size())