Old Documentation:

- [`import`](https://docs.python.org/3/reference/simple_stmts.html#the-import-statement)
- [`len`](https://docs.python.org/3/library/functions.html#len)
- [`numpy`](https://numpy.org/doc/1.19/user/whatisnumpy.html)
- [`numpy.array`](https://numpy.org/doc/stable/reference/generated/numpy.array.html)
- [numpy indexing](https://numpy.org/doc/stable/reference/arrays.indexing.html)
- [`torch`](https://pytorch.org/docs/stable/index.html)
- [`torch.Tensor`](https://pytorch.org/docs/stable/tensors.html#torch.Tensor)
- [`torch.utils.data`](https://pytorch.org/docs/stable/data.html#torch.utils.data)
- [`torch.utils.data.Dataset`](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset)

Import the numpy and pytorch (torch) modules.

In [1]:
import numpy as np
import torch

Create the sample dataset to be used in this notebook.

In [2]:
spectrograms  = np.load("./samples/spectrograms/linear/all_spectrograms.npy", allow_pickle=True)

**Step 1:** Code to be executed once at the beginning for initialization

In [3]:
def init_(spectrograms):
    
    # Create a global variable of with the spectrogram to be accessed in other functions
    global global_spectrograms
    global_spectrograms = spectrograms
    
    # Create a global variable to store the list of all (recording, timestep) index pairs.
    global global_index_map
    global_index_map = []
    
    # Use two for loops to create a list of all (recording, timestep) index pairs.
    for i, spectrogram in enumerate(spectrograms):
        for j, frame in enumerate(spectrogram):
            index_pair = (i, j)
            global_index_map.append(index_pair)
    
    # Create a global variable of the number of timesteps to be accessed in other functions
    global global_length
    global_length = len(global_index_map)
    
    return None

**Step 2:** Code to return the number of items as length

In [4]:
def len_():
    
    # Return the global variable of the number of timesteps in the spectrogram
    return global_length

**Step 3:** Code to return the x item at sample i, row j

In [5]:
def getitem_(index):
    
    # Get the recording index i and timestep index j that corresponds to the pair index
    i, j = global_index_map[index]
    
    # Index the global variable X using the recording index i and timestep index j
    frame = global_spectrograms[i][j, :]
    
    return frame

**Step 4:** Code to return the collated list of items

In [6]:
def collate_fn_(batch):
    
    # Index the global variable X at input index i    
    batch = torch.as_tensor(batch)
    
    return batch

Example of how to use init, len, getitem, and collate as functions

In [7]:
init_(spectrograms)

batch = []
for index in range(len_()):
    batch.append(getitem_(index))
    
batch = collate_fn_(batch)

Example of how to create a Dataset class using init, len, getitem and collate 

In [8]:
class ExampleDataset(torch.utils.data.Dataset):
    
    def __init__(self, spectrograms):
        
        ### Code to be executed once at the beginning for initialization
        self.spectrograms = spectrograms
        
        self.index_map = []
        for i, spectrogram in enumerate(spectrograms):
            for j, frame in enumerate(spectrogram):
                index_pair = (i, j)
                self.index_map.append(index_pair)
        
        self.length = len(self.index_map)
        
    def __len__(self):
        
        ### Return the number of items as length
        return self.length
    
    def __getitem__(self, index):
        
        ### Return one item at recording i, timestep j
        i, j = global_index_map[index]
        frame = global_X[i][j, :]

        return frame
    
    def collate_fn(batch):
        
        ### Specify how to collate list of items and what to return
        batch = torch.as_tensor(batch)

        return batch
        
    
dataset = ExampleDataset(spectrograms)