# Creating custom Dataset Classes in PyTorch

The dataset and dataloader classes are two very important aspects of PyTorch.

In [1]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

## Creating the Dataset

The dataset class consists of three functions; \_\_init\_\_(), \_\_len\_\_(), and \_\_getitem\_\_(). As this is a python class each function has to have self as the first input.

### \_\_init\_\_()

This function can have any number of user defined inputs. So if you need to pass in the path to a data file or any options, this function is where that is done. The purpose of this function is to initialize the dataset with the correct data so that the below functions work properly. Any d ata that needs to be stored can be stored by adding the data to the variable self, for example: self.data = data, in this example the variable data will be stored into the class. Storing the data into the class allows the functions \_\_len\_\_() and \_\_getitem\_\_() to have access to the data without have to reload.


### \_\_len\_\_()

This function has no additional inputs besides the self class structure. The purpose of this function is to determine how many elements there are in the dataset. The easiest way to get this value is to calculate the length of either the labels or the data matrix.

### \_\_getitem\_\_()

This function has one additional input, this is the index value, or in other words, this is the sample being requested. So you will need to take the given index and return that element of the data and the labels. 


## Example Dataset for Class Project


In [2]:
class NaturalImages_Dataset(Dataset):
    def __init__(self, path_2_data, path_2_labels):
      
        # Load data from files
        # read data and cast it a numpy array
        data = np.load(path_2_data)
        labels = np.squeeze(np.load(path_2_labels))
        
        # Convert data & labels from numpy to PyTorch format
        data = torch.from_numpy(data)
        data = data.permute(1,0)        
        labels = torch.from_numpy(labels)        

        # Store data as part of the class        
        self.data = data
        self.labels = labels
        
    def __len__(self):
      
        # Calculate and return the number of samples in the dataset
        return(len(self.labels))
        
    def __getitem__(self, idx):
      
        # Return the requested index of the data set as dictionary containing 
        # the data and the label
        sample = dict()
        sample['index'] = idx
        sample['data'] = self.data[idx]
        sample['label'] = self.labels[idx]
        
        return sample

Now that we have defined the dataset class for the ```NaturalImages_Dataset``` data, we need to instantiate the class. I have set this dataset up to take in the location to the ```data_train.npy``` and ```labels_train.npy``` as inputs. Once initialized we can see that the data and labels properties of the ```NaturalImages_Dataset``` variable have the correct shape for the input data.

In [3]:
# Initialize the dataset
NaturalImages_data = NaturalImages_Dataset('data_train.npy', 'labels_train.npy')

print(NaturalImages_data.data.shape)
print(NaturalImages_data.labels.shape)

torch.Size([3124, 270000])
torch.Size([3124])


## Creating the dataloader

With the dataset now created we can use the dataloader. Here I will just display the data that has been returned from the dataloader. 

In [4]:
bSize = 10
NaturalImages_loader = DataLoader(NaturalImages_data, batch_size=bSize, shuffle=True)

for i, data in enumerate(NaturalImages_loader):
    batch_indeces = data['index']
    batch_data = data['data']
    batch_labels = data['label']
    
    print('Batch %d' % (i))
    print(batch_indeces)
    print(batch_labels)
    print(batch_data.shape)
    print()

