In [None]:
%load_ext autoreload
%autoreload 2
%aimport utils_1_1

import pandas as pd
import numpy as np
import altair as alt
from altair_saver import save
import datetime
import dateutil.parser
from os.path import join

from constants_1_1 import SITE_FILE_TYPES
from utils_1_1 import (
    get_site_file_paths,
    get_site_file_info,
    get_site_ids,
    get_visualization_subtitle,
    get_country_color_map,
)
from theme import apply_theme
from web import for_website

alt.data_transformers.disable_max_rows(); # Allow using rows more than 5000

In [None]:
data_release='2021-05-28'
consistent_loinc = {
    "C_reactive_protein_CRP_Normal_Sensitivity": "C-reactive protein (Normal Sensitivity) (mg/dL)",
    "creatinine": "Creatinine (mg/dL)",
    "Ferritin": "Ferritin (ng/mL)",
    "D_dimer": "D-dimer (ng/mL)",
    "albumin": "Albumin (g/dL)",        

    "Fibrinogen": "Fibrinogen (mg/dL)",
    "alanine_aminotransferase_ALT": "Alanine aminotransferase (U/L)",
    "aspartate_aminotransferase_AST": "Aspartate aminotransferase (U/L)",
    "total_bilirubin": "Total bilirubin (mg/dL)",
    "lactate_dehydrogenase_LDH": "Lactate dehydrogenase (U/L)",
    "cardiac_troponin_High_Sensitivity": "Cardiac troponin High Sensitivity (ng/mL)",
    "cardiac_troponin_Normal_Sensitivity": "Cardiac troponin Normal Sensitivity (ng/mL)",
    "prothrombin_time_PT": "Prothrombin time (s)",
    "white_blood_cell_count_Leukocytes": "White blood cell count (10*3/uL)",
    "lymphocyte_count": "Lymphocyte count (10*3/uL)",
    "neutrophil_count": "Neutrophil count (10*3/uL)",
    "procalcitonin": "Procalcitonin (ng/mL)",
}

continents = ['NORTH AMERICA', 'EUROPE']
continent_colors = ['#D45E00', '#57B4E9']

countries = ['USA', 'FRANCE', 'GERMANY', 'ITALY']
country_colors = ['#D45E00', '#0072B2', '#029F73', '#B2AA2F'] # '#E5DA3E']

sites = ['APHP', 'FRBDX', 'UKFR', 'BIDMC', 'MGB', 'NWU', 'UCLA', 'UMICH', 'UPENN', 'UPITT', 'VA1', 'VA2', 'VA3', 'VA4', 'VA5']
site_colors = ['#0072B2', '#0072B2', '#029F73', '#D45E00', '#D45E00', '#D45E00', '#D45E00', '#D45E00', '#D45E00', '#D45E00', '#D45E00', '#D45E00', '#D45E00', '#D45E00', '#D45E00']
site_shapes = ['circle', 'circle', 'circle', 'diamond', 'diamond', 'diamond', 'diamond', 'diamond', 'diamond', 'diamond', 'diamond', 'diamond', 'diamond', 'diamond', 'diamond']

site_to_contry = {
    'META-USA': 'USA',
    'META-EUROPE': 'EUROPE',
    'META-FRANCE': 'FRANCE',
    'META-GERMANY': 'GERMANY',
    'META-ITALY': 'ITALY',
    'APHP': 'FRANCE',
    'FRBDX': 'FRANCE',
    'UKFR': 'FRANCE',
    'BIDMC': 'GERMANY',
    'MGB': 'USA',
    'NWU': 'USA', 
    'UCLA': 'USA', 
    'UMICH': 'USA', 
    'UPENN': 'USA', 
    'UPITT': 'USA', 
    'VA1': 'USA', 
    'VA2': 'USA', 
    'VA3': 'USA', 
    'VA4': 'USA', 
    'VA5': 'USA'
}

# ['black', '#0072B2', '#0072B2', '#0072B2', '#0072B2', '#CB7AA7', '#D45E00', '#D45E00', '#D45E00', '#D45E00', '#D45E00', '#D45E00', '#D45E00', '#D45E00', '#D45E00','#D45E00','#D45E00']
# len(sites)
# len(site_colors)
# len(site_shapes)

# Prediction Baseline

In [None]:
df = pd.read_csv(join("..", "data", "Phase2.1SurvivalRSummariesPublic", "ToShare", "table.R1.prediction.baselineLab.csv"))

df = df.drop(columns=['Unnamed: 0' ])
df = df.rename(columns={"nm.lab": "lab"})
df = pd.melt(df, id_vars=['siteid', 'lab'], var_name='day', value_name='value')
df.siteid = df.siteid.apply(lambda x: x.replace('Eurpoe', 'Europe').upper().replace('NORTH AMERICA', 'USA'))
df.day = df.day.apply(lambda x: x.replace('day', ''))
df.lab = df.lab.apply(lambda x: consistent_loinc[x])

unique_labs = df.lab.unique().tolist()
print(unique_labs)

unique_sites = df.siteid.unique().tolist()
print(unique_sites)

sdf = pd.read_csv(join("..", "data", "Phase2.1SurvivalRSummariesPublic", "ToShare", "table.totalN.csv"))
dic = pd.Series(sdf.N.values, index=sdf.siteid).to_dict()

df['N'] = df.siteid.apply(lambda x: dic[x] if x in dic.keys() else None)

countries = ['META', 'USA', 'EUROPE', 'FRANCE', 'GERMANY']
country_colors = ['black', '#D45E00', '#57B4E9', '#0072B2', '#029F73']

df

# df[df.siteid == 'META-EUROPE']

# Version with two panels

