In [1]:
import numpy as np
import pandas as pd
from matplotlib.pyplot import cm
# import matplotlib.pyplot as plt
# from mpl_toolkits import mplot3d
import seaborn as sns
import h5py

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

In [2]:
path = '/home2/datasets/3dmnist'
with h5py.File(f'{path}/full_dataset_vectors.h5') as hf:
    X_train = hf["X_train"][:]
    y_train = hf["y_train"][:]    
    X_test = hf["X_test"][:]  
    y_test = hf["y_test"][:] 
X_train.shape, y_train.shape, X_test.shape, y_test.shape

  


((10000, 4096), (10000,), (2000, 4096), (2000,))

In [3]:
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.15,
                                                                    random_state=1, shuffle=True)
X_train.shape, X_val.shape, y_train.shape, y_val.shape

((8500, 4096), (1500, 4096), (8500,), (1500,))

In [4]:
transforms.ToTensor()(OneHotEncoder().fit_transform(y_val.reshape(-1, 1)).toarray()).shape

torch.Size([1, 1500, 10])

In [8]:
class MNIST3dDataset(Dataset):
    '''
    '''
    def __init__(self, X_data, y_data):
        '''
        add RGB dimension and reshape to 1 + 4D
        convert labels to ohe
        '''
        self.X_data = self.add_rgb_to_data(X_data)
        #self.X_data = self.X_data.reshape(-1,16,16,16,3)
        self.y_data = OneHotEncoder().fit_transform(y_data.reshape(-1, 1))
    
    def __getitem__(self, idx):
        X_data = transforms.ToTensor()(self.X_data[idx]).squeeze(0)
        y_data = transforms.ToTensor()(self.y_data[idx].toarray()).squeeze(0).squeeze(0)
        return X_data, y_data
               
    
    def __len__(self):
        return self.X_data.shape[0]
    
    def add_rgb_dimention(self, array):
        '''
        translate data to color
        '''
        scaler_map = cm.ScalarMappable(cmap="Oranges")
        array = scaler_map.to_rgba(array)[:, : -1]
        return array
    
    def add_rgb_to_data(self, data):
        '''
        iterate dataset, add rgb dimension
        '''
        data_w_rgb = np.ndarray((data.shape[0], data.shape[1], 3))
        for i in range(data.shape[0]):
            data_w_rgb[i] = self.add_rgb_dimention(data[i])
        return data_w_rgb

def make_dataloaders(batch_size=64, n_workers=4, pin_memory=True, **kwargs):
    '''
    A handy function to make our dataloaders
    '''
    
    dataset = MNIST3dDataset(**kwargs)
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers,
                            pin_memory=pin_memory)
    return dataloader

In [9]:
# train_dl = make_dataloaders(X_data=X_train, y_data=y_train)
val_dl = make_dataloaders(X_data=X_val, y_data=y_val)

In [10]:
for X, y in val_dl:
    print(X.shape, y.shape)

torch.Size([64, 4096, 3]) torch.Size([64, 10])
torch.Size([64, 4096, 3]) torch.Size([64, 10])
torch.Size([64, 4096, 3]) torch.Size([64, 10])
torch.Size([64, 4096, 3]) torch.Size([64, 10])
torch.Size([64, 4096, 3]) torch.Size([64, 10])
torch.Size([64, 4096, 3]) torch.Size([64, 10])
torch.Size([64, 4096, 3]) torch.Size([64, 10])
torch.Size([64, 4096, 3]) torch.Size([64, 10])
torch.Size([64, 4096, 3]) torch.Size([64, 10])
torch.Size([64, 4096, 3]) torch.Size([64, 10])
torch.Size([64, 4096, 3]) torch.Size([64, 10])
torch.Size([64, 4096, 3]) torch.Size([64, 10])
torch.Size([64, 4096, 3]) torch.Size([64, 10])
torch.Size([64, 4096, 3]) torch.Size([64, 10])
torch.Size([64, 4096, 3]) torch.Size([64, 10])
torch.Size([64, 4096, 3]) torch.Size([64, 10])
torch.Size([64, 4096, 3]) torch.Size([64, 10])
torch.Size([64, 4096, 3]) torch.Size([64, 10])
torch.Size([64, 4096, 3]) torch.Size([64, 10])
torch.Size([64, 4096, 3]) torch.Size([64, 10])
torch.Size([64, 4096, 3]) torch.Size([64, 10])
torch.Size([6