# Inspect Classification Training Dataset

This notebook is meant to be run after the classification dataset has been created but before training a classifier. Copy this notebook to the same folder as the classification dataset, for examples:

```
CameraTraps/
    classification/
        BASE_LOGDIR/
            classification_ds.csv
            inspect_dataset.ipynb  # COPY THIS NOTEBOOK TO HERE
            splits.json
```

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
!pwd

## Imports and Constants

In [None]:
import json
import os

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import pandas as pd
import seaborn as sns

from classification.train_utils import load_splits, plot_img_grid


disp_context = pd.option_context(
    'display.float_format', '{:0.2f}'.format,
    'display.max_rows', 1000)
sns.set(style='darkgrid')

In [None]:
SPLITS = ['train', 'val', 'test']
csv_path = 'classification_ds.csv'
splits_json_path = 'splits.json'

crops_dir = '/path/to/crops'

## Load dataset and splits files

In [None]:
df = pd.read_csv(csv_path, index_col=False, float_precision='high')

# merge dataset and location into a (dataset, location) tuple
df['dataset_location'] = list(zip(df['dataset'], df['location']))

label_order = sorted(df['label'].unique())
num_labels = len(label_order)

display(df.head())

In [None]:
split_to_locs = load_splits(splits_json_path)

loc_to_split = {}
for split, locs in split_to_locs.items():
    for loc in locs:
        loc_to_split[loc] = split

df['split'] = df['dataset_location'].map(loc_to_split.__getitem__)

## (Optional) Compare against another set of splits

In [None]:
def compare_splits(splits_json_path1: str, splits_json_path2: str,
                   name1: str = 'this', name2: str = 'other') -> None:
    """Compare the locations from two different splits.

    Args:
        splits_json_path[X]: str, path to splits.json
        name[X]: str, name to use for comparison
    """
    split_to_locs1 = load_splits(splits_json_path1)
    split_to_locs2 = load_splits(splits_json_path2)

    for split in SPLITS:
        print(f'{name1} # of {split} locs:', len(split_to_locs1[split]))
        print(f'{name2} # of {split} locs:', len(split_to_locs2[split]))
        print(f'number of overlap {split} locs:', len(split_to_locs1[split] & split_to_locs2[split]))
        print('===')

In [None]:
# compare_splits(splits_json_path, '/path/to/other/splits.json')

## Sample crops from each label

In [None]:
for label, group_df in df.groupby('label'):
    group_df = group_df.sample(5)
    imgs = []
    for file in group_df['path']:
        path = os.path.join(crops_dir, file)
        imgs.append(mpimg.imread(path))
    fig = plot_img_grid(imgs=imgs, row_h=3, col_w=3, ncols=5)
    print(label)
    display(group_df)
    display(fig)

## Crop confidence

In [None]:
df['confidence'].describe()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, num_labels/4), tight_layout=True)
sns.boxplot(data=df, y='label', x='confidence', ax=ax)

## View distribution of locations and labels by locations

In [None]:
locs_per_split = df.groupby('split')['dataset_location'].nunique()[SPLITS]
locs_per_split.loc['total'] = locs_per_split.sum()
display(locs_per_split.to_frame())

In [None]:
locations = (
    df.groupby(['label', 'split'])['dataset_location'].nunique()
    .unstack('split')[SPLITS]
    .fillna(0)
    .astype(int)
)
locations['total'] = locations.sum(axis=1)

locations_frac = locations[SPLITS].div(locations['total'], axis=0)
locations_all = pd.concat(
    [locations_frac, locations], axis=1,
    keys=['frac', 'counts'], sort=False)

with disp_context:
    display(locations_all)

In [None]:
# identify labels with extreme distributions
with disp_context:
    print('(test set < 5) or (test set < 10%)')
    mask = (locations_all.loc[:, ('counts', 'test')] < 5) | (locations_all.loc[:, ('frac', 'test')] < 0.1)
    display(locations_all.loc[mask])

    print('(val set < 5) or (val set < 10%)')
    mask = (locations_all.loc[:, ('counts', 'val')] < 5) | (locations_all.loc[:, ('frac', 'val')] < 0.1)
    display(locations_all.loc[mask])

    print('(train set < 10) or (train set < 40%)')
    mask = (locations_all.loc[:, ('counts', 'train')] < 10) | (locations_all.loc[:, ('frac', 'train')] < 0.4)
    display(locations_all.loc[mask])

## View distribution of labels by split