In [None]:
def plot_lab(df=None, lab=None, is_country=True):
    d = df.copy()
    d = d[d.lab == lab]
    d = d[(d.day == '3') | (d.day == '7') | (d.day == '14')]
    d.day = d.day.apply(lambda x: 3 if x == '3' else 7 if x == '7' else 14)
    
    showLegend = False
    #if lab == 'dem+cls+3lab' or True:
    #showLegend = True
    showLegend = False
    if lab == 'Aspartate aminotransferase (U/L)' or lab == "Total bilirubin (mg/dL)" or lab == "Neutrophil count (10*3/uL)":
        showLegend = True
        
    """
    Meta Analysis
    """
    dm = d[d.siteid.str.contains('META')].copy()
    dm.siteid = dm.siteid.apply(lambda x: x.replace('META-', '').replace('META', 'All Countries')) # .replace('USA', 'NORTH AMERICA'))
    color_scale=alt.Scale(domain=['All Countries'], range=['black'])
    
    plot = alt.Chart(
        dm
    ).mark_line(
#         point=True,
        size=3,
#         stroke='black',
        opacity=1
    ).encode(
        x=alt.X("day:O", title='Days Since Admission', axis=alt.Axis(labelAngle=0, tickCount=10), scale=alt.Scale(clamp=True, nice=False, padding=0.3, domain=['3', '7', '14'])),
        y=alt.Y("value:Q", title=None, scale=alt.Scale(zero=False, nice=False, padding=10)),
        # shape=alt.Shape("siteid:N", title='Continent', scale=alt.Scale(domain=['USA', 'EUROPE'], range=['circle', 'diamond']), legend=alt.Legend() if showLegend else None),
        color=alt.Color("siteid:N", title='Meta Analysis', scale=color_scale, legend=alt.Legend() if showLegend else None),
    ).properties(
        width=450,
        height=250
    )
    
    point = plot.mark_point(
        filled=True,
        opacity=1,
        size=50
    ).encode(
        color=alt.Color("siteid:N", scale=color_scale, legend=None)
    )
    
    mp = alt.layer(plot, point).resolve_scale(color='independent') # if is_country else plot
        
    """
    Continent-Level
    """
    dm = d[d.siteid.str.contains('META')].copy()
    
    dm.siteid = dm.siteid.apply(lambda x: x.replace('META-', '').replace('USA', 'NORTH AMERICA'))
    color_scale=alt.Scale(domain=['NORTH AMERICA', 'EUROPE'], range=['#D45E00', '#57B4E9'])
    
    plot = alt.Chart(
        dm
    ).mark_line(
#         point=True,
        size=3,
#         stroke='black',
        opacity=1
    ).encode(
        x=alt.X("day:O", title=None, axis=alt.Axis(labelAngle=0, tickCount=10, labels=False, domain=False), scale=alt.Scale(clamp=True, nice=False, padding=0.3, domain=['3', '7', '14'])),
        y=alt.Y("value:Q", title=None, scale=alt.Scale(zero=False, nice=False, padding=10, domain=[0.4, 0.9])),
        # shape=alt.Shape("siteid:N", title='Continent', scale=alt.Scale(domain=['USA', 'EUROPE'], range=['circle', 'diamond']), legend=alt.Legend() if showLegend else None),
        color=alt.Color("siteid:N", title='Continent Level', scale=color_scale, legend=alt.Legend(symbolStrokeWidth=4, symbolSize=300) if showLegend else None),
    ).properties(
        width=450,
        height=250
    )
    
    point = plot.mark_point(
        filled=True,
        opacity=1,
        size=50
    ).encode(
        color=alt.Color("siteid:N", scale=color_scale, legend=None)
    )
    
    np = alt.layer(plot, point).resolve_scale(color='independent') # if is_country else plot
    
    """
    Country-Level
    """
    dm = d[d.siteid.str.contains('META')].copy()
    dm.siteid = dm.siteid.apply(lambda x: x.replace('META-', '')) # .replace('USA', 'NORTH AMERICA'))
    color_scale=alt.Scale(domain=['FRANCE', 'GERMANY', 'USA'], range=['#0072B2', '#029F73', '#D6641E'])
    
    plot = alt.Chart(
        dm
    ).mark_line(
#         point=True,
        size=2,
#         stroke='black',
        opacity=1
    ).encode(
        x=alt.X("day:O", title='Days Since Admission', axis=alt.Axis(labelAngle=0, tickCount=10), scale=alt.Scale(clamp=True, nice=False, padding=0.3, domain=['3', '7', '14'])),
        y=alt.Y("value:Q", title=None, scale=alt.Scale(zero=False, nice=False, padding=10)),
        # shape=alt.Shape("siteid:N", title='Continent', scale=alt.Scale(domain=['USA', 'EUROPE'], range=['circle', 'diamond']), legend=alt.Legend() if showLegend else None),
        color=alt.Color("siteid:N", title='Country Level', scale=color_scale, legend=alt.Legend(symbolDash=[3, 3], symbolStrokeWidth=4, symbolSize=300) if showLegend else None),
        strokeDash=alt.value([3, 3])
    ).properties(
        width=450,
        height=250
    )
    
    point = plot.mark_point(
        filled=True,
        opacity=1,
        size=50
    ).encode(
        color=alt.Color("siteid:N", scale=color_scale, legend=None)
    )
    

    cp = alt.layer(plot, point).resolve_scale(color='independent') # if is_country else plot

    """
    Site-Level
    """
    dm = d[~d.siteid.str.contains('META')].copy()
    dm['country'] = dm.siteid.apply(lambda x: site_to_contry[x])
    color_scale=alt.Scale(domain=sites, range=site_colors)

    plot = alt.Chart(
        dm
    ).mark_line(
#         point=True,
        size=2.5,
#         stroke='black',
        opacity=0
    ).encode(
        x=alt.X("day:O", title='Days Since Admission', axis=alt.Axis(labelAngle=0, tickCount=10), scale=alt.Scale(clamp=True, nice=False, padding=0.3, domain=['3', '7', '14'])),
        y=alt.Y("value:Q", title=None, scale=alt.Scale(zero=False, nice=False, padding=10, domain=[0.4, 0.9], clamp=True)), # , 
        color=alt.Color("siteid:N", title=None, scale=color_scale, legend=None)
    ).properties(
        width=450,
        height=250
    )
    
    point = plot.mark_point(
        filled=True,
        opacity=0.5,
        size=150
    ).encode(
        color=alt.Color("country:N", title='Site Level', scale=alt.Scale(domain=['FRANCE', 'GERMANY', 'USA'], range=['#0072B2', '#029F73', '#D6641E']), legend=alt.Legend() if showLegend else None),
        size=alt.Size('N:Q', title='Sample Size', legend=alt.Legend(symbolFillColor='black') if showLegend else None),
        
        # color=alt.Color("N:Q", title='Sample Size', legend=alt.Legend() if showLegend else None),
        # shape=alt.Shape("siteid:N", title=None, scale=alt.Scale(domain=sites, range=site_shapes), legend=None),
        
        # opacity=alt.Opacity('N:Q', title='Sample Size', scale=alt.Scale(range=[0, 1]))
    )
    

    sp = (point) # if is_country else plot
    
    """
    Combine
    """
    plot = alt.vconcat((mp + np), (sp + cp), spacing=10).resolve_scale(color='independent', size='independent', shape='independent')
    
    plot = plot.properties(
        title={
            "text": [
                 f"{lab}",
            ],
            "fontSize": 18,
            "dx": 30,
#             "subtitle": [
#                 get_visualization_subtitle(data_release=data_release, with_num_sites=False)
#             ],
            "subtitleColor": "gray",
        }
    )

    return plot

p1 = alt.hconcat(*(
   plot_lab(df=df, lab=lab, is_country=True) for lab in unique_labs[0:3]
), spacing=30)

