In [None]:
%load_ext autoreload
%autoreload 2
%aimport utils_1_1

import pandas as pd
import numpy as np
import altair as alt
from altair_saver import save
import datetime
import dateutil.parser
from os.path import join

from constants_1_1 import SITE_FILE_TYPES
from utils_1_1 import (
    get_site_file_paths,
    get_site_file_info,
    get_site_ids,
    get_visualization_subtitle,
    get_country_color_map,
)
from theme import apply_theme
from web import for_website

alt.data_transformers.disable_max_rows(); # Allow using rows more than 5000

# AUC Matrix

In [None]:
df = pd.read_csv(join("..", "data", "Phase2.1SurvivalRSummariesPublic", "ToShare", "table.auc.port.R1.toShare.csv"))

df = df.rename(columns={"site.from": "from", 'site.to': 'to', 'N.from': 'from_size'})
print(df)
# Convert to fake id
idf = pd.read_csv(join('..', 'data', "SiteID_MAP_Phase1.1.csv"), sep=",", header=0)
mapping = dict(zip(idf.siteid, idf.fakeid))

# Missing key-value pairs
mapping['all'] = 'all'
mapping['US'] = 'US'
mapping['France'] = 'France'
mapping['Germany'] = 'Germany'
mapping['Spain'] = 'Spain'

site_to_contry = {
    'SITE1': 'France',
    'SITE2': 'France',
    'SITE3': 'France',
    'SITE4': 'Germany',
    'SITE5': 'USA',
    'SITE6': 'USA', 
    'SITE7': 'USA', 
    'SITE8': 'USA', 
    'SITE9': 'USA', 
    'SITE10': 'USA', 
    'SITE11': 'USA', 
    'SITE12': 'USA', 
    'SITE13': 'USA', 
    'SITE14': 'USA', 
    'SITE15': 'USA',
    'SITE16': 'Spain',
    'all': 'Meta',
    'US': 'Meta',
    'France': 'Meta',
    'Germany': 'Meta',
    'Spain': 'Meta',
}

df['country'] = df['from'].apply(lambda x: site_to_contry[mapping[x]])
df['to-country'] = df['to'].apply(lambda x: site_to_contry[mapping[x]])

df['from'] = df['from'].apply(lambda x: f"{mapping[x]} →")
df['to'] = df['to'].apply(lambda x: f"→ {mapping[x].replace('all', 'All').replace('US', 'USA')}") # ←


countries = ['USA', 'France', 'Germany', "Spain"] # , 'ITALY']
country_colors = ['#D45E00', '#0072B2', '#029F73', '#B2AA2F']

# df = pd.melt(df, id_vars=['siteid', 'to'], var_name='day', value_name='value')

unique_sites = df['from'].unique().tolist()
print(unique_sites)
print(len(unique_sites))

df

In [None]:
def plot(df=None):
    d = df.copy()
      
