[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jabascal/notes-on-machine-learning/blob/master/notebook/exploratory_analysis_histograms.ipynb)

# Exploratory image analysis - Part 1 : Advanced density plots

Exploratory data analysis and visualization techniques are essential to get insight from the data. Unlock the full power of AI approaches by understanding and focusing on data quality!

## Introduction

Exploratory data analysis uses descriptive statistics and visualization techniques to provide insights on the data. Descriptive statistics aims at summarizing and analyzing data, and visualization techniques allows highlighting patterns, correlations, trends, outliers and errors in the data, as well as communicating the results. Typical visualization techniques include histograms, scatter plots, box plots, non-linear dimensionality reduction techniques, projection embeddings, dataset sprite plot and interactive version of these plots. In brief, descriptive statistics and visualization techniques are key to gain insights into the data characteristics and distribution.

In this article, we look at advanced density visualization techniques, which are useful to understand the data distribution, identify shifts on test data, and improve prediction errors. In specific, we focus on image data, CIFAR10 dataset, and the `joypy` library which yields advanced density plots. Let's get started!

## Installation

First, we'll install the necessary libraries within an environment. 

In [None]:
if True:
    !pip install matplotlib==3.8.2 \
        numpy==1.26.3 \
        pillow==10.2.0 \
        pandas==2.2.0 \
        seaborn==0.13.1 \
        joypy==0.2.6 \
        tensorflow

In [None]:
import os
import numpy as np
import pandas as pd
import seaborn as sns

import matplotlib.pyplot as plt
import PIL.Image as Image

np.random.seed(0)

## Data

We look at [CIFAR-10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html), which is a collection 60,000 images of 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images. We will use the testset. Images have 32x32 pixels and 3 channels (RGB).

### Data download and loading

Data can be download from the [CIFAR-10 website](https://www.cs.toronto.edu/~kriz/cifar.html) or [Kaggle](https://www.kaggle.com/c/cifar-10/data), but the simplest is to download the data using keras or pytorch. We show how to download the data using keras, which requires to install tensorflow.

In [None]:
from keras.datasets import cifar10

# Download the CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)

# We retain only the test set
images = x_test
labels_ids = y_test

We define CIFAR-10 labels, as given on the website. We also define different groups for the classes that are expected to present similarities.

In [None]:
# Channels and CIFAR-10 classes
channels = ['r', 'g', 'b']
cifar10_labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] 
cifar10_labels_ids = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
#cifar10_labels_idx = os.listdir(data_dir)

# Labels names
labels_names = [cifar10_labels[label_id[0]] for label_id in labels_ids]

# Define groups with similar classes
cifar10_groups = {'transport': ['airplane', 'ship', 'automobile', 'truck'],
                      'pet': ['cat', 'dog'], 
                       'wild': ['deer', 'horse', 'frog', 'bird']}

# Name for dataset (for saving results)
result_name = 'cifar10'

# Path for results
result_dir = '../Results/cifar10/data_anal/'    
if not os.path.exists(result_dir):
    os.makedirs(result_dir)

After reading the data, we display few random images and then a sprite of the dataset.  

In [None]:
def display_grid_images_labels(images, labels, dim_resize=None, num_subplots=(3, 3), 
                      figsize = (6, 6), path_file=None):
    """
    Load and display a grid of images from the specified list of image files.    
    """
    # Create a grid to display the images
    fig, axes = plt.subplots(num_subplots[0], num_subplots[1], figsize=figsize)
    axes = axes.flatten()  # Flatten the axes array
    # Load and display the selected images on the grid
    for i, (img, label) in enumerate(zip(images, labels)):        
        #ax = axes[i // num_subplots[0], i % num_subplots[1]]
        ax = axes[i]
        ax.imshow(img)
        ax.set_title(label)
        ax.axis('off')
    if path_file is not None:
        fig.savefig(path_file)
    else:
        plt.show()

