# Plot heatmaps
We plot various heatmaps of trial and trial site numbers per phase for different stratifications.

In [None]:
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns
sns.set(color_codes=True)
sns.set_style("whitegrid", {"grid.color": "gainsboro"})

from tools import visualization

In [None]:
# To suppress division by zero warning when using a log scale
import warnings
warnings.filterwarnings('ignore')

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%load_ext blackcellmagic

## Load data

In [None]:
grouper_column_names = [
    "country_ISO",
    "country_continent",
    "subregion",
    "economy_level",
    "consolidated_economy_level",
    "income_group",
    "consolidated_income_group",
    "hdi_category",
]

In [None]:
consolidated_counts_overall = {}
for grouper_column_name in grouper_column_names:
    consolidated_counts_overall[grouper_column_name] = pd.read_excel(
        "data/results/trials_sites_counts.xlsx",
        sheet_name="ovr_" + grouper_column_name,
    )

In [None]:
consolidated_counts_overall["country_continent"]

In [None]:
consolidated_counts_per_phase = {}
for grouper_column_name in grouper_column_names:
    consolidated_counts_per_phase[grouper_column_name] = pd.read_excel(
        "data/results/trials_sites_counts.xlsx",
        sheet_name="phs_" + grouper_column_name,
    )

In [None]:
consolidated_counts_per_phase["country_continent"].head()

## Heatmaps for trials

### Trial counts

In [None]:
list(consolidated_counts_per_phase)

#### Continent level

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

g_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["country_continent"],
    index_column_name="country_continent",
    value_column_name="n_trials",
    annotation_format="g",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0],
)
g_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["country_continent"],
    index_column_name="country_continent",
    value_column_name="of_total_trials_phase",
    annotation_format=".2%",
    phase_column_name="phase",
    cbar=False,
    ax=axes[1],
)
g_left.set_title("A) Number of trials with at least one site in a given continent")
g_right.set_title("B) Percentage of trials with at least one site in a given continent")
fig.tight_layout()
plt.show()

#### Subregion level

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

g_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["subregion"],
    index_column_name="subregion",
    value_column_name="n_trials",
    annotation_format="g",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0],
)
g_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["subregion"],
    index_column_name="subregion",
    value_column_name="of_total_trials_phase",
    annotation_format=".2%",
    phase_column_name="phase",
    cbar=False,
    ax=axes[1],
)
g_left.set_title("A) Number of trials with at least one site in a given region")
g_right.set_title("B) Percentage of trials with at least one site in a given region")
fig.tight_layout()
plt.show()

#### Economy level

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

g_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["consolidated_economy_level"],
    index_column_name="consolidated_economy_level",
    value_column_name="n_trials",
    annotation_format="g",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0],
)
g_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["consolidated_economy_level"],
    index_column_name="consolidated_economy_level",
    value_column_name="of_total_trials_phase",
    annotation_format=".2%",
    phase_column_name="phase",
    cbar=False,
    ax=axes[1],
)
g_left.set_title("A) Number of trials with at least one site in a given development region")
g_right.set_title("B) Percentage of trials with at least one site in a given development region")
fig.tight_layout()
plt.show()

#### Income level

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

g_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["consolidated_income_group"],
    index_column_name="consolidated_income_group",
    value_column_name="n_trials",
    annotation_format="g",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0],
)
g_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["consolidated_income_group"],
    index_column_name="consolidated_income_group",
    value_column_name="of_total_trials_phase",
    annotation_format=".2%",
    phase_column_name="phase",
    cbar=False,
    ax=axes[1],
)
g_left.set_title("A) Number of trials with at least one site in a given income class")
g_right.set_title("B) Percentage of trials with at least one site in a given income class")
fig.tight_layout()
plt.show()

#### Human Development Index

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