#     sort_from = ['APHP →', 'FRBDX →', 'UKFR →', 'BIDMC →', 'NWU →', 'MGB →', 'VA2 →', 'VA3 →',  'UPENN →',    'VA4 →', 'VA1 →', 'VA5 →', 'UPITT →','UMICH →','UCLA →']
#     sort_to = ['→ APHP', '→ FRBDX', '→ UKFR', '→ BIDMC', '→ NWU', '→ MGB', '→ VA2', '→ VA3', '→ UPENN',     '→ VA4', '→ VA1','→ VA5',  '→ UPITT','→ UMICH', '→ UCLA']

    sort_from = ['SITE1 →', 'SITE2 →', 'SITE3 →', 'SITE4 →', 'SITE5 →', 'SITE6 →', 'SITE7 →', 'SITE8 →',  'SITE9 →',    'SITE10 →', 'SITE11 →', 'SITE12 →', 'SITE13 →','SITE14 →','SITE15 →','SITE16 →']
    sort_to = ['→ SITE1', '→ SITE2', '→ SITE3', '→ SITE4', '→ SITE5', '→ SITE6', '→ SITE7', '→ SITE8', '→ SITE9',     '→ SITE10', '→ SITE11','→ SITE12',  '→ SITE13','→ SITE14', '→ SITE15', '→ SITE16']
    
    plot = alt.Chart(
        d[d['to-country'] != 'Meta']
    ).mark_square(
        opacity=1
    ).encode(
        x=alt.X("to:N", title='To', axis=alt.Axis(labelAngle=-55, domain=True, orient='top'), scale=alt.Scale(), sort=sort_to),
        y=alt.Y("from:N", title='From', axis=alt.Axis(labelAngle=0, domain=True), scale=alt.Scale(), sort=sort_from),
        size=alt.Size('from_size:Q', title='Sample Size', scale=alt.Scale(range=[100, 2000], type='log'), legend=alt.Legend(direction='horizontal', symbolFillColor='lightgray')),
#         size=alt.value(2400),
        color=alt.Color("auc:Q", title='AUC', scale=alt.Scale(scheme='redpurple', domain=[0.7, 0.9]), legend=alt.Legend(direction='horizontal', gradientLength=440, gradientThickness=30)),
    ).properties(
        width=750,
        height=750
    )
    
    meta = alt.Chart(
        d[d['to-country'] == 'Meta']
    ).mark_square(
        opacity=1
    ).encode(
        x=alt.X("to:N", title=None, axis=alt.Axis(labelAngle=-55, domain=True, orient='top'), scale=alt.Scale(), sort=sort_to),
        y=alt.Y("from:N", title=None, axis=alt.Axis(labelAngle=0, domain=False, labels=False), scale=alt.Scale(), sort=sort_from),
        size=alt.Size('from_size:Q', title='Sample Size', scale=alt.Scale(range=[100, 2000], type='log'), legend=alt.Legend(direction='horizontal', symbolFillColor='lightgray')),
#         size=alt.value(2400),
        color=alt.Color("auc:Q", title='AUC', scale=alt.Scale(scheme='redpurple', domain=[0.7, 0.9]), legend=alt.Legend(direction='horizontal', gradientLength=440, gradientThickness=30)),
    ).properties(
        width=200,
        height=750
    )
    
    n = alt.Chart(
        d
    ).mark_bar(
        opacity=1
    ).encode(
        x=alt.X("mean(from_size):Q", title='Sample Size', axis=alt.Axis(labelAngle=0, tickCount=2, domain=True, orient='top'), scale=alt.Scale(domain=[0, 20000])),
        y=alt.Y("from:N", title=None, axis=alt.Axis(labelAngle=0, labels=False, domain=True), scale=alt.Scale(), sort=sort_from),
        color=alt.Color("country:N", title='Country', scale=alt.Scale(range=country_colors, domain=countries), legend=alt.Legend(direction='vertical', gradientLength=440, gradientThickness=30)),
    ).properties(
        width=100,
        height=750
    )
    
    text = n.mark_text(
        angle=90,
        align='center',
        baseline='bottom',
        dy=-6,
    ).encode(
        text='from_size'
    )

    
    plot = alt.hconcat(alt.hconcat(plot, meta).resolve_scale(color='shared'), (n + text), spacing=0).properties(
#         title={
#             "text": [
#                 f"Transportability Of Cox Regression Model Across Different Sites",
#             ],
#             "dx": 50,
#             "fontSize": 18
#         }
    )
    
    return plot

res = plot(df)

res = apply_theme(
    res,
    axis_y_title_font_size=16,
    title_anchor='start',
    legend_orient='bottom',
    legend_title_orient='top',
    axis_label_font_size=14,
    header_label_font_size=16,
    subtitle_font_size=18,
#     legend_stroke_color='white'
)

res.display()
save(res,join("..", "result", "R1-auc-matrix.png"), scalefactor=8.0)

