In [32]:
import torch
import torchvision.transforms as tr
from torch.utils.data import Dataset, DataLoader
import numpy as np

In [33]:
train_images = np.random.randint(256, size=(100, 32, 32, 3))
train_labels = np.random.randint(2, size=(100, 1))

In [34]:
class MyDataset(Dataset):
    def __init__(self, x, y, transform=None) -> None:
        self.x = x
        self.y = y
        self.transform = transform
        self.len = len(y)

    def __getitem__(self, index):
        sample = self.x[index], self.y[index]
        if self.transform:
            sample = self.transform(sample)        
        return sample
    
    def __len__(self):
        return self.len

In [35]:
class ToTensor:
    def __call__(self, sample):
        inputs, labels = sample
        inputs = torch.FloatTensor(inputs)
        inputs = inputs.permute(2, 0, 1)
        return inputs, torch.LongTensor(labels)

In [36]:
class LinearTensor:
    def __init__(self, slope=1, bias=0) -> None:
        self.slope = slope
        self.bias = bias

    def __call__(self, sample):
        inputs, labels = sample
        inputs = self.slope * inputs + self.bias
        return inputs, labels

In [37]:
trans = tr.Compose([
    ToTensor(),
    LinearTensor(2, 5)
])

dataset = MyDataset(train_images, train_labels, transform=trans)
train_loader = DataLoader(dataset, batch_size=10, shuffle=True)


In [38]:
images, labels = next(iter(train_loader))
images.shape, labels.shape

(torch.Size([10, 3, 32, 32]), torch.Size([10, 1]))

In [60]:
class MyTransform:
    def __call__(self, sample):
        inputs, labels = sample
        inputs = torch.FloatTensor(inputs)
        inputs = inputs.permute(2, 0, 1)
        labels = torch.LongTensor(labels)
        
        trans = tr.Compose([
            tr.ToPILImage(),
            tr.Resize(128),
            tr.ToTensor(),
            tr.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
        return trans(inputs), labels

In [61]:
dataset = MyDataset(train_images, train_labels, transform=MyTransform())
train_loader = DataLoader(dataset, batch_size=15, shuffle=True)

In [62]:
len(dataset), len(train_loader)

(100, 7)

In [63]:
images, labels = next(iter(train_loader))
print(images.shape, labels.shape)

torch.Size([15, 3, 128, 128]) torch.Size([15, 1])