p2 = alt.hconcat(*(
   plot_lab(df=df, lab=lab, is_country=True) for lab in unique_labs[3:6]
), spacing=30)

p3 = alt.hconcat(*(
   plot_lab(df=df, lab=lab, is_country=True) for lab in unique_labs[6:9]
), spacing=30)

plot = alt.vconcat(p1, p2, p3, spacing=30)


plot = apply_theme(
    plot,
    axis_y_title_font_size=16,
    title_anchor='start',
    legend_orient='right',
    legend_title_orient='top',
    axis_label_font_size=14,
    header_label_font_size=16,
    subtitle_font_size=18,
#     point_size=100
)

#plot

# plot = plot_lab(df=df)

#plot = alt.hconcat(*(
#   plot_lab(df=df, lab=model, is_country=True) for model in unique_models[:1]
#), spacing=30)


#plot = plot_lab(df=df, lab="Albumin (g/dL)", is_country=True)
#p1 = alt.hconcat(*(
#   plot_lab(df=df, lab=lab, is_country=True) for lab in unique_labs[0:3]
#), spacing=30)



# plot = plot.properties(
#     title={
#         "text": [
#             f"Cox Model For Death Prediction",
#         ],
#         "dx": 30,
# #         "subtitle": [
# #             get_visualization_subtitle(data_release=data_release, with_num_sites=False)
# #         ],
#         "subtitleColor": "gray",
#     }
# )

plot = apply_theme(
    plot,
    axis_y_title_font_size=16,
    title_anchor='start',
    legend_orient='right',
    legend_title_orient='top',
    axis_label_font_size=14,
    header_label_font_size=16,
    subtitle_font_size=18,
#     point_size=100
)

plot.display()
save(plot,join("..", "result", "R1-prediction-singleLab.png"), scalefactor=8.0)



# Prediction Cov

In [None]:
df = pd.read_csv(join("..", "data", "Phase2.1SurvivalRSummariesPublic", "ToShare", "table.R1.prediction.cov.csv"))

#df = df.drop(columns=['Unnamed: 0' ])
df = pd.melt(df, id_vars=['siteid'], var_name='day', value_name='value')
df.siteid = df.siteid.apply(lambda x: x.replace('Eurpoe', 'Europe').upper().replace('NORTH AMERICA', 'USA'))
df.day = df.day.apply(lambda x: x.replace('day', ''))

#unique_models = df.model.unique().tolist()
#print(unique_models)

unique_sites = df.siteid.unique().tolist()
print(unique_sites)

sdf = pd.read_csv(join("..", "data", "Phase2.1SurvivalRSummariesPublic", "ToShare", "table.totalN.csv"))
dic = pd.Series(sdf.N.values, index=sdf.siteid).to_dict()

df['N'] = df.siteid.apply(lambda x: dic[x] if x in dic.keys() else None)

countries = ['META', 'USA', 'EUROPE', 'FRANCE', 'GERMANY']
country_colors = ['black', '#D45E00', '#57B4E9', '#0072B2', '#029F73']

df

# df[df.siteid == 'META-EUROPE']

# Version with two panels

