In [2]:
from math import sqrt
import matplotlib
SPINE_COLOR = 'gray'

def latexify(fig_width=None, fig_height=None, columns=1):
    """Set up matplotlib's RC params for LaTeX plotting.
    Call this before plotting a figure.

    Parameters
    ----------
    fig_width : float, optional, inches
    fig_height : float,  optional, inches
    columns : {1, 2}
    """

    # code adapted from http://www.scipy.org/Cookbook/Matplotlib/LaTeX_Examples

    # Width and max height in inches for IEEE journals taken from
    # computer.org/cms/Computer.org/Journal%20templates/transactions_art_guide.pdf

    assert(columns in [1,2])

    if fig_width is None:
        fig_width = 3.39 if columns==1 else 6.9 # width in inches
        #fig_width = 12 if columns==1 else 17 # width in inches

    if fig_height is None:
        golden_mean = (sqrt(5)-1.0)/2.0    # Aesthetic ratio
        fig_height = fig_width*golden_mean # height in inches

    MAX_HEIGHT_INCHES = 8.0
    if fig_height > MAX_HEIGHT_INCHES:
        print("WARNING: fig_height too large:" + fig_height + 
              "so will reduce to" + MAX_HEIGHT_INCHES + "inches.")
        fig_height = MAX_HEIGHT_INCHES

    params = {'backend': 'ps',
              'text.latex.preamble': ['\\usepackage{gensymb}'],
              'axes.labelsize': 8, # fontsize for x and y labels (was 10)
              'axes.titlesize': 8,
              'lines.linewidth': 0.5,
              'axes.linewidth': 0.0,
              #'text.fontsize': 8, # was 10
              'legend.fontsize': 8, # was 10
              'xtick.labelsize': 8,
              'ytick.labelsize': 8,
              'lines.markersize': 2,
              'text.usetex': True,
              'figure.figsize': [fig_width,fig_height],
              'font.family': 'serif'
    }

    matplotlib.rcParams.update(params)


def format_axes(ax):

    for spine in ['top', 'right']:
        ax.spines[spine].set_visible(False)

    for spine in ['left', 'bottom']:
        ax.spines[spine].set_color(SPINE_COLOR)
        ax.spines[spine].set_linewidth(0.5)

    ax.xaxis.set_ticks_position('bottom')
    ax.yaxis.set_ticks_position('left')

    for axis in [ax.xaxis, ax.yaxis]:
        axis.set_tick_params(direction='out', color=SPINE_COLOR)

    return ax

# latexify()


In [3]:
import matplotlib.pyplot as plt
from gorillatracker.classification.clustering import EXT_MERGED_DF

df = EXT_MERGED_DF
# Define the model and datasets
m = "ViT-Finetuned"
datasets = ["Bristol", "SPAC+min3", "SPAC+min3+max10"]

# Loop over datasets
for d in datasets:
    # Filter the dataframe based on model and dataset
    frame = df[(df["model"] == m) & (df["dataset"] == d)]

    # Count the number of unique labels
    num_labels = frame["label"].nunique()
    print(f"Dataset: {d}")
    print(f"Number of unique labels: {num_labels}")

    # Count the number of images per label
    label_counts = frame["label"].value_counts()

    # Print max and min number of images per label
    max_images = label_counts.max()
    min_images = label_counts.min()
    print(f"Max number of images per label: {max_images}")
    print(f"Min number of images per label: {min_images}")

    # Create a histogram over the number of images per label
    plt.figure(figsize=(8, 6))
    label_counts.plot(kind="bar")
    plt.title(f"Number of Images per Label for Dataset: {d}")
    plt.xlabel("Labels")
    plt.ylabel("Number of Images")
    plt.xticks(rotation=90)
    plt.show()

Dataset: Bristol
Number of unique labels: 7
Max number of images per label: 90
Min number of images per label: 25
Dataset: SPAC+min3
Number of unique labels: 110
Max number of images per label: 67
Min number of images per label: 3
Dataset: SPAC+min3+max10
Number of unique labels: 110
Max number of images per label: 10
Min number of images per label: 3


In [4]:
import pandas as pd
import matplotlib.pyplot as plt

# Define the model and datasets
m = "ViT-Finetuned"
datasets = ["Bristol", "SPAC+min3", "SPAC+min3max10"]

# Initialize a summary data list to store statistics
summary_data = []

# Prepare data for plotting
for d in datasets:
    # Filter the dataframe based on model and dataset
    frame = df[(df["model"] == m) & (df["dataset"] == d)]
    
    # Count the number of unique labels
    num_labels = frame["label"].nunique()
    
    # Count the number of images per label
    label_counts = frame["label"].value_counts()
    
    # Max and min number of images per label
    max_images = label_counts.max()
    min_images = label_counts.min()
    
    # Append statistics to the summary table
    summary_data.append({
        "Dataset": d,
        "Num Labels": num_labels,
        "Max Images per Label": max_images,
        "Min Images per Label": min_images
    })

# Create a summary table using Matplotlib
fig, ax = plt.subplots(figsize=(8, 4))
ax.axis('tight')
ax.axis('off')

# Prepare summary data for display
summary_table_data = [
    ["Dataset", "Num Labels", "Max Images per Label", "Min Images per Label"]
]
summary_table_data += [[d["Dataset"], d["Num Labels"], d["Max Images per Label"], d["Min Images per Label"]] for d in summary_data]

# Add table to the plot
ax.table(cellText=summary_table_data, loc='center')

plt.title("Dataset Label Statistics")
plt.show()

# Create a subplot for side-by-side histograms
fig, axs = plt.subplots(1, 3, figsize=(18, 6), sharey=True)

for i, d in enumerate(datasets):
    # Filter the dataframe based on model and dataset
    frame = df[(df["model"] == m) & (df["dataset"] == d)]
    
    # Count the number of images per label
    label_counts = frame["label"].value_counts()
    
    # Plot histogram on each subplot
    axs[i].bar(label_counts.index, label_counts.values)
    axs[i].set_title(f"Dataset: {d}")
    axs[i].set_xlabel("Labels")
    axs[i].set_ylabel("Number of Images")
    axs[i].tick_params(axis='x', rotation=90)

plt.tight_layout()
plt.show()

# Prepare the data for the box plot (distribution comparison)
label_counts_per_dataset = []
for d in datasets:
    frame = df[(df["model"] == m) & (df["dataset"] == d)]
    label_counts = frame["label"].value_counts()
    label_counts_per_dataset.append(label_counts.values)

# Create a box plot
plt.figure(figsize=(8, 6))
plt.boxplot(label_counts_per_dataset, labels=datasets)
plt.title("Distribution of Images per Label by Dataset")
plt.ylabel("Number of Images per Label")
plt.xlabel("Dataset")
plt.xticks(rotation=45)
plt.show()


  plt.boxplot(label_counts_per_dataset, labels=datasets)
