In [19]:
import pandas as pd
import numpy as np
from bokeh.io import output_notebook, show
from bokeh.plotting import figure
from bokeh.models import ColumnDataSource, FactorRange, HoverTool, ColorBar, LinearColorMapper, BasicTicker
from bokeh.transform import dodge, transform
from bokeh.layouts import gridplot, column, row
from bokeh.palettes import RdYlGn11 as Palette
from bokeh.models import LabelSet
from bokeh.palettes import RdYlGn11

# Initialize Bokeh
output_notebook()

# ==========================================
# 1. DATA PREPARATION
# ==========================================

# 1. Federated Stage 1 (New Data)
stage1_results = {
    'tank_screw_image_roc_auc': 0.1528, 'engine_wiring_image_ap': 0.1802, 
    'engine_wiring_pixel_roc_auc': 0.2323, 'tank_screw_pixel_roc_auc': 0.3018, 
    'pipe_staple_image_ap': 0.1499, 'underbody_screw_pixel_roc_auc': 0.2943, 
    'pipe_clip_pixel_ap': 0.0029, 'underbody_screw_image_roc_auc': 0.1042, 
    'underbody_screw_image_ap': 0.0130, 'underbody_screw_pixel_ap': 0.0002, 
    'engine_wiring_pixel_ap': 0.0019, 'pipe_staple_pixel_roc_auc': 0.1838, 
    'engine_wiring_image_roc_auc': 0.1679, 'pipe_clip_image_roc_auc': 0.1838, 
    'pipe_staple_pixel_ap': 0.0045, 'tank_screw_image_ap': 0.0691, 
    'underbody_pipes_pixel_ap': 0.0114, 'underbody_pipes_pixel_roc_auc': 0.2277, 
    'pipe_clip_pixel_roc_auc': 0.2194, 'pipe_clip_image_ap': 0.1538, 
    'underbody_pipes_image_ap': 0.1839, 'pipe_staple_image_roc_auc': 0.1937, 
    'underbody_pipes_image_roc_auc': 0.1793, 'tank_screw_pixel_ap': 0.0032
}

# 2. Federated Final (Current Results from previous context)
current_results = {
    "engine_wiring_image_roc_auc": 0.632, "engine_wiring_pixel_roc_auc": 0.864,
    "engine_wiring_image_ap": 0.670,      "engine_wiring_pixel_ap": 0.015,
    "tank_screw_image_roc_auc": 0.480,    "tank_screw_pixel_roc_auc": 0.823,
    "tank_screw_image_ap": 0.205,         "tank_screw_pixel_ap": 0.004,
    "pipe_clip_image_roc_auc": 0.565,     "pipe_clip_pixel_roc_auc": 0.802,
    "pipe_clip_image_ap": 0.456,          "pipe_clip_pixel_ap": 0.015,
    "underbody_pipes_image_roc_auc": 0.972, "underbody_pipes_pixel_roc_auc": 0.878,
    "underbody_pipes_image_ap": 0.979,      "underbody_pipes_pixel_ap": 0.244,
    "underbody_screw_image_roc_auc": 0.642, "underbody_screw_pixel_roc_auc": 0.988,
    "underbody_screw_image_ap": 0.068,      "underbody_screw_pixel_ap": 0.012,
    # Adding missing pipe_staple keys to current_results for parity
    "pipe_staple_image_roc_auc": 0.5260047281323876, "pipe_staple_pixel_roc_auc": 0.8037051468561723,
    "pipe_staple_image_ap": 0.48773680025241245, "pipe_staple_pixel_ap": 0.034420750483208036
}

# 3. Centralized (Reference Data from previous context)
reference_results = {
    "engine_wiring_image_roc_auc": 0.5478, "engine_wiring_pixel_roc_auc": 0.8568,
    "engine_wiring_image_ap": 0.5961,      "engine_wiring_pixel_ap": 0.0140,
    "pipe_clip_image_roc_auc": 0.5111,     "pipe_clip_pixel_roc_auc": 0.8084,
    "pipe_clip_image_ap": 0.4253,          "pipe_clip_pixel_ap": 0.0156,
    "pipe_staple_image_roc_auc": 0.5869,   "pipe_staple_pixel_roc_auc": 0.7934,
    "pipe_staple_image_ap": 0.4650,        "pipe_staple_pixel_ap": 0.0329,
    "tank_screw_image_roc_auc": 0.4728,    "tank_screw_pixel_roc_auc": 0.8171,
    "tank_screw_image_ap": 0.2163,         "tank_screw_pixel_ap": 0.0035,
    "underbody_pipes_image_roc_auc": 0.8291, "underbody_pipes_pixel_roc_auc": 0.8773,
    "underbody_pipes_image_ap": 0.7432,      "underbody_pipes_pixel_ap": 0.2437,
    "underbody_screw_image_roc_auc": 0.6051, "underbody_screw_pixel_roc_auc": 0.9881,
    "underbody_screw_image_ap": 0.0579,      "underbody_screw_pixel_ap": 0.0142
}

