In [None]:
import os
import sys
import random

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

sys.path.append(os.path.join(os.getcwd(),'..', '..', '..'))
from settings import DATA_PATH
from src.data_processing.io import load_mnist

In [None]:
train, val, test = load_mnist(DATA_PATH)
x_train, y_train = train
x_val, y_val = val
x_test, y_test = test

# MNIST example

In [None]:
n_samples = 8
indices = random.sample(list(range(len(x_train))), n_samples)
fig, axes = plt.subplots(2, n_samples // 2, figsize=(16, 8))
for ax, im_idx in zip(axes.flatten(), indices):
    ax.imshow(x_train[im_idx].reshape([28, 28]), cmap='gray')
    ax.tick_params(axis=u'both', which=u'both', bottom=False, left=False, top=False, labelbottom=False, labelleft=False)
plt.plot()

# Data distribution

In [None]:
data_sizes = {'Train set': len(x_train), 'Validation set': len(x_val), 'Test set': len(x_test)}
data_sizes = pd.DataFrame(data_sizes, index=['Number of examples'])
data_sizes.head()

## Class distribution

In [None]:
train_labels, train_counts = np.unique(y_train, return_counts=True)
val_labels, val_counts = np.unique(y_val, return_counts=True)
test_labels, test_counts = np.unique(y_test, return_counts=True)

class_distribution = {'Train':{label: num for label, num in zip(train_labels, train_counts)},
                     'Validation':{label: num for label, num in zip(val_labels, val_counts)},
                     'Test':{label: num for label, num in zip(test_labels, test_counts)}}
class_distribution = pd.DataFrame(class_distribution, index=None).reset_index().melt('index', 
                                                                                     var_name='cols', value_name='vals')
class_distribution.columns = ['Label', 'Set', 'Number of examples']

In [None]:
_, ax = plt.subplots(1, 1, figsize=(12, 6))
sns.barplot(x='Label', y='Number of examples', hue='Set', data=class_distribution, ax=ax)
plt.show()