g_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["hdi_category"],
    index_column_name="hdi_category",
    value_column_name="n_trials",
    annotation_format="g",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0],
)
g_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["hdi_category"],
    index_column_name="hdi_category",
    value_column_name="of_total_trials_phase",
    annotation_format=".2%",
    phase_column_name="phase",
    cbar=False,
    ax=axes[1],
)
g_left.set_title("A) Number of trials with at least one site in a given HDI class")
g_right.set_title("B) Percentage of trials with at least one site in a given HDI class")
fig.tight_layout()
plt.show()

### Trials, expected trials, and deviation

#### Continent

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

g_top_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["country_continent"],
    index_column_name="country_continent",
    value_column_name="n_trials",
    annotation_format="g",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0][0],
)
g_top_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["country_continent"],
    index_column_name="country_continent",
    value_column_name="n_trials_expected",
    annotation_format=".2f",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0][1],
)

g_bottom_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["country_continent"],
    index_column_name="country_continent",
    value_column_name="factor_deviation_n_trials_from_expected",
    annotation_format=".2f",
    phase_column_name="phase",
    linear_diverging=True,
    cbar=False,
    ax=axes[1][0],
)
g_bottom_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["country_continent"],
    index_column_name="country_continent",
    value_column_name="factor_deviation_n_trials_from_expected",
    annotation_format=".2f",
    phase_column_name="phase",
    log_scaled=True,
    cbar=False,
    ax=axes[1][1],
)

g_top_left.set_title("A) Number of trials with at least one site in a given continent")
g_top_right.set_title("B) Expected number of trials with at least one site in a given continent")
g_bottom_left.set_title("C) Disproportionality of actual number of trials vs. expected")
g_bottom_right.set_title("D) Log-disproportionality of actual number of trials vs. expected")
fig.tight_layout()
plt.show()

#### Subregion level

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

g_top_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["subregion"],
    index_column_name="subregion",
    value_column_name="n_trials",
    annotation_format="g",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0][0],
)
g_top_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["subregion"],
    index_column_name="subregion",
    value_column_name="n_trials_expected",
    annotation_format=".2f",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0][1],
)

g_bottom_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["subregion"],
    index_column_name="subregion",
    value_column_name="factor_deviation_n_trials_from_expected",
    annotation_format=".2f",
    phase_column_name="phase",
    linear_diverging=True,
    cbar=False,
    ax=axes[1][0],
)
g_bottom_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["subregion"],
    index_column_name="subregion",
    value_column_name="factor_deviation_n_trials_from_expected",
    annotation_format=".2f",
    phase_column_name="phase",
    log_scaled=True,
    cbar=False,
    ax=axes[1][1],
)

g_top_left.set_title("A) Number of trials with at least one site in a given region")
g_top_right.set_title("B) Expected number of trials with at least one site in a given region")
g_bottom_left.set_title("C) Disproportionality of actual number of trials vs. expected")
g_bottom_right.set_title("D) Log-disproportionality of actual number of trials vs. expected")
fig.tight_layout()
plt.show()

#### Economy level

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

g_top_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["consolidated_economy_level"],
    index_column_name="consolidated_economy_level",
    value_column_name="n_trials",
    annotation_format="g",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0][0],
)
g_top_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["consolidated_economy_level"],
    index_column_name="consolidated_economy_level",
    value_column_name="n_trials_expected",
    annotation_format=".2f",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0][1],
)

g_bottom_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["consolidated_economy_level"],
    index_column_name="consolidated_economy_level",
    value_column_name="factor_deviation_n_trials_from_expected",
    annotation_format=".2f",
    phase_column_name="phase",
    linear_diverging=True,
    cbar=False,
    ax=axes[1][0],
)
g_bottom_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["consolidated_economy_level"],
    index_column_name="consolidated_economy_level",
    value_column_name="factor_deviation_n_trials_from_expected",
    annotation_format=".2f",
    phase_column_name="phase",
    log_scaled=True,
    cbar=False,
    ax=axes[1][1],
)