In [None]:
def plot_lab(df=None, is_country=True):
    d = df.copy()
    #d = d[d.model == lab]
    d = d[(d.day == '3') | (d.day == '7') | (d.day == '14')]
    d.day = d.day.apply(lambda x: 3 if x == '3' else 7 if x == '7' else 14)
    
    #showLegend = False
    #if lab == 'dem+cls+3lab' or True:
    showLegend = True

    """
    Meta Analysis
    """
    dm = d[d.siteid.str.contains('META')].copy()
    dm.siteid = dm.siteid.apply(lambda x: x.replace('META-', '').replace('META', 'All Countries')) # .replace('USA', 'NORTH AMERICA'))
    color_scale=alt.Scale(domain=['All Countries'], range=['black'])
    
    plot = alt.Chart(
        dm
    ).mark_line(
#         point=True,
        size=3,
#         stroke='black',
        opacity=1
    ).encode(
        x=alt.X("day:O", title='Days Since Admission', axis=alt.Axis(labelAngle=0, tickCount=10), scale=alt.Scale(clamp=True, nice=False, padding=0.3, domain=['3', '7', '14'])),
        y=alt.Y("value:Q", title=None, scale=alt.Scale(zero=False, nice=False, padding=10)),
        # shape=alt.Shape("siteid:N", title='Continent', scale=alt.Scale(domain=['USA', 'EUROPE'], range=['circle', 'diamond']), legend=alt.Legend() if showLegend else None),
        color=alt.Color("siteid:N", title='Meta Analysis', scale=color_scale, legend=alt.Legend() if showLegend else None),
    ).properties(
        width=450,
        height=250
    )
    
    point = plot.mark_point(
        filled=True,
        opacity=1,
        size=50
    ).encode(
        color=alt.Color("siteid:N", scale=color_scale, legend=None)
    )
    
    mp = alt.layer(plot, point).resolve_scale(color='independent') # if is_country else plot
        
    """
    Continent-Level
    """
    dm = d[d.siteid.str.contains('META')].copy()
    
    dm.siteid = dm.siteid.apply(lambda x: x.replace('META-', '').replace('USA', 'NORTH AMERICA'))
    color_scale=alt.Scale(domain=['NORTH AMERICA', 'EUROPE'], range=['#D45E00', '#57B4E9'])
    
    plot = alt.Chart(
        dm
    ).mark_line(
#         point=True,
        size=3,
#         stroke='black',
        opacity=1
    ).encode(
        x=alt.X("day:O", title=None, axis=alt.Axis(labelAngle=0, tickCount=10, labels=False, domain=False), scale=alt.Scale(clamp=True, nice=False, padding=0.3, domain=['3', '7', '14'])),
        y=alt.Y("value:Q", title=None, scale=alt.Scale(zero=False, nice=False, padding=10, domain=[0.4, 0.9])),
        # shape=alt.Shape("siteid:N", title='Continent', scale=alt.Scale(domain=['USA', 'EUROPE'], range=['circle', 'diamond']), legend=alt.Legend() if showLegend else None),
        color=alt.Color("siteid:N", title='Continent Level', scale=color_scale, legend=alt.Legend(symbolStrokeWidth=4, symbolSize=300) if showLegend else None),
    ).properties(
        width=450,
        height=250
    )
    
    point = plot.mark_point(
        filled=True,
        opacity=1,
        size=50
    ).encode(
        color=alt.Color("siteid:N", scale=color_scale, legend=None)
    )
    
    np = alt.layer(plot, point).resolve_scale(color='independent') # if is_country else plot
    
    """
    Country-Level
    """
    dm = d[d.siteid.str.contains('META')].copy()
    dm.siteid = dm.siteid.apply(lambda x: x.replace('META-', '')) # .replace('USA', 'NORTH AMERICA'))
    color_scale=alt.Scale(domain=['FRANCE', 'GERMANY', 'USA'], range=['#0072B2', '#029F73', '#D6641E'])
    
    plot = alt.Chart(
        dm
    ).mark_line(
#         point=True,
        size=2,
#         stroke='black',
        opacity=1
    ).encode(
        x=alt.X("day:O", title='Days Since Admission', axis=alt.Axis(labelAngle=0, tickCount=10), scale=alt.Scale(clamp=True, nice=False, padding=0.3, domain=['3', '7', '14'])),
        y=alt.Y("value:Q", title=None, scale=alt.Scale(zero=False, nice=False, padding=10)),
        # shape=alt.Shape("siteid:N", title='Continent', scale=alt.Scale(domain=['USA', 'EUROPE'], range=['circle', 'diamond']), legend=alt.Legend() if showLegend else None),
        color=alt.Color("siteid:N", title='Country Level', scale=color_scale, legend=alt.Legend(symbolDash=[3, 3], symbolStrokeWidth=4, symbolSize=300) if showLegend else None),
        strokeDash=alt.value([3, 3])
    ).properties(
        width=450,
        height=250
    )
    
    point = plot.mark_point(
        filled=True,
        opacity=1,
        size=50
    ).encode(
        color=alt.Color("siteid:N", scale=color_scale, legend=None)
    )
    

    cp = alt.layer(plot, point).resolve_scale(color='independent') # if is_country else plot

    """
    Site-Level
    """
    dm = d[~d.siteid.str.contains('META')].copy()
    dm['country'] = dm.siteid.apply(lambda x: site_to_contry[x])
    color_scale=alt.Scale(domain=sites, range=site_colors)

    plot = alt.Chart(
        dm
    ).mark_line(
#         point=True,
        size=2.5,
#         stroke='black',
        opacity=0
    ).encode(
        x=alt.X("day:O", title='Days Since Admission', axis=alt.Axis(labelAngle=0, tickCount=10), scale=alt.Scale(clamp=True, nice=False, padding=0.3, domain=['3', '7', '14'])),
        y=alt.Y("value:Q", title=None, scale=alt.Scale(zero=False, nice=False, padding=10, domain=[0.4, 0.9], clamp=True)), # , 
        color=alt.Color("siteid:N", title=None, scale=color_scale, legend=None)
    ).properties(
        width=450,
        height=250
    )
    
    point = plot.mark_point(
        filled=True,
        opacity=0.5,
        size=150
    ).encode(
        color=alt.Color("country:N", title='Site Level', scale=alt.Scale(domain=['FRANCE', 'GERMANY', 'USA'], range=['#0072B2', '#029F73', '#D6641E']), legend=alt.Legend() if showLegend else None),
        size=alt.Size('N:Q', title='Sample Size', legend=alt.Legend(symbolFillColor='black') if showLegend else None),
        
        # color=alt.Color("N:Q", title='Sample Size', legend=alt.Legend() if showLegend else None),
        # shape=alt.Shape("siteid:N", title=None, scale=alt.Scale(domain=sites, range=site_shapes), legend=None),
        
        # opacity=alt.Opacity('N:Q', title='Sample Size', scale=alt.Scale(range=[0, 1]))
    )
    

    sp = (point) # if is_country else plot
    
    """
    Combine
    """
    plot = alt.vconcat((mp + np), (sp + cp), spacing=10).resolve_scale(color='independent', size='independent', shape='independent')
    
    plot = plot.properties(
        title={
            "text": [
#                 f"{lab}",
            ],
            "fontSize": 18,
            "dx": 30,
#             "subtitle": [
#                 get_visualization_subtitle(data_release=data_release, with_num_sites=False)
#             ],
            "subtitleColor": "gray",
        }
    )

    return plot

# plot = plot_lab(df=df)

#plot = alt.hconcat(*(
#   plot_lab(df=df, lab=model, is_country=True) for model in unique_models[:1]
#), spacing=30)


plot = plot_lab(df=df, is_country=True)

# plot = plot.properties(
#     title={
#         "text": [
#             f"Cox Model For Death Prediction",
#         ],
#         "dx": 30,
# #         "subtitle": [
# #             get_visualization_subtitle(data_release=data_release, with_num_sites=False)
# #         ],
#         "subtitleColor": "gray",
#     }
# )

plot = apply_theme(
    plot,
    axis_y_title_font_size=16,
    title_anchor='start',
    legend_orient='right',
    legend_title_orient='top',
    axis_label_font_size=14,
    header_label_font_size=16,
    subtitle_font_size=18,
#     point_size=100
)

plot.display()
save(plot,join("..", "result", "R1-prediction-cov.png"), scalefactor=8.0)



# Prediction TPR

In [None]:
df = pd.read_csv(join("..", "data", "Phase2.1SurvivalRSummariesPublic", "ToShare", "table.R1.prediction.cov.tpr.csv"))

#df = df.drop(columns=['Unnamed: 0' ])
df = pd.melt(df, id_vars=['siteid'], var_name='day', value_name='value')
df.siteid = df.siteid.apply(lambda x: x.replace('Eurpoe', 'Europe').upper().replace('NORTH AMERICA', 'USA'))
df.day = df.day.apply(lambda x: x.replace('day', ''))

#unique_models = df.model.unique().tolist()
#print(unique_models)

unique_sites = df.siteid.unique().tolist()
print(unique_sites)

sdf = pd.read_csv(join("..", "data", "Phase2.1SurvivalRSummariesPublic", "ToShare", "table.totalN.csv"))
dic = pd.Series(sdf.N.values, index=sdf.siteid).to_dict()

df['N'] = df.siteid.apply(lambda x: dic[x] if x in dic.keys() else None)

countries = ['META', 'USA', 'EUROPE', 'FRANCE', 'GERMANY']
country_colors = ['black', '#D45E00', '#57B4E9', '#0072B2', '#029F73']

df

# df[df.siteid == 'META-EUROPE']