def parse_metrics(data_dict, label):
    rows = []
    metrics = ["image_roc_auc", "pixel_roc_auc", "image_ap", "pixel_ap"]
    for key, value in data_dict.items():
        if any(key.endswith(m) for m in metrics):
            for metric in metrics:
                if key.endswith(f"_{metric}"):
                    category = key.replace(f"_{metric}", "")
                    rows.append({
                        "Category": category.replace("_", " ").title(),
                        "Metric": metric.replace("_", " ").upper(),
                        "Value": value,
                        "Type": label
                    })
                    break
    return pd.DataFrame(rows)

df_ref = parse_metrics(reference_results, "Centralized")
df_st1 = parse_metrics(stage1_results, "Fed Stage 1")
df_cur = parse_metrics(current_results, "Fed Final")
df_main = pd.concat([df_ref, df_st1, df_cur])

# ==========================================
# 2. PLOT TYPE 1: PER-CATEGORY COMPARISON
# ==========================================

def create_category_plots(df):
    plots = []
    categories = sorted(df['Category'].unique())
    # You can customize metric order here:
    metrics = ["IMAGE ROC AUC", "PIXEL ROC AUC", "IMAGE AP", "PIXEL AP"]
    
    for cat in categories:
        cat_data = df[df['Category'] == cat]
        
        # Extract values
        v_cent = cat_data[cat_data['Type'] == 'Centralized']['Value'].values
        v_st1  = cat_data[cat_data['Type'] == 'Fed Stage 1']['Value'].values
        v_final= cat_data[cat_data['Type'] == 'Fed Final']['Value'].values
        
        # Create Text Labels
        t_cent = [f"{x:.2f}" for x in v_cent]
        t_st1  = [f"{x:.2f}" for x in v_st1]
        t_final= [f"{x:.2f}" for x in v_final]

        source = ColumnDataSource(data={
            'metrics': metrics,
            'centralized': v_cent, 'centralized_txt': t_cent,
            'stage1': v_st1,       'stage1_txt': t_st1,
            'final': v_final,      'final_txt': t_final,
        })

        p = figure(x_range=metrics, height=350, width=450, title=f"{cat}",
                   toolbar_location=None, tools="hover")
        
        # --- POSITION CONFIGURATION ---
        w = 0.2
        # LEFT: Centralized
        off_cen = -w - 0.05
        # MIDDLE: Stage 1
        off_st1 = 0
        # RIGHT: Final
        off_fin = w + 0.05

        # --- BARS ---
        # 1. Centralized (Left)
        p.vbar(x=dodge('metrics', off_cen, range=p.x_range), top='centralized', width=w, source=source,
               color="#5CB879", legend_label="Centralized")
        
        # 2. Stage 1 (Middle)
        p.vbar(x=dodge('metrics', off_st1, range=p.x_range), top='stage1', width=w, source=source,
               color="#F4A582", legend_label="Fed Stage 1")
        
        # 3. Final (Right)
        p.vbar(x=dodge('metrics', off_fin, range=p.x_range), top='final', width=w, source=source,
               color="#D9534F", legend_label="Fed Final")

        # --- LABELS ---
        # Centralized Labels
        labels_cen = LabelSet(x=dodge('metrics', off_cen, range=p.x_range), y='centralized', text='centralized_txt',
                              level='glyph', x_offset=0, y_offset=2, source=source, text_align='center', text_font_size='7pt')
        p.add_layout(labels_cen)

        # Stage 1 Labels
        labels_st1 = LabelSet(x=dodge('metrics', off_st1, range=p.x_range), y='stage1', text='stage1_txt',
                              level='glyph', x_offset=0, y_offset=2, source=source, text_align='center', text_font_size='7pt')
        p.add_layout(labels_st1)

        # Final Labels
        labels_fin = LabelSet(x=dodge('metrics', off_fin, range=p.x_range), y='final', text='final_txt',
                              level='glyph', x_offset=0, y_offset=2, source=source, text_align='center', text_font_size='7pt')
        p.add_layout(labels_fin)

        # Styling
        p.y_range.start = 0
        p.y_range.end = 1.3
        p.legend.location = "top_right"
        p.legend.orientation = "horizontal"
        p.legend.label_text_font_size = "7pt"
        p.xaxis.major_label_orientation = 0.2
        
        plots.append(p)
        
    return plots