num_selected = 16
images_selected = [images[i] for i in range(num_selected)]
labels_selected = [cifar10_labels[labels_ids[i][0]] for i in range(num_selected)]
file_save = os.path.join(result_dir, f'{result_name}_img.png')
display_grid_images_labels(images_selected, labels_selected, 
                           path_file=file_save, figsize=(5,6), num_subplots=(4,4))

To easily inspect an image:

In [None]:
ind = 1000
print(f'Label: {cifar10_labels[labels_ids[ind][0]]}')
Image.fromarray(images[ind]).show()

It is common to create a sprite plot of the dataset, which is a single image that contains all or a subset of the images in the dataset. 

In [None]:
def images_to_sprite(data, invert_colors=False):
    if len(data.shape) == 3:
        data = np.tile(data[...,np.newaxis], (1,1,1,3))
    data = data.astype(np.float32)
    min = np.min(data.reshape((data.shape[0], -1)), axis=1)
    data = (data.transpose(1,2,3,0) - min).transpose(3,0,1,2)
    max = np.max(data.reshape((data.shape[0], -1)), axis=1)
    data = (data.transpose(1,2,3,0) / max).transpose(3,0,1,2)
    # Inverting the colors seems to look better for MNIST
    if invert_colors:
       data = 1 - data

    n = int(np.ceil(np.sqrt(data.shape[0])))
    padding = ((0, n ** 2 - data.shape[0]), (0, 0),
            (0, 0)) + ((0, 0),) * (data.ndim - 3)
    data = np.pad(data, padding, mode='constant',
            constant_values=0)
    # Tile the individual thumbnails into an image.
    data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3)
            + tuple(range(4, data.ndim + 1)))
    data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:])
    data = (data * 255).astype(np.uint8)
    return data, n

# Subsample the dataset
num_selected = 25*25
images_2d = np.array(images[:num_selected,:,:,:]).reshape(-1, 32, 32, 3)

# Create the sprite image
sprite, n = images_to_sprite(images_2d)
sprite_path = os.path.join(result_dir, f'{result_name}_sprite.png')
Image.fromarray(sprite).save(sprite_path)
#Image.fromarray(sprite).show()
fig = plt.figure(figsize=(10,10))
plt.imshow(sprite)
plt.axis('off')

## Descriptive statistics

### Summary statistics

A summary statistics provides a quantitative summary of the data with a few numbers. It is fast way to get a general idea of the data. We use pandas to create a dataframe of the dataset. Then, we group the data by `class` and use the `describe` method to provide a summary of the central tendency, dispersion and shape of the dataset. 

For large image datasets, we would need to subsample the dataset/images or preprocess the data. 

In [None]:
def create_df_class(images, labels_ids, labels_names, channels, class_name='class'):
    """
    Given images array, labels ids array, labels names list and channels list, 
    create a dataframe for each class and then concantenate all dataframes.
    """
    for label_id in range(len(labels_names)):
        # Combine all pixels for same group
        images_subgroup = images[labels_ids.flatten() == label_id]
        images_subgroup = images_subgroup.reshape(-1, 3)
        
        # Create dataframe
        images_subgroup_df = pd.DataFrame(images_subgroup, columns=channels)
        images_subgroup_df[class_name] = labels_names[label_id]
        if label_id == 0:
            images_df = images_subgroup_df        
        else:
            images_df = pd.concat([images_df, images_subgroup_df], axis=0, ignore_index=True)

    return images_df

# Create a dataframe with all images
num_selected = 500
images_df = create_df_class(images[:num_selected,:,:,:], labels_ids[:num_selected], cifar10_labels, channels)
# print(images_df)

# Group by class
images_df_group = images_df.groupby('class')

# Display stats
stats = images_df_group.describe()
stats = stats.transpose()
stats

### Data visualization with common plots

