In [None]:
import numpy as np
import pandas as pd

# Imports
import utils
import os
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import seaborn as sns

from pathlib import Path
from glob import glob

notebook_path = Path(f"{os.path.dirname(os.path.realpath('99_descriptive_dataset_metrics.ipynb'))}")

# Data download / load from disk
wandb_entity = "..."                                                                                            # <-- Add your credentials for W&B here.
wandb_project = "..."                                                                                           # <-- Add your credentials for W&B here.
wandb_run_filter_keywords = ["baseline"]


In [None]:
# Seaborn configuration
# Alternative font: Linux Libertine
sns.set_theme(context="paper", style="whitegrid", palette="colorblind", font="Times New Roman", font_scale=3)
sns.color_palette(palette="colorblind")
sns.set(rc={"figure.figsize": (6, 4)})
sns.set(font_scale=1.8)

In [None]:
dataset_paths = {
	"blond": f"{notebook_path.parent}/pipelines/blond/data/processed/dirichlet",
	"mnist": f"{notebook_path.parent}/pipelines/mnist/data/processed/dirichlet",
	"shakespeare": f"{notebook_path.parent}/pipelines/shakespeare/data/processed/dirichlet"
}

df = pd.DataFrame()
idx = 0
for dataset_name, data_path in dataset_paths.items():
	files = os.listdir(data_path)
	for file in files:
		if file.endswith(".csv"):
			client_file = pd.read_csv(f"{data_path}/{file}")
			client_id = file.split("/")[-1].split(".")[0].split("_")[-1]
			for fold in ["train", "val", "test"]:
				length = len(client_file.loc[client_file["fold"] == fold])
				record = pd.DataFrame({"dataset": dataset_name, "client_id": client_id, "fold": fold, "records": length}, index=[idx])
				df = pd.concat([df, record])
				idx += 1

In [None]:
for dataset in ["blond", "mnist", "shakespeare"]:
	subset = df.loc[df["dataset"] == dataset]
	print(f"{dataset}: {subset.loc[subset['fold'] == 'train', 'records'].sum()}")
	ax = sns.violinplot(data=subset, x="fold", y="records")
	ax.set(ylabel="Samples", xlabel="Split")
	ax.set_xticklabels(["Train", "Validation", "Test"])

	if max(subset["records"]) > 1000:
		ax.set(ylabel="Samples (in 1,000)")
		tick_scaler = ticker.FuncFormatter(lambda x, pos: "{0:g}".format(x/1000))
		ax.yaxis.set_major_formatter(tick_scaler)
	elif max(subset["records"]) > 100:
		ax.set(ylabel="Samples (in 100)")
		tick_scaler = ticker.FuncFormatter(lambda x, pos: "{0:g}".format(x/100))
		ax.yaxis.set_major_formatter(tick_scaler)

	utils.write_figure_to_disk(plt, file_name=f"{dataset}_data_distribution", chapter_name="datasets")
	plt.show()