Batch 0
tensor([ 591, 1706, 2081,  534, 1340,  472,  740,  499, 2939,  380])
tensor([ 2.,  1.,  9.,  8.,  1.,  3.,  5.,  9., 10.,  5.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 1
tensor([ 211,  985, 2394, 1840,  506,  450,  353,  172, 1033, 2206])
tensor([4., 9., 7., 7., 1., 3., 4., 3., 6., 3.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 2
tensor([ 356, 2393, 1107,  495, 2988, 1674, 3063, 1350,  734, 1545])
tensor([ 6.,  6.,  1., 10.,  2.,  7.,  2.,  5.,  2., 10.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 3
tensor([1993, 1593,  617,  715,  179, 2559, 2691, 2160, 1563, 1290])
tensor([8., 3., 4., 5., 8., 3., 5., 6., 6., 8.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 4
tensor([1399, 3041, 2983,   87,  668,  886, 1129,  255, 1665,  294])
tensor([ 7.,  8., 10.,  9.,  1.,  3.,  8.,  3.,  8.,  5.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 5
tensor([2607,  406, 3014, 2345, 1626,  797, 1351, 1490, 1499, 1253])
tensor([ 7.,  5.,  6.,  8., 

tensor([2749, 3027,  900, 1784, 1546,  850, 2107, 2031,  290, 1781])
tensor([7., 1., 9., 9., 4., 9., 9., 1., 3., 9.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 47
tensor([1841, 2470, 2995, 2836, 2783, 1361, 1650,  496,  667,  307])
tensor([ 4.,  8., 10.,  1.,  3.,  4.,  8.,  9.,  5.,  8.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 48
tensor([  79, 2476, 1491, 2402, 1547, 2421, 1417, 2489,  242, 3101])
tensor([ 2.,  8., 10.,  6.,  3.,  9.,  1.,  6.,  4.,  3.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 49
tensor([ 787, 2977, 2782,  596, 1398, 1890, 1700, 2665,  933, 2857])
tensor([ 5., 10.,  4.,  8.,  2.,  9.,  6.,  9.,  8.,  5.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 50
tensor([ 694,  538,  915, 1695, 2748, 2847,  793, 1994,  849,  618])
tensor([10.,  3.,  8.,  5., 10., 10.,  8.,  9.,  2.,  4.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 51
tensor([2248, 1358, 1992, 1474, 1508, 1082, 2058,  814, 1594,  422])
tensor([ 6.,  9., 10.

Batch 97
tensor([1262, 2154, 2587,  989, 2169,  272, 2148, 1189, 1658, 1541])
tensor([9., 6., 3., 4., 1., 9., 5., 9., 1., 9.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 98
tensor([2738, 2892, 1298,   26, 3069, 1767, 2937,  227,  326, 1998])
tensor([ 3.,  3.,  2., 10.,  9.,  1.,  1.,  5., 10.,  5.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 99
tensor([2698, 1332, 1585, 2942, 2428, 1702, 2437, 2358, 2589, 2009])
tensor([10.,  6.,  8.,  9., 10.,  4.,  6., 10.,  6.,  8.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 100
tensor([1958,  239,  871, 2047, 2730, 1495, 2704,   88, 2881, 2516])
tensor([ 7.,  7.,  5.,  2., 10.,  7.,  2.,  7.,  2., 10.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 101
tensor([1190, 1690,   39,  714, 2522, 2884, 2487,  486, 1276,  345])
tensor([ 2.,  3.,  4.,  3.,  3., 10.,  1., 10.,  5.,  9.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 102
tensor([1562, 1768,   29, 1970,  723, 2741,    6,  474, 2657,  300])
tensor([ 

Batch 143
tensor([1220, 1379, 2718, 1916,  264,  536, 2605, 2998,  455,  846])
tensor([6., 8., 2., 2., 5., 3., 9., 7., 9., 6.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 144
tensor([2498,  508, 2274, 2643,  516, 1114,  684,  379,  225, 1437])
tensor([ 4.,  6.,  5.,  3.,  7.,  5.,  2.,  6., 10.,  5.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 145
tensor([1482, 2517, 1975,  562, 1797,   11, 1581, 1803, 1921, 2193])
tensor([ 3.,  4.,  8.,  2.,  8.,  1.,  1.,  8., 10.,  5.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 146
tensor([2689, 1331, 3087, 1267, 2403, 2705, 2140, 2400,  918, 2548])
tensor([ 8.,  5.,  6.,  2.,  6., 10.,  8.,  3.,  4.,  1.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 147
tensor([2959,  224,  137, 2535,  831, 2372, 1122, 1462,  106,  314])
tensor([6., 2., 8., 6., 2., 2., 3., 2., 2., 9.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 148
tensor([  33, 3100, 1616,  142, 2578, 2580, 1847, 2459, 2151,  173])
tensor([ 8.,  1.

Batch 192
tensor([1982, 2506, 1364, 1538,  346, 1884, 1304,  960, 1819, 1947])
tensor([4., 1., 9., 1., 2., 5., 8., 5., 6., 7.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 193
tensor([1155, 2550, 2802, 2161, 1131,  285,   83, 1445, 1167, 1621])
tensor([2., 8., 9., 6., 8., 8., 1., 3., 6., 1.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 194
tensor([2770,   55, 1857,  330, 1228,  396, 1752,   38, 2588, 1909])
tensor([9., 5., 7., 7., 6., 6., 2., 1., 8., 1.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 195
tensor([ 520,  851,  497,  899, 1273, 3017, 2612, 2456,   71, 1791])
tensor([3., 5., 5., 7., 9., 7., 2., 8., 3., 3.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 196
tensor([1955, 3029, 1458,  869,  162,  768, 2069,   95,  870, 2889])
tensor([ 9.,  3., 10.,  5., 10.,  1., 10.,  7.,  8.,  9.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 197
tensor([1087, 1750, 1178, 2261,  561,  743, 2349, 1375, 2798, 1448])
tensor([ 8.,  2.,  2.,  3., 10.,  5.

Batch 240
tensor([2224,  950, 2444, 2473, 1922,  123, 2832, 2063, 2287, 2018])
tensor([ 6.,  2.,  5.,  2.,  9., 10.,  3.,  8.,  3.,  6.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 241
tensor([ 961, 1386,   62,  823,  184, 2784,  171, 2708,  371,  174])
tensor([ 4.,  1.,  8., 10.,  3.,  1.,  8., 10.,  6.,  9.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 242
tensor([1532,  577, 2329, 2622, 3077, 3012, 1668,  166, 2526, 2728])
tensor([6., 8., 3., 2., 8., 2., 3., 1., 8., 7.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 243
tensor([2596,   60, 1128, 2057, 2575, 2168, 2591, 2593, 1275, 2561])
tensor([ 7.,  2.,  5.,  6., 10.,  2.,  1.,  3., 10.,  9.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 244
tensor([2511, 1887, 1906, 1852,  354,  271, 2088, 2000, 1528, 1321])
tensor([10.,  8.,  5., 10.,  1.,  9., 10.,  5.,  2.,  5.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 245
tensor([2127, 2118, 1246, 2299, 1020, 2555, 2153,  663, 2638, 2351])
tensor

Batch 285
tensor([ 381, 2475, 2979,  447, 1573, 2046,  498,  701, 1801,  270])
tensor([2., 8., 9., 3., 9., 2., 3., 4., 9., 9.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 286
tensor([ 411,  607, 2432, 2395, 2332, 2707, 1740, 1699, 1954, 1578])
tensor([2., 7., 7., 1., 6., 8., 7., 3., 1., 2.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 287
tensor([1064, 2401, 1625,  263, 1152,  537, 2852,  213, 1722, 2117])
tensor([1., 7., 7., 5., 9., 9., 4., 7., 1., 7.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 288
tensor([ 868,  306,  212, 1225,  813,  476, 2228, 1627,  403, 1372])
tensor([7., 7., 5., 8., 9., 6., 3., 2., 2., 1.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 289
tensor([ 911, 1015, 2706, 1415, 3030, 2034, 3035,  259,  208,  801])
tensor([ 4.,  3.,  2.,  9.,  8.,  8.,  3.,  5., 10.,  4.], dtype=torch.float64)
torch.Size([10, 270000])

Batch 290
tensor([2244, 2021, 2776, 1896, 2714, 1493, 1075,  654, 1045,  436])
tensor([ 7.,  6.,  1., 10.,  9.,  7.