# ==========================================
# 3. PLOT TYPE 2: HEATMAP (Final vs Centralized)
# ==========================================
def create_heatmap_final_v2(df_ref, df_curr):
    # 1. Merge
    df_merge = pd.merge(df_ref, df_curr, on=['Category', 'Metric'], suffixes=('_Ref', '_Fed'))
    
    # 2. Calculate Percentage Change
    # Formula: (Fed - Cent) / Cent * 100
    df_merge['Change'] = ((df_merge['Value_Fed'] - df_merge['Value_Ref']) / df_merge['Value_Ref']) * 100
    
    # Format text
    df_merge['Change_txt'] = df_merge['Change'].map(lambda x: f"{x:+.1f}")
    
    source = ColumnDataSource(df_merge)
    
    # 3. DEFINE PALETTE (REVERSED)
    # We flip the palette using [::-1].
    # Now: Low Values (Negative/Red) -> High Values (Positive/Green)
    palette = RdYlGn11[::-1]
    
    # 4. CONFIGURE MAPPER
    # low=-30 (Dark Red), high=30 (Dark Green)
    mapper = LinearColorMapper(palette=palette, low=-30, high=30)

    p = figure(title="Relative Performance Change (%) - Fed Final vs Centralized",
               x_range=sorted(df_merge['Metric'].unique()),
               y_range=sorted(df_merge['Category'].unique(), reverse=True),
               height=400, width=550,
               toolbar_location=None, tools="hover")

    p.rect(x="Metric", y="Category", width=1, height=1, source=source,
           line_color='white', fill_color=transform('Change', mapper))

    # Add Color Bar
    color_bar = ColorBar(color_mapper=mapper, location=(0, 0),
                         ticker=BasicTicker(desired_num_ticks=10))
    p.add_layout(color_bar, 'right')
    
    # Add Text Labels
    p.text(x="Metric", y="Category", text="Change_txt", source=source,
           text_align="center", text_baseline="middle", text_color="black",
           text_font_size="10pt")

    # Add Hover Tool
    hover = p.select(dict(type=HoverTool))
    hover.tooltips = [("Category", "@Category"), ("Change", "@Change_txt%")]

    return p
    
# ==========================================
# 4. PLOT TYPE 3: SUMMARY (Ordered)
# ==========================================

def create_summary_plot(df):
    avg_df = df.groupby(['Metric', 'Type'])['Value'].mean().reset_index()
    metrics = ["IMAGE ROC AUC", "PIXEL ROC AUC", "IMAGE AP", "PIXEL AP"]
    
    source = ColumnDataSource(data={
        'metrics': metrics,
        'centralized': avg_df[avg_df['Type'] == 'Centralized']['Value'].values,
        'stage1': avg_df[avg_df['Type'] == 'Fed Stage 1']['Value'].values,
        'final': avg_df[avg_df['Type'] == 'Fed Final']['Value'].values,
    })

    p = figure(x_range=metrics, height=400, width=600, title="Average Performance Progression",
               toolbar_location=None, tools="hover")

    w = 0.2
    # Same ordering as above: Centralized -> Stage 1 -> Final
    p.vbar(x=dodge('metrics', -w-0.05, range=p.x_range), top='centralized', width=w, source=source,
           color="#5CB879", legend_label="Centralized")
    
    p.vbar(x=dodge('metrics', 0, range=p.x_range), top='stage1', width=w, source=source,
           color="#F4A582", legend_label="Fed Stage 1")
    
    p.vbar(x=dodge('metrics', w+0.05, range=p.x_range), top='final', width=w, source=source,
           color="#D9534F", legend_label="Fed Final")
    
    p.y_range.start = 0
    p.y_range.end = 1.1
    p.legend.location = "top_right"
    p.legend.orientation = "horizontal"
    
    return p

# ==========================================
# NEW: METRIC-SPECIFIC COMPARISON CHARTS
# ==========================================

# ==========================================
# FIXED LAYOUT: 2x2 GRID FOR PERFECT ALIGNMENT
# ==========================================

# 1. GENERATE THE INDIVIDUAL PLOTS
# (Assumes create_metric_comparison_charts returns a 'row', so we need to unpack it slightly
# or just call the internal logic. To be safe, I'll assume you ran the function 'create_metric_comparison_charts'
# and 'create_category_plots' etc. from previous cells.)