Common plots that provide a general description of the data are histograms, scatter plots, box plots, and correlation plots. Pandas and seaborn are common libraries for these plots. Here, we focus on histograms. We will start by showing that it is quite tricky to get useful insight with two examples of these plots. Then, we will look at more advance or dedicated tools for displaying densities. 

**Histograms** provide a visual representation of the distribution of the data, as they represent the frequency of the values in the dataset. They hightlight the shape and skewness of the distribution and can help to indetify outliers. The bins are the intervals of the values. 

As a first example, we use *pandas* `hist` method to plot the histogram for one class. 

In [None]:
# Select class bird and plot histogram using pandas
class_selected = 'bird'
images_df_class = images_df[images_df['class'] == class_selected]
fig, axes = plt.subplots(1, 3, figsize=(7, 4))
for i, channel in enumerate(channels):
    images_df_class.hist(column=channel, ax=axes[i], bins=20, alpha=0.5)   

Other libraries such as *seaborn* yield higher quality plots. [Seaborn](https://seaborn.pydata.org/index.html) is a well known Python library based on *matplotlib* that provides a high-level interface for displaying statistical graphics. 

A useful plot to start the exploratory analysis is the **pairplot**, also known as scatterplot matrix. It creates a grid of scatter plots, where each variable in the dataset is plotted against each other. The diagonal plots are histograms of the corresponding variable. 

In [None]:
fig = sns.pairplot(images_df.sample(10000), hue='class' )
fig.savefig(os.path.join(result_dir, f'{result_name}_pairplot.png'))

Cannot see much? Well, I can't! 
- Histograms are cluttered!

- Scatter plots do not provide much information either on this case! Scatter plots are usefull to find relations between variables. In this case variables are RGB channels. 

Even though the plots are quite nice, they do not provide great inside for image data. We need another tool to come to the rescue for comparing histograms. 

### Advanced histogram plots

An usefull libray to plot histograms is [joypy](https://github.com/sessions/two-factor/mobile?auto=true). Joyplot is also a matplotlib- and pandas-based library for partially overlapping plots. 

In order to get further insight into the data, we separte plots by subclasses for clarity and for better identifying similar distributions. It may take a minute to estimate the density for each class, depending on the number of images.

In [None]:
import joypy

# Plot per subclass
for subclass in cifar10_groups.keys():
    images_subclass = images_df[images_df['class'].isin(cifar10_groups[subclass])]
    fig, axes = joypy.joyplot(images_subclass, 
                              legend=True,
                              color=['r', 'g', 'b'],
                              fade=True,
                              by='class', 
                              figsize=(6,3),)                       
    fig.savefig(os.path.join(result_dir, f'{result_name}_jp_{subclass}.png'))

Do you notice the similarities between classes? Cat and dog have very similar distributions, as well as automobile and truck, and deer and horse. We also notice that ship and bird are skewed to the left (left-tailed) because of the blue colors of sea and sky, respectively. These distributions opposed to the rest of the distributions and to the general distribution of the dataset, which are mostly right-tailed distributions. 

Skewness and kurtosis are two important measures of the shape of the distribution. Skewness refers to symmetry of distribution, while kurtosis refers to the tail of the distribution.

In [None]:
skewness = images_df_group.skew()
print(skewness)

We can compare to the distribution of the entire dataset.

In [None]:
# PLot histogram of all dataset
fig, axes = joypy.joyplot(images_df, 
                          legend=True,
                          color=['r', 'g', 'b'], 
                          figsize=(6,3))
fig.savefig(os.path.join(result_dir, f'{result_name}_jp.png'))

Previous metrics are global measures. We have already explored those across subclasses. We can also display histograms of individual images; for instance, the worst performing cases. 

In the following, we select few images for few subclasses. We can see the large variability across image samples. 

In [None]:
# Plot histogram for a new image
# cifar10_labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] 
# cifar10_labels_ids = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
num_images_new = 5
images_selected_classes = []
for class_id_new in [0, 5, 3]:
    class_id_new = int(class_id_new)
    images_selected = []
    labels_selected = []
    # Take few random images from the selected class
    labels_selected_idx = np.where(labels_ids == class_id_new)[0]    
    labels_selected_idx = np.random.choice(labels_selected_idx, num_images_new)
    for label_selected_idx in labels_selected_idx:
        image_new = images[label_selected_idx,:,:,:]
        images_selected.append(image_new)
        labels_selected.append(cifar10_labels[class_id_new] + f'_{label_selected_idx}')
    images_selected = np.array(images_selected)
    labels_selected = np.array(labels_selected)
    images_selected_classes.append(images_selected)

    # Create a dataframe with selectd images
    images_selected_df = create_df_class(images_selected, np.array(range(len(labels_selected))), labels_selected, channels)
    
    fig, axes = joypy.joyplot(images_selected_df, 
                            legend=True,
                            color=['r', 'g', 'b'], 
                            fade=True, 
                            by='class', 
                            figsize=(5,3))
    fig.savefig(os.path.join(result_dir, f'{result_name}_jp_new_{cifar10_labels[class_id_new]}.png'))
images_selected_classes = np.array(images_selected_classes)

And visualize the selected images using a sprite plot.

In [None]:
# Sprites of selected image
images_2d = np.array(images_selected_classes).reshape(-1, 32, 32, 3)
sprite, n = images_to_sprite(images_2d)
sprite_path = os.path.join(result_dir, f'{result_name}_selected_sprite.png')
Image.fromarray(sprite).save(sprite_path)
#Image.fromarray(sprite).show()
fig = plt.figure(figsize=(5,5))
plt.imshow(sprite)
plt.axis('off')

### Interactive plots

Several libraries provide interactive plots, such as [plotly](https://plotly.com/python), [bokeh](https://docs.bokeh.org/en/latest/index.html), [Vega-Altair](https://altair-viz.github.io), and [pygal](http://www.pygal.org/en/stable), among others. This would be subject of another notebook.

## Summary and discussion

 Histograms can provide first insights on the data, specially for applications for which different classes can have different range of values. Histograms are also key in machine learning, as all taks can be seen as explicitly or implicitly learning the distribution of the data. In addition, this analysis can be also usefull for error analysis; for instance, by looking at cases that are not well classified.
 
 In this dataset, we consider the channels as variables and the pixels as observations. This choice is an example for illustration purposes, and other more meaningful variables can be defined depending on the application. 

 A word of caution is raised when over-relying on summary statistics, providing global metrics [Alabi 2023]. However, they are a great starting point to get insight and they can be used to allert of potential shifts in the data. 

 In this notebook, we have introduced descriptive statistics and data exploration and have focussed on histograms. Other great techniques for image data are projection embeddings and non-linear dimensionality reduction techniques, which allow to visualize all data points in a 2D or 3D space. We will look at these techniques in future notebooks.

## References

Alex Krizhevsky, Learning Multiple Layers of Features from Tiny Images, Alex Krizhevsky, 2009. [Link](https://api.semanticscholar.org/CorpusID:18268744)

[Alabi 2023] Olubunmi Alabi and Tosin Bukola. Introduction to Descriptive statistics, Recent Advances in Biostatistics, 2023. [Link](https://www.intechopen.com/online-first/1141192)

[Kaur 2018] P Kaur et al. Descriptive statistics. International Journal of Academic Medicine 4(1):p 60-63, Jan–Apr, 2018.

[Nick 2007] T Nick. Topics in Biostatistics. Methods in Molecular Biology. Chapter 3: Descriptive statistics, 2007. [Link](https://www.researchgate.net/profile/Douglas-Case/publication/5402496_Power_and_Sample_Size/links/00b49539f39fea3b24000000/Power-and-Sample-Size.pdf?_sg%5B0%5D=started_experiment_milestone&origin=journalDetail#page=42)
