# Figures for publication

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%load_ext blackcellmagic

In [None]:
import numpy as np
import pandas as pd

import geopandas as gpd

import matplotlib.pyplot as plt
import seaborn as sns
sns.set(color_codes=True)
sns.set_style("whitegrid", {"grid.color": "gainsboro"})

from tools import visualization

## Load data

In [None]:
grouper_column_names = [
    "country_ISO",
    "country_continent",
    "subregion",
    "economy_level",
    "consolidated_economy_level",
    "income_group",
    "consolidated_income_group",
    "hdi_category",
]

In [None]:
consolidated_counts_overall = {}
for grouper_column_name in grouper_column_names:
    consolidated_counts_overall[grouper_column_name] = pd.read_excel(
        "data/results/trials_sites_counts.xlsx",
        sheet_name="ovr_" + grouper_column_name,
    )

In [None]:
consolidated_counts_overall["country_continent"]

In [None]:
consolidated_counts_per_phase = {}
for grouper_column_name in grouper_column_names:
    consolidated_counts_per_phase[grouper_column_name] = pd.read_excel(
        "data/results/trials_sites_counts.xlsx",
        sheet_name="phs_" + grouper_column_name,
    )

In [None]:
consolidated_counts_per_phase["country_continent"].head()

## Figure 1 - log-disproportionality of trial sites worldmap and regression

In [None]:
phase_country_data = consolidated_counts_per_phase["country_ISO"]

In [None]:
phase_country_data_nonzero = phase_country_data[
    phase_country_data["n_trials"] > 0
].copy()

In [None]:
phase_country_data_nonzero["log10_factor_deviation_n_sites_from_expected"] = np.log10(
    phase_country_data_nonzero["factor_deviation_n_sites_from_expected"]
)

### Add continent info

In [None]:
ms_trials_socioeconomic = pd.read_excel(
    "data/results/trials_sites_counts.xlsx", sheet_name="Base_dataset"
)

In [None]:
country_socioeconomic_data = (
    ms_trials_socioeconomic[
        [
            "country_ISO",
            "country_continent",
            "subregion",
            "consolidated_income_group",
            "consolidated_economy_level",
            "hdi_category",
        ]
    ]
    .drop_duplicates()
    .reset_index(drop=True)
    .copy()
)

In [None]:
phase_country_data_nonzero = pd.merge(left=phase_country_data_nonzero, right=country_socioeconomic_data, on="country_ISO", how="left")

In [None]:
continents = sorted(list(phase_country_data_nonzero["country_continent"].drop_duplicates()))

### Load and prepare geometry data from naturalearth
https://www.naturalearthdata.com/

In [None]:
geo_data = gpd.read_file("data/source/naturalearth/110m_cultural/ne_110m_admin_0_countries.shx")

In [None]:
geo_data = geo_data[
    ~geo_data["CONTINENT"].isin(["Antarctica", "Seven seas (open ocean)"])
].copy()

In [None]:
geo_data = (
    geo_data[
        [
            "ADM0_A3",
            "SUBREGION",
            "geometry",
        ]
    ]
    .rename(
        columns={
            "ADM0_A3": "country_ISO",
            "SUBREGION": "subregion",
        }
    )
    .copy()
)

### Load and add HDI data

In [None]:
hdi_raw = pd.read_csv("data/source/unstats/human-development-index.csv")

In [None]:
max_year_hdi_per_country = (
    hdi_raw[["Entity", "Code", "Year"]]
    .groupby(["Entity", "Code"])
    .max()
    .reset_index()
)

In [None]:
hdi_raw = pd.merge(
    left=hdi_raw,
    right=max_year_hdi_per_country,
    on=["Entity", "Code", "Year"],
    how="inner",
)

In [None]:
hdi_data = (
    hdi_raw[["Entity", "Code", "Human Development Index"]]
    .rename(
        columns={
            "Entity": "hdi_country",
            "Code": "country_ISO",
            "Human Development Index": "hdi",
        }
    )
    .copy()
)

In [None]:
phase_country_data_nonzero = pd.merge(left=phase_country_data_nonzero, right=hdi_data[["country_ISO", "hdi"]], on="country_ISO", how="inner")

### Plot

In [None]:
continent_colors = ["black", "indigo", "deeppink", "orange", "olivedrab", "teal"]

In [None]:
fig, axes = plt.subplots(4, 2, figsize=((20, 20)), width_ratios=[15, 5])

for i in range(1, 5):
    g_world = visualization.plot_choropleth_map_country_level(
        trial_data=phase_country_data_nonzero[
            phase_country_data_nonzero["phase"] == "PHASE" + str(i)
        ],
        column_to_plot="factor_deviation_n_sites_from_expected",
        log_scale=True,
        log_scale_diverging_palette=True,
        colormap_minimum_value=-2.5,
        colormap_maximum_value=1.5,
        show_colorbar=True,
        edgecolor=u"black",
        edges_linewidth=0.25,
        geometry_base_dataframe=geo_data,
        geometry_column_name="geometry",
        country_id_column_name="country_ISO",
        base_color="silver",
        base_edgecolor=u"white",
        ax=axes[i - 1][0],
    )

    g_regression = visualization.linear_regression_and_scatter_plot(
        data=phase_country_data_nonzero[
            phase_country_data_nonzero["phase"] == "PHASE" + str(i)
        ],
        x_column="hdi",
        y_column="log10_factor_deviation_n_sites_from_expected",
        scatter_palette=continent_colors,
        scatter_hue_column="country_continent",
        scatter_hue_order=continents,
        scatter_alpha=.75,
        xlim=(0.55, 0.975),
        ylim=(-2.75, 2.25),
        n_points=100,
        ax=axes[i - 1][1],
    )

    axes[i - 1][1].legend(title="Continent", loc="upper left")
    axes[i - 1][1].set_xlabel("Human Development Index")
    axes[i - 1][1].set_ylabel(
        "$\mathregular{Log}_{10}$" + "-disproportionality of trial sites"
    )

    axes[i - 1][0].set_title(
        "ACEG"[i - 1]
        + ") "
        + "$\mathregular{Log}_{10}$"
        + "-disproportionality of actual number of phase "
        + str(i)
        + " trial sites vs. expected number of phase "
        + str(i)
        + " trial sites based on population"
    )
    axes[i - 1][1].set_title(
        "BDFH"[i - 1]
        + ") "
        + "$\mathregular{Log}_{10}$"
        + "-disproportionality of phase "
        + str(i)
        + " trial sites as function of HDI, linear regression"
    )