g_top_left.set_title("A) Number of trials with at least one site in given development region")
g_top_right.set_title("B) Expected number of trials with at least one site in a given development region")
g_bottom_left.set_title("C) Disproportionality of actual number of trials vs. expected")
g_bottom_right.set_title("D) Log-disproportionality of actual number of trials vs. expected")
fig.tight_layout()
plt.show()

#### Income level

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

g_top_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["consolidated_income_group"],
    index_column_name="consolidated_income_group",
    value_column_name="n_trials",
    annotation_format="g",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0][0],
)
g_top_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["consolidated_income_group"],
    index_column_name="consolidated_income_group",
    value_column_name="n_trials_expected",
    annotation_format=".2f",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0][1],
)

g_bottom_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["consolidated_income_group"],
    index_column_name="consolidated_income_group",
    value_column_name="factor_deviation_n_trials_from_expected",
    annotation_format=".2f",
    phase_column_name="phase",
    linear_diverging=True,
    cbar=False,
    ax=axes[1][0],
)
g_bottom_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["consolidated_income_group"],
    index_column_name="consolidated_income_group",
    value_column_name="factor_deviation_n_trials_from_expected",
    annotation_format=".2f",
    phase_column_name="phase",
    log_scaled=True,
    cbar=False,
    ax=axes[1][1],
)

g_top_left.set_title("A) Number of trials with at least one site in a given income group")
g_top_right.set_title("B) Expected number of trials with at least one site in a given income group")
g_bottom_left.set_title("C) Disproportionality of actual number of trials vs. expected")
g_bottom_right.set_title("D) Log-disproportionality of actual number of trials vs. expected")
fig.tight_layout()
plt.show()

#### Human Development Index

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

g_top_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["hdi_category"],
    index_column_name="hdi_category",
    value_column_name="n_trials",
    annotation_format="g",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0][0],
)
g_top_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["hdi_category"],
    index_column_name="hdi_category",
    value_column_name="n_trials_expected",
    annotation_format=".2f",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0][1],
)

g_bottom_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["hdi_category"],
    index_column_name="hdi_category",
    value_column_name="factor_deviation_n_trials_from_expected",
    annotation_format=".2f",
    phase_column_name="phase",
    linear_diverging=True,
    cbar=False,
    ax=axes[1][0],
)
g_bottom_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["hdi_category"],
    index_column_name="hdi_category",
    value_column_name="factor_deviation_n_trials_from_expected",
    annotation_format=".2f",
    phase_column_name="phase",
    log_scaled=True,
    cbar=False,
    ax=axes[1][1],
)

g_top_left.set_title("A) Number of trials with at least one site in a given HDI category")
g_top_right.set_title("B) Expected number of trials with at least one site in a given HDI category")
g_bottom_left.set_title("C) Disproportionality of actual number of trials vs. expected")
g_bottom_right.set_title("D) Log-disproportionality of actual number of trials vs. expected")
fig.tight_layout()
plt.show()

### Trials per capita

#### Continent level

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

g_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["country_continent"],
    index_column_name="country_continent",
    value_column_name="trials_per_capita",
    scale_factor=1_000_000,
    annotation_format=".2f",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0],
)
g_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["country_continent"],
    index_column_name="country_continent",
    value_column_name="trials_per_capita",
    scale_factor=1_000_000_000,
    annotation_format=".2f",
    phase_column_name="phase",
    cbar=False,
    ax=axes[1],
)
g_left.set_title("A) Number of trials per million (" + r"$10^{6}$" + ") with at least one site in a given continent")
g_right.set_title("B) Number of trials per billion (" + r"$10^{9}$" + ") with at least one site in a given continent")
fig.tight_layout()
plt.show()

#### Subregion level

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