In [None]:
def plot_lab(df=None, is_country=True):
    d = df.copy()
    #d = d[d.model == lab]
    d = d[(d.day == '3') | (d.day == '7') | (d.day == '14')]
    d.day = d.day.apply(lambda x: 3 if x == '3' else 7 if x == '7' else 14)
    
    #showLegend = False
    #if lab == 'dem+cls+3lab' or True:
    showLegend = True

    """
    Meta Analysis
    """
    dm = d[d.siteid.str.contains('META')].copy()
    dm.siteid = dm.siteid.apply(lambda x: x.replace('META-', '').replace('META', 'All Countries')) # .replace('USA', 'NORTH AMERICA'))
    color_scale=alt.Scale(domain=['All Countries'], range=['black'])
    
    plot = alt.Chart(
        dm
    ).mark_line(
#         point=True,
        size=3,
#         stroke='black',
        opacity=1
    ).encode(
        x=alt.X("day:O", title='Days Since Admission', axis=alt.Axis(labelAngle=0, tickCount=10), scale=alt.Scale(clamp=True, nice=False, padding=0.3, domain=['3', '7', '14'])),
        y=alt.Y("value:Q", title=None, scale=alt.Scale(zero=False, nice=False, padding=10)),
        # shape=alt.Shape("siteid:N", title='Continent', scale=alt.Scale(domain=['USA', 'EUROPE'], range=['circle', 'diamond']), legend=alt.Legend() if showLegend else None),
        color=alt.Color("siteid:N", title='Meta Analysis', scale=color_scale, legend=alt.Legend() if showLegend else None),
    ).properties(
        width=450,
        height=250
    )
    
    point = plot.mark_point(
        filled=True,
        opacity=1,
        size=50
    ).encode(
        color=alt.Color("siteid:N", scale=color_scale, legend=None)
    )
    
    mp = alt.layer(plot, point).resolve_scale(color='independent') # if is_country else plot
        
    """
    Continent-Level
    """
    dm = d[d.siteid.str.contains('META')].copy()
    
    dm.siteid = dm.siteid.apply(lambda x: x.replace('META-', '').replace('USA', 'NORTH AMERICA'))
    color_scale=alt.Scale(domain=['NORTH AMERICA', 'EUROPE'], range=['#D45E00', '#57B4E9'])
    
    plot = alt.Chart(
        dm
    ).mark_line(
#         point=True,
        size=3,
#         stroke='black',
        opacity=1
    ).encode(
        x=alt.X("day:O", title=None, axis=alt.Axis(labelAngle=0, tickCount=10, labels=False, domain=False), scale=alt.Scale(clamp=True, nice=False, padding=0.3, domain=['3', '7', '14'])),
        y=alt.Y("value:Q", title=None, scale=alt.Scale(zero=False, nice=False, padding=10, domain=[0, 0.9])),
        # shape=alt.Shape("siteid:N", title='Continent', scale=alt.Scale(domain=['USA', 'EUROPE'], range=['circle', 'diamond']), legend=alt.Legend() if showLegend else None),
        color=alt.Color("siteid:N", title='Continent Level', scale=color_scale, legend=alt.Legend(symbolStrokeWidth=4, symbolSize=300) if showLegend else None),
    ).properties(
        width=450,
        height=250
    )
    
    point = plot.mark_point(
        filled=True,
        opacity=1,
        size=50
    ).encode(
        color=alt.Color("siteid:N", scale=color_scale, legend=None)
    )
    
    np = alt.layer(plot, point).resolve_scale(color='independent') # if is_country else plot
    
    """
    Country-Level
    """
    dm = d[d.siteid.str.contains('META')].copy()
    dm.siteid = dm.siteid.apply(lambda x: x.replace('META-', '')) # .replace('USA', 'NORTH AMERICA'))
    color_scale=alt.Scale(domain=['FRANCE', 'GERMANY', 'USA'], range=['#0072B2', '#029F73', '#D6641E'])
    
    plot = alt.Chart(
        dm
    ).mark_line(
#         point=True,
        size=2,
#         stroke='black',
        opacity=1
    ).encode(
        x=alt.X("day:O", title='Days Since Admission', axis=alt.Axis(labelAngle=0, tickCount=10), scale=alt.Scale(clamp=True, nice=False, padding=0.3, domain=['3', '7', '14'])),
        y=alt.Y("value:Q", title=None, scale=alt.Scale(zero=False, nice=False, padding=10)),
        # shape=alt.Shape("siteid:N", title='Continent', scale=alt.Scale(domain=['USA', 'EUROPE'], range=['circle', 'diamond']), legend=alt.Legend() if showLegend else None),
        color=alt.Color("siteid:N", title='Country Level', scale=color_scale, legend=alt.Legend(symbolDash=[3, 3], symbolStrokeWidth=4, symbolSize=300) if showLegend else None),
        strokeDash=alt.value([3, 3])
    ).properties(
        width=450,
        height=250
    )
    
    point = plot.mark_point(
        filled=True,
        opacity=1,
        size=50
    ).encode(
        color=alt.Color("siteid:N", scale=color_scale, legend=None)
    )
    

    cp = alt.layer(plot, point).resolve_scale(color='independent') # if is_country else plot

    """
    Site-Level
    """
    dm = d[~d.siteid.str.contains('META')].copy()
    dm['country'] = dm.siteid.apply(lambda x: site_to_contry[x])
    color_scale=alt.Scale(domain=sites, range=site_colors)

    plot = alt.Chart(
        dm
    ).mark_line(
#         point=True,
        size=2.5,
#         stroke='black',
        opacity=0
    ).encode(
        x=alt.X("day:O", title='Days Since Admission', axis=alt.Axis(labelAngle=0, tickCount=10), scale=alt.Scale(clamp=True, nice=False, padding=0.3, domain=['3', '7', '14'])),
        y=alt.Y("value:Q", title=None, scale=alt.Scale(zero=False, nice=False, padding=10, domain=[0, 0.9], clamp=True)), # , 
        color=alt.Color("siteid:N", title=None, scale=color_scale, legend=None)
    ).properties(
        width=450,
        height=250
    )
    
    point = plot.mark_point(
        filled=True,
        opacity=0.5,
        size=150
    ).encode(
        color=alt.Color("country:N", title='Site Level', scale=alt.Scale(domain=['FRANCE', 'GERMANY', 'USA'], range=['#0072B2', '#029F73', '#D6641E']), legend=alt.Legend() if showLegend else None),
        size=alt.Size('N:Q', title='Sample Size', legend=alt.Legend(symbolFillColor='black') if showLegend else None),
        
        # color=alt.Color("N:Q", title='Sample Size', legend=alt.Legend() if showLegend else None),
        # shape=alt.Shape("siteid:N", title=None, scale=alt.Scale(domain=sites, range=site_shapes), legend=None),
        
        # opacity=alt.Opacity('N:Q', title='Sample Size', scale=alt.Scale(range=[0, 1]))
    )
    

    sp = (point) # if is_country else plot
    
    """
    Combine
    """
    plot = alt.vconcat((mp + np), (sp + cp), spacing=10).resolve_scale(color='independent', size='independent', shape='independent')
    
    plot = plot.properties(
        title={
            "text": [
#                 f"{lab}",
            ],
            "fontSize": 18,
            "dx": 30,
#             "subtitle": [
#                 get_visualization_subtitle(data_release=data_release, with_num_sites=False)
#             ],
            "subtitleColor": "gray",
        }
    )

    return plot