plt.subplots_adjust(hspace=0.75)
fig.tight_layout()
save_path = "figures/figure_1"
plt.savefig(save_path + ".pdf", format="pdf", bbox_inches="tight")
plt.savefig(save_path + ".svg", format="svg", bbox_inches="tight")
plt.savefig(save_path + ".png", format="png", bbox_inches="tight")
plt.savefig(
    save_path + ".tiff",
    format="tiff",
    dpi=600,
    pil_kwargs={"compression": "tiff_lzw"},
    bbox_inches="tight",
)

## Figure 2 - combined heatmaps
To include:
* Trials per continent, relative
* Trial sites per continent
* Trial sites per region
* Trial sites per HDI

In [None]:
fig = plt.figure(figsize=(14, 12))

# The subpanels have different dimensions, so we have to
# play around with the grid specs and create custom axes.
ax_l1 = plt.subplot2grid((16, 2), (0, 0), rowspan=6)
ax_l2 = plt.subplot2grid((16, 2), (6, 0), rowspan=6)
ax_l3 = plt.subplot2grid((16, 2), (12, 0), rowspan=4)
ax_r1 = plt.subplot2grid((16, 2), (0, 1), rowspan=20)
axes = [ax_l1, ax_l2, ax_l3, ax_r1]


g_l1 = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["country_continent"],
    index_column_name="country_continent",
    value_column_name="of_total_trials_phase",
    annotation_format=".2%",
    phase_column_name="phase",
    linear_palette="viridis",
    cbar=False,
    ax=ax_l1,
)
ax_l1.set_title("A) Percentage of trials with at least one site on a given continent")

g_l2 = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["country_continent"],
    index_column_name="country_continent",
    value_column_name="of_total_sites_phase",
    annotation_format=".2%",
    phase_column_name="phase",
    cbar=False,
    ax=ax_l2,
)
ax_l2.set_title("B) Distribution of trial sites over continents")

g_l3 = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["hdi_category"],
    index_column_name="hdi_category",
    value_column_name="of_total_sites_phase",
    annotation_format=".2%",
    phase_column_name="phase",
    cbar=False,
    ax=ax_l3,
)
ax_l3.set_title("C) Distribution of trial sites over HDI categories")

g_r1 = visualization.plot_heatmap_per_phase(
    trial_data=consolidated_counts_per_phase["subregion"],
    index_column_name="subregion",
    value_column_name="of_total_sites_phase",
    annotation_format=".2%",
    phase_column_name="phase",
    cbar=False,
    ax=ax_r1,
)
ax_r1.set_title("D) Distribution of trial sites over geographical regions")

for ax in axes:
    ax.set_xlabel(None)
    ax.set_ylabel(None)
fig.tight_layout()

save_path = "figures/figure_2"
plt.savefig(save_path + ".pdf", format="pdf", bbox_inches="tight")
plt.savefig(save_path + ".svg", format="svg", bbox_inches="tight")
plt.savefig(save_path + ".png", format="png", bbox_inches="tight")
plt.savefig(
    save_path + ".tiff",
    format="tiff",
    dpi=600,
    pil_kwargs={"compression": "tiff_lzw"},
    bbox_inches="tight",
)

## Figure for GitHub README summary

### Prepare data

In [None]:
overall_country_data = consolidated_counts_overall["country_ISO"]

In [None]:
overall_country_data_nonzero = overall_country_data[
    overall_country_data["n_trials"] > 0
].copy()

### Add continent info

In [None]:
overall_country_data_nonzero = pd.merge(left=overall_country_data_nonzero, right=country_socioeconomic_data, on="country_ISO", how="left")

### Plot

In [None]:
fig, ax = plt.subplots(1, figsize=(15, 15))

g = visualization.plot_choropleth_map_country_level(
    trial_data=overall_country_data_nonzero,
    column_to_plot="factor_deviation_n_sites_from_expected",
    log_scale=True,
    log_scale_diverging_palette=True,
    colormap_minimum_value=-2.5,
    colormap_maximum_value=1.5,
    show_colorbar=True,
    edgecolor=u"black",
    edges_linewidth=0.25,
    geometry_base_dataframe=geo_data,
    geometry_column_name="geometry",
    country_id_column_name="country_ISO",
    base_color="silver",
    base_edgecolor=u"white",
    ax=None,
)
g.set_title("$\mathregular{Log}_{10}$" + "-disproportionality of actual number of trial sites vs. expected number of trial sites based on population")

fig.tight_layout()
save_path = "figures/log10_disproportionality_all_phases"
plt.savefig(save_path + ".pdf", format="pdf", bbox_inches="tight")
plt.savefig(save_path + ".svg", format="svg", bbox_inches="tight")
plt.savefig(save_path + ".png", format="png", bbox_inches="tight")
plt.savefig(
    save_path + ".tiff",
    format="tiff",
    dpi=600,
    pil_kwargs={"compression": "tiff_lzw"},
    bbox_inches="tight",
)