g_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["subregion"],
    index_column_name="subregion",
    value_column_name="trials_per_capita",
    scale_factor=1_000_000,
    annotation_format=".2f",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0],
)
g_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["subregion"],
    index_column_name="subregion",
    value_column_name="trials_per_capita",
    scale_factor=1_000_000_000,
    annotation_format=".2f",
    phase_column_name="phase",
    cbar=False,
    ax=axes[1],
)
g_left.set_title("A) Number of trials per million (" + r"$10^{6}$" + ") with at least one site in a given region")
g_right.set_title("B) Number of trials per billion (" + r"$10^{9}$" + ") with at least one site in a given region")
fig.tight_layout()
plt.show()

#### Economy level

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

g_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["consolidated_economy_level"],
    index_column_name="consolidated_economy_level",
    value_column_name="trials_per_capita",
    scale_factor=1_000_000,
    annotation_format=".2f",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0],
)
g_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["consolidated_economy_level"],
    index_column_name="consolidated_economy_level",
    value_column_name="trials_per_capita",
    scale_factor=1_000_000_000,
    annotation_format=".2f",
    phase_column_name="phase",
    cbar=False,
    ax=axes[1],
)
g_left.set_title("A) Number of trials per million (" + r"$10^{6}$" + ") with at least one site in a given development region")
g_right.set_title("B) Number of trials per billion (" + r"$10^{9}$" + ") with at least one site in a given development region")
fig.tight_layout()
plt.show()

#### Income group

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

g_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["consolidated_income_group"],
    index_column_name="consolidated_income_group",
    value_column_name="trials_per_capita",
    scale_factor=1_000_000,
    annotation_format=".2f",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0],
)
g_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["consolidated_income_group"],
    index_column_name="consolidated_income_group",
    value_column_name="trials_per_capita",
    scale_factor=1_000_000_000,
    annotation_format=".2f",
    phase_column_name="phase",
    cbar=False,
    ax=axes[1],
)
g_left.set_title("A) Number of trials per million (" + r"$10^{6}$" + ") with at least one site in a given income group")
g_right.set_title("B) Number of trials per billion (" + r"$10^{9}$" + ") with at least one site in a given income group")
fig.tight_layout()
plt.show()

#### HDI category

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

g_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["hdi_category"],
    index_column_name="hdi_category",
    value_column_name="trials_per_capita",
    scale_factor=1_000_000,
    annotation_format=".2f",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0],
)
g_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["hdi_category"],
    index_column_name="hdi_category",
    value_column_name="trials_per_capita",
    scale_factor=1_000_000_000,
    annotation_format=".2f",
    phase_column_name="phase",
    cbar=False,
    ax=axes[1],
)
g_left.set_title("A) Number of trials per million (" + r"$10^{6}$" + ") with at least one site in a given HDI category")
g_right.set_title("B) Number of trials per billion (" + r"$10^{9}$" + ") with at least one site in a given HDI category")
fig.tight_layout()
plt.show()

## Heatmaps for trial sites

### Site counts

In [None]:
list(consolidated_counts_per_phase["country_ISO"].columns)

#### Continent level

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

g_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["country_continent"],
    index_column_name="country_continent",
    value_column_name="n_sites",
    annotation_format="g",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0],
)
g_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["country_continent"],
    index_column_name="country_continent",
    value_column_name="of_total_sites_phase",
    annotation_format=".2%",
    phase_column_name="phase",
    cbar=False,
    ax=axes[1],
)
g_left.set_title("A) Number of trial sites per continent")
g_right.set_title("B) Distribution of trial sites")
fig.tight_layout()
plt.show()

#### Subregion level

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

g_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["subregion"],
    index_column_name="subregion",
    value_column_name="n_sites",
    annotation_format="g",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0],
)
g_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["subregion"],
    index_column_name="subregion",
    value_column_name="of_total_sites_phase",
    annotation_format=".2%",
    phase_column_name="phase",
    cbar=False,
    ax=axes[1],
)
g_left.set_title("A) Number of trial sites per region")
g_right.set_title("B) Distribution of trial sites")
fig.tight_layout()
plt.show()

