In [None]:
import os

import numpy as np
import matplotlib.pyplot as plt


In [None]:
# data import
data_dir = './data/'

training_data = np.load(os.path.join(data_dir, 'training.npy'))
training_labels = np.load(os.path.join(data_dir, 'training_labels.npy'))

test_data = np.load(os.path.join(data_dir, 'test.npy'))


In [None]:
# data information
unique_labels, label_counts = np.unique(training_labels, return_counts=True)

print(f'Training data shape:\t{training_data.shape}')
print(f'Training labels shape:\t{training_labels.shape}')
print('Label counts percent:\n' + '\n'.join([f'{label: >5}{count: >7}  {count/len(training_labels):.3}' for label, count in zip(unique_labels, label_counts)]))

print(f'Test data shape_\t{test_data.shape}')

In [None]:
# data samples
np.random.seed(42)
n_samples = (4, 6)

fig = plt.figure(figsize=(4*n_samples[1], 2*len(unique_labels)*n_samples[0]))
figs = fig.subfigures(nrows=len(unique_labels))
for row, label in zip(figs, unique_labels):
    
    row.suptitle(f'Label: {label}', size=24)
    
    all_label_idxs = np.argwhere(training_labels == label).flatten()
    chosen_label_idxs = np.random.choice(all_label_idxs, np.prod(n_samples), replace=False)
    
    axs = row.subplots(*n_samples)
    if not isinstance(axs, np.ndarray):
        axs = np.array(axs)
    
    for chosen_label_idx, ax in zip(chosen_label_idxs, axs.flatten()):
        
        ax.set_title(f'Img index: {chosen_label_idx}')
        ax.imshow(training_data[chosen_label_idx])