In [None]:
labels_dist = (
    df.groupby(['label', 'split']).size()
    .unstack('split')[SPLITS]
    .fillna(0)
    .astype(int)
)

labels_dist_with_total = labels_dist.copy()
labels_dist_with_total.loc['total'] = labels_dist.sum(axis=0)

labels_dist_frac = labels_dist_with_total.div(labels_dist_with_total.sum(axis=1), axis=0)

labels_dist_all = pd.concat([labels_dist_frac, labels_dist_with_total], axis=1,
                            keys=['frac', 'counts'], sort=False)
labels_dist_all.loc[:, ('counts', 'total')] = labels_dist_all.loc[:, 'counts'].sum(axis=1)

with disp_context:
    display(labels_dist_all)

In [None]:
# identify labels with extreme distributions
with disp_context:
    print('(test set < 300) and (test set < 9%)')
    test_mask = (labels_dist_all.loc[:, ('counts', 'test')] < 300) & (labels_dist_all.loc[:, ('frac', 'test')] < 0.09)
    print(test_mask.sum())

    print('(val set < 300) and (val set < 9%)')
    val_mask = (labels_dist_all.loc[:, ('counts', 'val')] < 300) & (labels_dist_all.loc[:, ('frac', 'val')] < 0.09)
    print(val_mask.sum())

    print('(train set < 1000) and (train set < 40%)')
    train_mask = (labels_dist_all.loc[:, ('counts', 'train')] < 1000) & (labels_dist_all.loc[:, ('frac', 'train')] < 0.4)
    print(train_mask.sum())
    # display(labels_dist_all.loc[train_mask])

    # combined
    print((train_mask | val_mask | test_mask).sum())
    display(labels_dist_all.loc[train_mask | val_mask | test_mask])

In [None]:
# approximate sample weights
sample_weights = len(df) / (df['label'].value_counts() * df['label'].nunique())
with disp_context:
    display(sample_weights)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, num_labels/2), tight_layout=True)
sns.countplot(y='label', hue='split', data=df, order=label_order, ax=ax, hue_order=SPLITS)

# roughly equivalent to:
# labels_dist.plot(kind='barh', figsize=(10, num_labels/2), width=0.8, ax=ax)
# ax.invert_yaxis()
# ax.grid(axis='y')
# ax.set_xlabel('count')

for i, p in enumerate(ax.patches):
    if i < len(ax.patches) / 3:
        ax.annotate(str(p.get_width()), (p.get_width() * 1.005, p.get_y() + 0.2))

In [None]:
# if necessary, zoom in the x-axis from the plot above
# fig, ax = plt.subplots(1, 1, figsize=(10, num_labels/2))
# ax = sns.countplot(data=df, y='label', hue='split', order=label_order, ax=ax, hue_order=SPLITS)
# ax.set_xlim(0, 5000)
# plt.show()

In [None]:
labels_dist_norm = labels_dist / labels_dist.sum(axis=0) * 100
with disp_context:
    display(labels_dist_norm)

labels_dist_norm = labels_dist_norm.stack('split').rename('% of split').reset_index()
fig, ax = plt.subplots(1, 1, figsize=(10, num_labels/2), tight_layout=True)
ax.set_title('How much each class contributes to each split')
sns.barplot(data=labels_dist_norm, y='label', x='% of split', hue='split', ax=ax)

## View distrbution of labels by split and dataset

In [None]:
# which datasets are represented in each split?
with disp_context:
    display(df.groupby(['label', 'split'])['dataset'].unique().unstack('split')[SPLITS])

In [None]:
labels_by_split_ds = df.groupby(['label', 'split', 'dataset']).size().rename('count')
with disp_context:
    display(labels_by_split_ds.unstack('split')[SPLITS].fillna(0).astype(int))

In [None]:
sns.catplot(data=labels_by_split_ds.reset_index(),
            x='count', y='label', hue='split', col='dataset',
            col_wrap=1, kind='bar', sharex=False)

## View distribution of labels by split, dataset, and location

For each label, dataset, and split:
* plot a histogram of the number of crops per location.

In [None]:
labels_by_split_ds_loc = df.groupby(['label', 'dataset', 'location', 'split']).size().rename('count').reset_index()
with disp_context:
    display(labels_by_split_ds_loc.head())
labels_by_split_ds_loc['split'] = labels_by_split_ds_loc['split'].astype('category')
sns.catplot(data=labels_by_split_ds_loc,
            col='label', y='dataset', x='count', hue='split',
            kind='strip', dodge=True,
            col_wrap=5, sharex=False, sharey=False)