# plot = plot_lab(df=df)

#plot = alt.hconcat(*(
#   plot_lab(df=df, lab=model, is_country=True) for model in unique_models[:1]
#), spacing=30)


plot = plot_lab(df=df, is_country=True)

# plot = plot.properties(
#     title={
#         "text": [
#             f"Cox Model For Death Prediction",
#         ],
#         "dx": 30,
# #         "subtitle": [
# #             get_visualization_subtitle(data_release=data_release, with_num_sites=False)
# #         ],
#         "subtitleColor": "gray",
#     }
# )

plot = apply_theme(
    plot,
    axis_y_title_font_size=16,
    title_anchor='start',
    legend_orient='right',
    legend_title_orient='top',
    axis_label_font_size=14,
    header_label_font_size=16,
    subtitle_font_size=18,
#     point_size=100
)

plot.display()
save(plot,join("..", "result", "R1-prediction-cov-tpr.png"), scalefactor=8.0)

# Prediction PPV

In [None]:
df = pd.read_csv(join("..", "data", "Phase2.1SurvivalRSummariesPublic", "ToShare", "table.R1.prediction.cov.ppv.csv"))

#df = df.drop(columns=['Unnamed: 0' ])
df = pd.melt(df, id_vars=['siteid'], var_name='day', value_name='value')
df.siteid = df.siteid.apply(lambda x: x.replace('Eurpoe', 'Europe').upper().replace('NORTH AMERICA', 'USA'))
df.day = df.day.apply(lambda x: x.replace('day', ''))

#unique_models = df.model.unique().tolist()
#print(unique_models)

unique_sites = df.siteid.unique().tolist()
print(unique_sites)

sdf = pd.read_csv(join("..", "data", "Phase2.1SurvivalRSummariesPublic", "ToShare", "table.totalN.csv"))
dic = pd.Series(sdf.N.values, index=sdf.siteid).to_dict()

df['N'] = df.siteid.apply(lambda x: dic[x] if x in dic.keys() else None)

countries = ['META', 'USA', 'EUROPE', 'FRANCE', 'GERMANY']
country_colors = ['black', '#D45E00', '#57B4E9', '#0072B2', '#029F73']

df

In [None]:
def plot_lab(df=None, is_country=True):
    d = df.copy()
    #d = d[d.model == lab]
    d = d[(d.day == '3') | (d.day == '7') | (d.day == '14')]
    d.day = d.day.apply(lambda x: 3 if x == '3' else 7 if x == '7' else 14)
    
    #showLegend = False
    #if lab == 'dem+cls+3lab' or True:
    showLegend = True

    """
    Meta Analysis
    """
    dm = d[d.siteid.str.contains('META')].copy()
    dm.siteid = dm.siteid.apply(lambda x: x.replace('META-', '').replace('META', 'All Countries')) # .replace('USA', 'NORTH AMERICA'))
    color_scale=alt.Scale(domain=['All Countries'], range=['black'])
    
    plot = alt.Chart(
        dm
    ).mark_line(
#         point=True,
        size=3,
#         stroke='black',
        opacity=1
    ).encode(
        x=alt.X("day:O", title='Days Since Admission', axis=alt.Axis(labelAngle=0, tickCount=10), scale=alt.Scale(clamp=True, nice=False, padding=0.3, domain=['3', '7', '14'])),
        y=alt.Y("value:Q", title=None, scale=alt.Scale(zero=False, nice=False, padding=10)),
        # shape=alt.Shape("siteid:N", title='Continent', scale=alt.Scale(domain=['USA', 'EUROPE'], range=['circle', 'diamond']), legend=alt.Legend() if showLegend else None),
        color=alt.Color("siteid:N", title='Meta Analysis', scale=color_scale, legend=alt.Legend() if showLegend else None),
    ).properties(
        width=450,
        height=250
    )
    
    point = plot.mark_point(
        filled=True,
        opacity=1,
        size=50
    ).encode(
        color=alt.Color("siteid:N", scale=color_scale, legend=None)
    )
    
    mp = alt.layer(plot, point).resolve_scale(color='independent') # if is_country else plot
        
    """
    Continent-Level
    """
    dm = d[d.siteid.str.contains('META')].copy()
    
    dm.siteid = dm.siteid.apply(lambda x: x.replace('META-', '').replace('USA', 'NORTH AMERICA'))
    color_scale=alt.Scale(domain=['NORTH AMERICA', 'EUROPE'], range=['#D45E00', '#57B4E9'])
    
    plot = alt.Chart(
        dm
    ).mark_line(
#         point=True,
        size=3,
#         stroke='black',
        opacity=1
    ).encode(
        x=alt.X("day:O", title=None, axis=alt.Axis(labelAngle=0, tickCount=10, labels=False, domain=False), scale=alt.Scale(clamp=True, nice=False, padding=0.3, domain=['3', '7', '14'])),
        y=alt.Y("value:Q", title=None, scale=alt.Scale(zero=False, nice=False, padding=10, domain=[0, 0.4])),
        # shape=alt.Shape("siteid:N", title='Continent', scale=alt.Scale(domain=['USA', 'EUROPE'], range=['circle', 'diamond']), legend=alt.Legend() if showLegend else None),
        color=alt.Color("siteid:N", title='Continent Level', scale=color_scale, legend=alt.Legend(symbolStrokeWidth=4, symbolSize=300) if showLegend else None),
    ).properties(
        width=450,
        height=250
    )
    
    point = plot.mark_point(
        filled=True,
        opacity=1,
        size=50
    ).encode(
        color=alt.Color("siteid:N", scale=color_scale, legend=None)
    )
    
    np = alt.layer(plot, point).resolve_scale(color='independent') # if is_country else plot
    
    """
    Country-Level
    """
    dm = d[d.siteid.str.contains('META')].copy()
    dm.siteid = dm.siteid.apply(lambda x: x.replace('META-', '')) # .replace('USA', 'NORTH AMERICA'))
    color_scale=alt.Scale(domain=['FRANCE', 'GERMANY', 'USA'], range=['#0072B2', '#029F73', '#D6641E'])
    
    plot = alt.Chart(
        dm
    ).mark_line(
#         point=True,
        size=2,
#         stroke='black',
        opacity=1
    ).encode(
        x=alt.X("day:O", title='Days Since Admission', axis=alt.Axis(labelAngle=0, tickCount=10), scale=alt.Scale(clamp=True, nice=False, padding=0.3, domain=['3', '7', '14'])),
        y=alt.Y("value:Q", title=None, scale=alt.Scale(zero=False, nice=False, padding=10)),
        # shape=alt.Shape("siteid:N", title='Continent', scale=alt.Scale(domain=['USA', 'EUROPE'], range=['circle', 'diamond']), legend=alt.Legend() if showLegend else None),
        color=alt.Color("siteid:N", title='Country Level', scale=color_scale, legend=alt.Legend(symbolDash=[3, 3], symbolStrokeWidth=4, symbolSize=300) if showLegend else None),
        strokeDash=alt.value([3, 3])
    ).properties(
        width=450,
        height=250
    )
    
    point = plot.mark_point(
        filled=True,
        opacity=1,
        size=50
    ).encode(
        color=alt.Color("siteid:N", scale=color_scale, legend=None)
    )
    

    cp = alt.layer(plot, point).resolve_scale(color='independent') # if is_country else plot

    """
    Site-Level
    """
    dm = d[~d.siteid.str.contains('META')].copy()
    dm['country'] = dm.siteid.apply(lambda x: site_to_contry[x])
    color_scale=alt.Scale(domain=sites, range=site_colors)

    plot = alt.Chart(
        dm
    ).mark_line(
#         point=True,
        size=2.5,
#         stroke='black',
        opacity=0
    ).encode(
        x=alt.X("day:O", title='Days Since Admission', axis=alt.Axis(labelAngle=0, tickCount=10), scale=alt.Scale(clamp=True, nice=False, padding=0.3, domain=['3', '7', '14'])),
        y=alt.Y("value:Q", title=None, scale=alt.Scale(zero=False, nice=False, padding=10, domain=[0, 0.4], clamp=True)), # , 
        color=alt.Color("siteid:N", title=None, scale=color_scale, legend=None)
    ).properties(
        width=450,
        height=250
    )
    
    point = plot.mark_point(
        filled=True,
        opacity=0.5,
        size=150
    ).encode(
        color=alt.Color("country:N", title='Site Level', scale=alt.Scale(domain=['FRANCE', 'GERMANY', 'USA'], range=['#0072B2', '#029F73', '#D6641E']), legend=alt.Legend() if showLegend else None),
        size=alt.Size('N:Q', title='Sample Size', legend=alt.Legend(symbolFillColor='black') if showLegend else None),
        
        # color=alt.Color("N:Q", title='Sample Size', legend=alt.Legend() if showLegend else None),
        # shape=alt.Shape("siteid:N", title=None, scale=alt.Scale(domain=sites, range=site_shapes), legend=None),
        
        # opacity=alt.Opacity('N:Q', title='Sample Size', scale=alt.Scale(range=[0, 1]))
    )
    

    sp = (point) # if is_country else plot
    
    """
    Combine
    """
    plot = alt.vconcat((mp + np), (sp + cp), spacing=10).resolve_scale(color='independent', size='independent', shape='independent')
    
    plot = plot.properties(
        title={
            "text": [
#                 f"{lab}",
            ],
            "fontSize": 18,
            "dx": 30,
#             "subtitle": [
#                 get_visualization_subtitle(data_release=data_release, with_num_sites=False)
#             ],
            "subtitleColor": "gray",
        }
    )

    return plot