#### Economy level

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

g_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["consolidated_economy_level"],
    index_column_name="consolidated_economy_level",
    value_column_name="n_sites",
    annotation_format="g",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0],
)
g_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["consolidated_economy_level"],
    index_column_name="consolidated_economy_level",
    value_column_name="of_total_sites_phase",
    annotation_format=".2%",
    phase_column_name="phase",
    cbar=False,
    ax=axes[1],
)
g_left.set_title("A) Number of trial sites per development region")
g_right.set_title("B) Distribution of trial sites")
fig.tight_layout()
plt.show()

#### Income level

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

g_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["consolidated_income_group"],
    index_column_name="consolidated_income_group",
    value_column_name="n_sites",
    annotation_format="g",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0],
)
g_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["consolidated_income_group"],
    index_column_name="consolidated_income_group",
    value_column_name="of_total_sites_phase",
    annotation_format=".2%",
    phase_column_name="phase",
    cbar=False,
    ax=axes[1],
)
g_left.set_title("A) Number of trial sites per income class")
g_right.set_title("B) Distribution of trial sites")
fig.tight_layout()
plt.show()

#### Human Development Index

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

g_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["hdi_category"],
    index_column_name="hdi_category",
    value_column_name="n_sites",
    annotation_format="g",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0],
)
g_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["hdi_category"],
    index_column_name="hdi_category",
    value_column_name="of_total_sites_phase",
    annotation_format=".2%",
    phase_column_name="phase",
    cbar=False,
    ax=axes[1],
)
g_left.set_title("A) Number of trial sites per HDI category")
g_right.set_title("B) Distribution of trial sites")
fig.tight_layout()
plt.show()

### Trial sites, expected trial sites, and deviation

#### Continent

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

g_top_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["country_continent"],
    index_column_name="country_continent",
    value_column_name="n_sites",
    annotation_format="g",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0][0],
)
g_top_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["country_continent"],
    index_column_name="country_continent",
    value_column_name="n_sites_expected",
    annotation_format=".2f",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0][1],
)

g_bottom_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["country_continent"],
    index_column_name="country_continent",
    value_column_name="factor_deviation_n_sites_from_expected",
    annotation_format=".2f",
    phase_column_name="phase",
    linear_diverging=True,
    cbar=False,
    ax=axes[1][0],
)
g_bottom_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["country_continent"],
    index_column_name="country_continent",
    value_column_name="factor_deviation_n_sites_from_expected",
    annotation_format=".2f",
    phase_column_name="phase",
    log_scaled=True,
    cbar=False,
    ax=axes[1][1],
)

g_top_left.set_title("A) Number of trial sites per continent")
g_top_right.set_title("B) Expected number of trial sites per continent")
g_bottom_left.set_title("C) Disproportionality of actual number of trial sites vs. expected")
g_bottom_right.set_title("D) Log-disproportionality of actual number of trial sites vs. expected")
fig.tight_layout()
plt.show()

#### Subregion level

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

g_top_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["subregion"],
    index_column_name="subregion",
    value_column_name="n_sites",
    annotation_format="g",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0][0],
)
g_top_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["subregion"],
    index_column_name="subregion",
    value_column_name="n_sites_expected",
    annotation_format=".2f",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0][1],
)

g_bottom_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["subregion"],
    index_column_name="subregion",
    value_column_name="factor_deviation_n_sites_from_expected",
    annotation_format=".2f",
    phase_column_name="phase",
    linear_diverging=True,
    cbar=False,
    ax=axes[1][0],
)
g_bottom_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["subregion"],
    index_column_name="subregion",
    value_column_name="factor_deviation_n_sites_from_expected",
    annotation_format=".2f",
    phase_column_name="phase",
    log_scaled=True,
    cbar=False,
    ax=axes[1][1],
)

