In [1]:
# import libraries
import numpy as np
import torch

# import transformations and dataset/loader
import torchvision
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
import matplotlib_inline.backend_inline

matplotlib_inline.backend_inline.set_matplotlib_formats("svg")

In [2]:
# import dataset (comes with colab!)
data = np.loadtxt(open("../Datasets/mnist_train_small.csv", "rb"), delimiter=",")

# extract only the first 8
labels = data[:8, 0]
data = data[:8, 1:]

# normalize the data to a range of [0 1]
dataNorm = data / np.max(data)

# reshape to 2D!
dataNorm = dataNorm.reshape(dataNorm.shape[0], 1, 28, 28)

# check sizes
print(dataNorm.shape)
print(labels.shape)

# convert to torch tensor format
dataT = torch.tensor(dataNorm).float()
labelsT = torch.tensor(labels).long()

(8, 1, 28, 28)
(8,)


In [3]:
class customDataset(Dataset):
    def __init__(self, tensors, transform=None):
        # Check if tensors is a tuple of (data, labels) have the same length
        assert all(
            tensors[0].size(0) == t.size(0) for t in tensors
        ), "Size mismatch between tensors"
        
        # Assign inputs
        self.tensors = tensors
        self.transform = transform

    # What to do when someone wants and item from the dataset
    def __getitem__(self, index):
        if self.transform:
            x = self.transform(self.tensors[0][index])
        else:
            x = self.tensors[0][index]

        # And return labels
        y = self.tensors[1][index]
        return x, y # return a tuple of (data, label)
    
    # How many items are in the dataset
    def __len__(self):
        # Return the length of the dataset
        return self.tensors[0].size(0)