# plot = plot_lab(df=df)

#plot = alt.hconcat(*(
#   plot_lab(df=df, lab=model, is_country=True) for model in unique_models[:1]
#), spacing=30)


plot = plot_lab(df=df, is_country=True)

# plot = plot.properties(
#     title={
#         "text": [
#             f"Cox Model For Death Prediction",
#         ],
#         "dx": 30,
# #         "subtitle": [
# #             get_visualization_subtitle(data_release=data_release, with_num_sites=False)
# #         ],
#         "subtitleColor": "gray",
#     }
# )

plot = apply_theme(
    plot,
    axis_y_title_font_size=16,
    title_anchor='start',
    legend_orient='right',
    legend_title_orient='top',
    axis_label_font_size=14,
    header_label_font_size=16,
    subtitle_font_size=18,
#     point_size=100
)

plot
save(plot,join("..", "result", "R1-prediction-cov-ppv.png"), scalefactor=8.0)

# prediction cases

In [None]:
df = pd.read_csv(join("..", "data", "Phase2.1SurvivalRSummariesPublic", "ToShare", "table.R1.prediction.cases.csv"))
#print(df)
#df = df.drop(columns=['Unnamed: 0' ])
#df = pd.melt(df, id_vars=['siteid'], var_name='day', value_name='value')
#print(df)
df.siteid = df.siteid.apply(lambda x: x.replace('Eurpoe', 'Europe').upper().replace('NORTH AMERICA', 'USA'))
#df.day = df.day.apply(lambda x: x.replace('day', ''))
#print(df)
#unique_models = df.model.unique().tolist()
#print(unique_models)

unique_sites = df.siteid.unique().tolist()
print(unique_sites)

sdf = pd.read_csv(join("..", "data", "Phase2.1SurvivalRSummariesPublic", "ToShare", "table.totalN.csv"))
dic = pd.Series(sdf.N.values, index=sdf.siteid).to_dict()

df['N'] = df.siteid.apply(lambda x: dic[x] if x in dic.keys() else None)
print(df)
countries = ['META', 'USA', 'EUROPE', 'FRANCE', 'GERMANY']
country_colors = ['black', '#D45E00', '#57B4E9', '#0072B2', '#029F73']

df

