In [98]:
import pickle
import torch
# from torchvision import transforms # no torchvision because we are not working with images but with points
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from sklearn.model_selection import train_test_split

In [99]:
data_file = open("./extended_prepared_dataset.pickle", "rb")
data = pickle.load(data_file)
print(len(data))
print(data[0][1])
# the symbols are one-hot encoded; alpha is the first and epsilon is the last symbol
# first 60 symbols are alpha; second 60 are beta etc.

300
[1. 0. 0. 0. 0.]


In [100]:
random_seed = 69
test_split = 0.2
batch_size = 32

In [101]:
class CustomSymbolDataset(Dataset):
    def __init__(self, symbols, labels, transform=None, target_transform=None):
        """
            symbols - a list of tuples where the first 
            element is an array of normalized points and
            the second is the one-hot encoded vector of
            the respective symbol
        """
        self.symbols = symbols
        self.labels = labels
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        symbol = self.symbols[idx]
        label = self.labels[idx]
        if self.transform:
            symbol = self.transform(symbol)
        if self.target_transform:
            label = self.target_transform(label)
        return symbol, label
    
class ToTensor(object):
    """Convert numpy array in sample to Tensors."""
    def __call__(self, sample):
        return torch.from_numpy(sample)

In [102]:
# Split the data into a train and validation set
X_train, X_val, y_train, y_val = train_test_split(
    dataset.symbols, 
    dataset.labels, 
    test_size=0.2,
    random_state=random_seed,
    shuffle=True)

# Define the train dataset and dataloader
train_dataset = CustomSymbolDataset(X_train, y_train, transform=ToTensor(), target_transform=ToTensor())
train_dataloader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True, num_workers=0)

# Define the validation dataset and dataloader
val_dataset = CustomSymbolDataset(X_val, y_val, transform=ToTensor(), target_transform=ToTensor())
val_dataloader = DataLoader(dataset=val_dataset, batch_size=32, shuffle=False, num_workers=0)