g_top_left.set_title("A) Number of trial sites per region")
g_top_right.set_title("B) Expected number of trial sites per region")
g_bottom_left.set_title("C) Disproportionality of actual number of trial sites vs. expected")
g_bottom_right.set_title("D) Log-disproportionality of actual number of trial sites vs. expected")
fig.tight_layout()
plt.show()

#### Economy level

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

g_top_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["consolidated_economy_level"],
    index_column_name="consolidated_economy_level",
    value_column_name="n_sites",
    annotation_format="g",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0][0],
)
g_top_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["consolidated_economy_level"],
    index_column_name="consolidated_economy_level",
    value_column_name="n_sites_expected",
    annotation_format=".2f",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0][1],
)

g_bottom_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["consolidated_economy_level"],
    index_column_name="consolidated_economy_level",
    value_column_name="factor_deviation_n_sites_from_expected",
    annotation_format=".2f",
    phase_column_name="phase",
    linear_diverging=True,
    cbar=False,
    ax=axes[1][0],
)
g_bottom_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["consolidated_economy_level"],
    index_column_name="consolidated_economy_level",
    value_column_name="factor_deviation_n_sites_from_expected",
    annotation_format=".2f",
    phase_column_name="phase",
    log_scaled=True,
    cbar=False,
    ax=axes[1][1],
)

g_top_left.set_title("A) Number of trial sites per development region")
g_top_right.set_title("B) Expected number of trial sites per development region")
g_bottom_left.set_title("C) Disproportionality of actual number of trial sites vs. expected")
g_bottom_right.set_title("D) Log-disproportionality of actual number of trial sites vs. expected")
fig.tight_layout()
plt.show()

#### Income level

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

g_top_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["consolidated_income_group"],
    index_column_name="consolidated_income_group",
    value_column_name="n_sites",
    annotation_format="g",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0][0],
)
g_top_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["consolidated_income_group"],
    index_column_name="consolidated_income_group",
    value_column_name="n_sites_expected",
    annotation_format=".2f",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0][1],
)

g_bottom_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["consolidated_income_group"],
    index_column_name="consolidated_income_group",
    value_column_name="factor_deviation_n_sites_from_expected",
    annotation_format=".2f",
    phase_column_name="phase",
    linear_diverging=True,
    cbar=False,
    ax=axes[1][0],
)
g_bottom_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["consolidated_income_group"],
    index_column_name="consolidated_income_group",
    value_column_name="factor_deviation_n_sites_from_expected",
    annotation_format=".2f",
    phase_column_name="phase",
    log_scaled=True,
    cbar=False,
    ax=axes[1][1],
)

g_top_left.set_title("A) Number of trial sites per income group")
g_top_right.set_title("B) Expected number of trial sites per income group")
g_bottom_left.set_title("C) Disproportionality of actual number of trial sites vs. expected")
g_bottom_right.set_title("D) Log-disproportionality of actual number of trial sites vs. expected")
fig.tight_layout()
plt.show()

#### Human Development Index

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

g_top_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["hdi_category"],
    index_column_name="hdi_category",
    value_column_name="n_sites",
    annotation_format="g",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0][0],
)
g_top_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["hdi_category"],
    index_column_name="hdi_category",
    value_column_name="n_sites_expected",
    annotation_format=".2f",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0][1],
)

g_bottom_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["hdi_category"],
    index_column_name="hdi_category",
    value_column_name="factor_deviation_n_sites_from_expected",
    annotation_format=".2f",
    phase_column_name="phase",
    linear_diverging=True,
    cbar=False,
    ax=axes[1][0],
)
g_bottom_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["hdi_category"],
    index_column_name="hdi_category",
    value_column_name="factor_deviation_n_sites_from_expected",
    annotation_format=".2f",
    phase_column_name="phase",
    log_scaled=True,
    cbar=False,
    ax=axes[1][1],
)