# First, let's unpack the two metric charts so we can place them individually
# Note: Since the previous function returned a 'row', let's just regenerate them individually here
# to ensure we can put them into a clean grid.

def create_individual_metric_charts(df_ref, df_curr):
    # --- HELPER ---
    def get_source_data(metric_name):
        ref = df_ref[df_ref['Metric'] == metric_name].set_index('Category')['Value']
        fed = df_curr[df_curr['Type'] == 'Fed Final']
        if fed.empty: fed = df_curr[df_curr['Type'] == 'Federated']
        fed = fed[fed['Metric'] == metric_name].set_index('Category')['Value']
        
        cats = sorted(ref.index.unique())
        data = {
            'categories': cats,
            'centralized': [ref.get(c, 0) for c in cats],
            'federated':   [fed.get(c, 0) for c in cats],
            'centralized_txt': [f"{ref.get(c, 0):.2f}" for c in cats],
            'federated_txt':   [f"{fed.get(c, 0):.2f}" for c in cats]
        }
        return ColumnDataSource(data), cats

    # --- PLOT 1: IMAGE ROC ---
    src_img, cats_img = get_source_data("IMAGE ROC AUC")
    p1 = figure(x_range=cats_img, height=350, width=500, title="Image-Level ROC-AUC Comparison",
                toolbar_location=None, tools="hover")
    p1.vbar(x=dodge('categories', -0.15, range=p1.x_range), top='centralized', width=0.25, source=src_img, color="#50C878", legend_label="Centralized")
    p1.vbar(x=dodge('categories',  0.15, range=p1.x_range), top='federated', width=0.25, source=src_img, color="#F05F5F", legend_label="Federated")
    p1.add_layout(LabelSet(x=dodge('categories', -0.15, range=p1.x_range), y='centralized', text='centralized_txt', level='glyph', x_offset=0, y_offset=2, source=src_img, text_align='center', text_font_size='8pt'))
    p1.add_layout(LabelSet(x=dodge('categories', 0.15, range=p1.x_range), y='federated', text='federated_txt', level='glyph', x_offset=0, y_offset=2, source=src_img, text_align='center', text_font_size='8pt'))
    p1.y_range.end = 1.15
    p1.legend.location = "top_right"

    # --- PLOT 2: PIXEL ROC ---
    src_pix, cats_pix = get_source_data("PIXEL ROC AUC")
    p2 = figure(x_range=cats_pix, height=350, width=500, title="Pixel-Level ROC-AUC Comparison",
                toolbar_location=None, tools="hover")
    p2.vbar(x=dodge('categories', -0.15, range=p2.x_range), top='centralized', width=0.25, source=src_pix, color="#5DA5DA", legend_label="Centralized")
    p2.vbar(x=dodge('categories',  0.15, range=p2.x_range), top='federated', width=0.25, source=src_pix, color="#FAA43A", legend_label="Federated")
    p2.add_layout(LabelSet(x=dodge('categories', -0.15, range=p2.x_range), y='centralized', text='centralized_txt', level='glyph', x_offset=0, y_offset=2, source=src_pix, text_align='center', text_font_size='8pt'))
    p2.add_layout(LabelSet(x=dodge('categories', 0.15, range=p2.x_range), y='federated', text='federated_txt', level='glyph', x_offset=0, y_offset=2, source=src_pix, text_align='center', text_font_size='8pt'))
    p2.y_range.end = 1.15
    p2.legend.location = "top_right"

    return p1, p2

# ==========================================
# 5. EXECUTE
# ==========================================
# 2. CREATE INSTANCES
cat_plots = create_category_plots(df_main) # The detailed grid
p_img, p_pix = create_individual_metric_charts(df_ref, df_cur) # The two top charts
heatmap_plot = create_heatmap_final_v2(df_ref, df_cur) # Bottom left
summary_plot = create_summary_plot(df_main) # Bottom right

# Ensure widths match for perfect alignment
p_img.width = 500
p_pix.width = 500
heatmap_plot.width = 500
summary_plot.width = 500

# 3. CONSTRUCT THE GRID
# Top Section: Detailed Category Grid (3 cols)
grid_top = gridplot(cat_plots, ncols=3)

# Bottom Section: 2x2 Matrix
# [ Image ROC | Pixel ROC ]
# [ Heatmap   | Summary   ]
grid_bottom = gridplot([
    [p_img, p_pix],
    [heatmap_plot, summary_plot]
])

# 4. FINAL DISPLAY
final_layout = column(grid_top, grid_bottom)

show(final_layout)