In [None]:
def plot_lab(df=None, is_country=True):
    d = df.copy()
    d = d[(d.day == 3) | (d.day == 7) | (d.day == 14)]
    print(d)
    d.day = d.day.apply(lambda x: 3 if x == 3 else 7 if x == 7 else 14)
    #showLegend = False
    #if lab == 'dem+cls+3lab' or True:
    showLegend = True

    """
    Meta Analysis
    """
    dm = d[d.siteid.str.contains('META')].copy()
    dm.siteid = dm.siteid.apply(lambda x: x.replace('META-', '').replace('META', 'All Countries')) # .replace('USA', 'NORTH AMERICA'))
    color_scale=alt.Scale(domain=['All Countries'], range=['black'])
    
    plot = alt.Chart(
        dm
    ).mark_line(
#         point=True,
        size=3,
#         stroke='black',
        opacity=1
    ).encode(
        x=alt.X("day:O", title='Days Since Admission', axis=alt.Axis(labelAngle=0, tickCount=10), scale=alt.Scale(clamp=True, nice=False, padding=0.3, domain=['3', '7', '14'])),
        y=alt.Y("value:Q", title=None, scale=alt.Scale(zero=False, nice=False, padding=10)),
        # shape=alt.Shape("siteid:N", title='Continent', scale=alt.Scale(domain=['USA', 'EUROPE'], range=['circle', 'diamond']), legend=alt.Legend() if showLegend else None),
        #color=alt.Color("siteid:N", title='Meta Analysis', scale=color_scale, legend=alt.Legend() if showLegend else None),
    ).properties(
        width=450,
        height=250
    )
    
    point = plot.mark_point(
        filled=True,
        opacity=1,
        size=50
    ).encode(
        color=alt.Color("siteid:N", scale=color_scale, legend=None)
    )
    
    mp = alt.layer(plot, point).resolve_scale(color='independent') # if is_country else plot
        
    """
    Continent-Level
    """
    dm = d[d.siteid.str.contains('META')].copy()
    
    dm.siteid = dm.siteid.apply(lambda x: x.replace('META-', '').replace('USA', 'NORTH AMERICA'))
    color_scale=alt.Scale(domain=['NORTH AMERICA', 'EUROPE'], range=['#D45E00', '#57B4E9'])
    
    plot = alt.Chart(
        dm
    ).mark_line(
#         point=True,
        size=3,
#         stroke='black',
        opacity=1
    ).encode(
        x=alt.X("day:O", title=None, axis=alt.Axis(labelAngle=0, tickCount=10, labels=False, domain=False), scale=alt.Scale(clamp=True, nice=False, padding=0.3, domain=['3', '7', '14'])),
        y=alt.Y("value:Q", title=None, scale=alt.Scale(zero=False, nice=False, padding=10, domain=[0, 200])),
        # shape=alt.Shape("siteid:N", title='Continent', scale=alt.Scale(domain=['USA', 'EUROPE'], range=['circle', 'diamond']), legend=alt.Legend() if showLegend else None),
        color=alt.Color("siteid:N", title='Continent Level', scale=color_scale, legend=alt.Legend(symbolStrokeWidth=4, symbolSize=300) if showLegend else None),
    ).properties(
        width=450,
        height=250
    )
    
    point = plot.mark_point(
        filled=True,
        opacity=1,
        size=50
    ).encode(
        color=alt.Color("siteid:N", scale=color_scale, legend=None)
    )
    
    np = alt.layer(plot, point).resolve_scale(color='independent') # if is_country else plot
    
    """
    Country-Level
    """
    dm = d[d.siteid.str.contains('META')].copy()
    dm.siteid = dm.siteid.apply(lambda x: x.replace('META-', '')) # .replace('USA', 'NORTH AMERICA'))
    color_scale=alt.Scale(domain=['FRANCE', 'GERMANY', 'USA'], range=['#0072B2', '#029F73', '#D6641E'])
    
    plot = alt.Chart(
        dm
    ).mark_line(
#         point=True,
        size=2,
#         stroke='black',
        opacity=1
    ).encode(
        x=alt.X("day:O", title='Days Since Admission', axis=alt.Axis(labelAngle=0, tickCount=10), scale=alt.Scale(clamp=True, nice=False, padding=0.3, domain=['3', '7', '14'])),
        y=alt.Y("value:Q", title=None, scale=alt.Scale(zero=False, nice=False, padding=10)),
        # shape=alt.Shape("siteid:N", title='Continent', scale=alt.Scale(domain=['USA', 'EUROPE'], range=['circle', 'diamond']), legend=alt.Legend() if showLegend else None),
        color=alt.Color("siteid:N", title='Country Level', scale=color_scale, legend=alt.Legend(symbolDash=[3, 3], symbolStrokeWidth=4, symbolSize=300) if showLegend else None),
        strokeDash=alt.value([3, 3])
    ).properties(
        width=450,
        height=250
    )
    
    point = plot.mark_point(
        filled=True,
        opacity=1,
        size=50
    ).encode(
        color=alt.Color("siteid:N", scale=color_scale, legend=None)
    )
    

    cp = alt.layer(plot, point).resolve_scale(color='independent') # if is_country else plot

    """
    Site-Level
    """
    dm = d[~d.siteid.str.contains('META')].copy()
    dm['country'] = dm.siteid.apply(lambda x: site_to_contry[x])
    color_scale=alt.Scale(domain=sites, range=site_colors)

    plot = alt.Chart(
        dm
    ).mark_line(
#         point=True,
        size=2.5,
#         stroke='black',
        opacity=0
    ).encode(
        x=alt.X("day:O", title='Days Since Admission', axis=alt.Axis(labelAngle=0, tickCount=10), scale=alt.Scale(clamp=True, nice=False, padding=0.3, domain=['3', '7', '14'])),
        y=alt.Y("value:Q", title=None, scale=alt.Scale(zero=False, nice=False, padding=10, domain=[0, 200], clamp=True)), # , 
        color=alt.Color("siteid:N", title=None, scale=color_scale, legend=None)
    ).properties(
        width=450,
        height=250
    )
    
    point = plot.mark_point(
        filled=True,
        opacity=0.5,
        size=150
    ).encode(
        color=alt.Color("country:N", title='Site Level', scale=alt.Scale(domain=['FRANCE', 'GERMANY', 'USA'], range=['#0072B2', '#029F73', '#D6641E']), legend=alt.Legend() if showLegend else None),
        size=alt.Size('N:Q', title='Sample Size', legend=alt.Legend(symbolFillColor='black') if showLegend else None),
        
        # color=alt.Color("N:Q", title='Sample Size', legend=alt.Legend() if showLegend else None),
        # shape=alt.Shape("siteid:N", title=None, scale=alt.Scale(domain=sites, range=site_shapes), legend=None),
        
        # opacity=alt.Opacity('N:Q', title='Sample Size', scale=alt.Scale(range=[0, 1]))
    )
    

    sp = (point) # if is_country else plot
    
    """
    Combine
    """
    plot = alt.vconcat((mp + np), (sp + cp), spacing=10).resolve_scale(color='independent', size='independent', shape='independent')
    
    plot = plot.properties(
        title={
            "text": [
#                 f"{lab}",
            ],
            "fontSize": 18,
            "dx": 30,
#             "subtitle": [
#                 get_visualization_subtitle(data_release=data_release, with_num_sites=False)
#             ],
            "subtitleColor": "gray",
        }
    )

    return plot

# plot = plot_lab(df=df)

#plot = alt.hconcat(*(
#   plot_lab(df=df, lab=model, is_country=True) for model in unique_models[:1]
#), spacing=30)


plot = plot_lab(df=df, is_country=True)

# plot = plot.properties(
#     title={
#         "text": [
#             f"Cox Model For Death Prediction",
#         ],
#         "dx": 30,
# #         "subtitle": [
# #             get_visualization_subtitle(data_release=data_release, with_num_sites=False)
# #         ],
#         "subtitleColor": "gray",
#     }
# )

plot = apply_theme(
    plot,
    axis_y_title_font_size=16,
    title_anchor='start',
    legend_orient='right',
    legend_title_orient='top',
    axis_label_font_size=14,
    header_label_font_size=16,
    subtitle_font_size=18,
#     point_size=100
)

plot.display()
save(plot,join("..", "result", "R1-prediction-cov-cases.png"), scalefactor=8.0)

In [None]:
print(df)