g_top_left.set_title("A) Number of trial sites per HDI category")
g_top_right.set_title("B) Expected number of trial sites per HDI category")
g_bottom_left.set_title("C) Disproportionality of actual number of trial sites vs. expected")
g_bottom_right.set_title("D) Log-disproportionality of actual number of trial sites vs. expected")
fig.tight_layout()
plt.show()

### Trials per capita

#### Continent level

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

g_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["country_continent"],
    index_column_name="country_continent",
    value_column_name="sites_per_capita",
    scale_factor=1_000_000,
    annotation_format=".2f",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0],
)
g_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["country_continent"],
    index_column_name="country_continent",
    value_column_name="sites_per_capita",
    scale_factor=1_000_000_000,
    annotation_format=".2f",
    phase_column_name="phase",
    cbar=False,
    ax=axes[1],
)
g_left.set_title("A) Number of trial sites per million (" + r"$10^{6}$" + ")")
g_right.set_title("B) Number of trial sites per billion (" + r"$10^{9}$" + ")")
fig.tight_layout()
plt.show()

#### Subregion level

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

g_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["subregion"],
    index_column_name="subregion",
    value_column_name="sites_per_capita",
    scale_factor=1_000_000,
    annotation_format=".2f",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0],
)
g_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["subregion"],
    index_column_name="subregion",
    value_column_name="sites_per_capita",
    scale_factor=1_000_000_000,
    annotation_format=".2f",
    phase_column_name="phase",
    cbar=False,
    ax=axes[1],
)
g_left.set_title("A) Number of trial sites per million (" + r"$10^{6}$" + ")")
g_right.set_title("B) Number of trial sites per billion (" + r"$10^{9}$" + ")")
fig.tight_layout()
plt.show()

#### Economy level

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

g_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["consolidated_economy_level"],
    index_column_name="consolidated_economy_level",
    value_column_name="sites_per_capita",
    scale_factor=1_000_000,
    annotation_format=".2f",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0],
)
g_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["consolidated_economy_level"],
    index_column_name="consolidated_economy_level",
    value_column_name="sites_per_capita",
    scale_factor=1_000_000_000,
    annotation_format=".2f",
    phase_column_name="phase",
    cbar=False,
    ax=axes[1],
)
g_left.set_title("A) Number of trial sites per million (" + r"$10^{6}$" + ")")
g_right.set_title("B) Number of trial sites per billion (" + r"$10^{9}$" + ")")
fig.tight_layout()
plt.show()

#### Income group

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

g_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["consolidated_income_group"],
    index_column_name="consolidated_income_group",
    value_column_name="sites_per_capita",
    scale_factor=1_000_000,
    annotation_format=".2f",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0],
)
g_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["consolidated_income_group"],
    index_column_name="consolidated_income_group",
    value_column_name="sites_per_capita",
    scale_factor=1_000_000_000,
    annotation_format=".2f",
    phase_column_name="phase",
    cbar=False,
    ax=axes[1],
)
g_left.set_title("A) Number of trial sites per million (" + r"$10^{6}$" + ")")
g_right.set_title("B) Number of trial sites per billion (" + r"$10^{9}$" + ")")
fig.tight_layout()
plt.show()

#### HDI category

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

g_left = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["hdi_category"],
    index_column_name="hdi_category",
    value_column_name="sites_per_capita",
    scale_factor=1_000_000,
    annotation_format=".2f",
    phase_column_name="phase",
    cbar=False,
    ax=axes[0],
)
g_right = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["hdi_category"],
    index_column_name="hdi_category",
    value_column_name="sites_per_capita",
    scale_factor=1_000_000_000,
    annotation_format=".2f",
    phase_column_name="phase",
    cbar=False,
    ax=axes[1],
)
g_left.set_title("A) Number of trial sites per million (" + r"$10^{6}$" + ")")
g_right.set_title("B) Number of trial sites per billion (" + r"$10^{9}$" + ")")
fig.tight_layout()
plt.show()