In [4]:
import pandas as pd
import numpy as np
from bokeh.io import output_notebook, show
from bokeh.plotting import figure
from bokeh.models import ColumnDataSource, FactorRange, HoverTool, ColorBar, LinearColorMapper, BasicTicker
from bokeh.transform import dodge, transform
from bokeh.layouts import gridplot, column, row
from bokeh.palettes import RdYlGn11 as Palette
from bokeh.models import LabelSet

# Initialize Bokeh
output_notebook()

# ==========================================
# 1. DATA PREPARATION
# ==========================================

# 1. Federated Stage 1 (New Data)
stage1_results = {
    'tank_screw_image_roc_auc': 0.1528, 'engine_wiring_image_ap': 0.1802, 
    'engine_wiring_pixel_roc_auc': 0.2323, 'tank_screw_pixel_roc_auc': 0.3018, 
    'pipe_staple_image_ap': 0.1499, 'underbody_screw_pixel_roc_auc': 0.2943, 
    'pipe_clip_pixel_ap': 0.0029, 'underbody_screw_image_roc_auc': 0.1042, 
    'underbody_screw_image_ap': 0.0130, 'underbody_screw_pixel_ap': 0.0002, 
    'engine_wiring_pixel_ap': 0.0019, 'pipe_staple_pixel_roc_auc': 0.1838, 
    'engine_wiring_image_roc_auc': 0.1679, 'pipe_clip_image_roc_auc': 0.1838, 
    'pipe_staple_pixel_ap': 0.0045, 'tank_screw_image_ap': 0.0691, 
    'underbody_pipes_pixel_ap': 0.0114, 'underbody_pipes_pixel_roc_auc': 0.2277, 
    'pipe_clip_pixel_roc_auc': 0.2194, 'pipe_clip_image_ap': 0.1538, 
    'underbody_pipes_image_ap': 0.1839, 'pipe_staple_image_roc_auc': 0.1937, 
    'underbody_pipes_image_roc_auc': 0.1793, 'tank_screw_pixel_ap': 0.0032
}

# 2. Federated Final (Current Results from previous context)
current_results = {
    "engine_wiring_image_roc_auc": 0.632, "engine_wiring_pixel_roc_auc": 0.864,
    "engine_wiring_image_ap": 0.670,      "engine_wiring_pixel_ap": 0.015,
    "tank_screw_image_roc_auc": 0.480,    "tank_screw_pixel_roc_auc": 0.823,
    "tank_screw_image_ap": 0.205,         "tank_screw_pixel_ap": 0.004,
    "pipe_clip_image_roc_auc": 0.565,     "pipe_clip_pixel_roc_auc": 0.802,
    "pipe_clip_image_ap": 0.456,          "pipe_clip_pixel_ap": 0.015,
    "underbody_pipes_image_roc_auc": 0.972, "underbody_pipes_pixel_roc_auc": 0.878,
    "underbody_pipes_image_ap": 0.979,      "underbody_pipes_pixel_ap": 0.244,
    "underbody_screw_image_roc_auc": 0.642, "underbody_screw_pixel_roc_auc": 0.988,
    "underbody_screw_image_ap": 0.068,      "underbody_screw_pixel_ap": 0.012,
    # Adding missing pipe_staple keys to current_results for parity
    "pipe_staple_image_roc_auc": 0,     "pipe_staple_pixel_roc_auc": 0,
    "pipe_staple_image_ap": 0,          "pipe_staple_pixel_ap": 0
}

# 3. Centralized (Reference Data from previous context)
reference_results = {
    "engine_wiring_image_roc_auc": 0.5478, "engine_wiring_pixel_roc_auc": 0.8568,
    "engine_wiring_image_ap": 0.5961,      "engine_wiring_pixel_ap": 0.0140,
    "pipe_clip_image_roc_auc": 0.5111,     "pipe_clip_pixel_roc_auc": 0.8084,
    "pipe_clip_image_ap": 0.4253,          "pipe_clip_pixel_ap": 0.0156,
    "pipe_staple_image_roc_auc": 0.5869,   "pipe_staple_pixel_roc_auc": 0.7934,
    "pipe_staple_image_ap": 0.4650,        "pipe_staple_pixel_ap": 0.0329,
    "tank_screw_image_roc_auc": 0.4728,    "tank_screw_pixel_roc_auc": 0.8171,
    "tank_screw_image_ap": 0.2163,         "tank_screw_pixel_ap": 0.0035,
    "underbody_pipes_image_roc_auc": 0.8291, "underbody_pipes_pixel_roc_auc": 0.8773,
    "underbody_pipes_image_ap": 0.7432,      "underbody_pipes_pixel_ap": 0.2437,
    "underbody_screw_image_roc_auc": 0.6051, "underbody_screw_pixel_roc_auc": 0.9881,
    "underbody_screw_image_ap": 0.0579,      "underbody_screw_pixel_ap": 0.0142
}

def parse_metrics(data_dict, label):
    """Parses flat JSON keys into structured DataFrame rows."""
    rows = []
    metrics = ["image_roc_auc", "pixel_roc_auc", "image_ap", "pixel_ap"]
    
    for key, value in data_dict.items():
        # Skip non-metric keys like 'loss'
        if any(key.endswith(m) for m in metrics):
            for metric in metrics:
                if key.endswith(f"_{metric}"):
                    category = key.replace(f"_{metric}", "")
                    rows.append({
                        "Category": category.replace("_", " ").title(),
                        "Metric": metric.replace("_", " ").upper(),
                        "Value": value,
                        "Type": label
                    })
                    break
    return pd.DataFrame(rows)

# Create and Combine DataFrames
df_ref = parse_metrics(reference_results, "Centralized")
df_st1 = parse_metrics(stage1_results, "Fed Stage 1")
df_cur = parse_metrics(current_results, "Fed Final")

df_main = pd.concat([df_ref, df_st1, df_cur])

# ==========================================
# 2. PLOT TYPE 1: PER-CATEGORY COMPARISON (3 BARS)
# ==========================================

def create_category_plots(df):
    plots = []
    categories = sorted(df['Category'].unique())
    metrics = df['Metric'].unique().tolist()
    
    for cat in categories:
        cat_data = df[df['Category'] == cat]
        
        # 1. Get raw values
        v_cent = cat_data[cat_data['Type'] == 'Centralized']['Value'].values
        v_st1  = cat_data[cat_data['Type'] == 'Fed Stage 1']['Value'].values
        v_final= cat_data[cat_data['Type'] == 'Fed Final']['Value'].values
        
        # 2. Create Formatted Text Strings for the Labels (e.g., "0.85")
        # We handle empty data or NaNs by checking length or using simple list comprehensions
        t_cent = [f"{x:.2f}" for x in v_cent]
        t_st1  = [f"{x:.2f}" for x in v_st1]
        t_final= [f"{x:.2f}" for x in v_final]

        source = ColumnDataSource(data={
            'metrics': metrics,
            'centralized': v_cent, 'centralized_txt': t_cent,
            'stage1': v_st1,       'stage1_txt': t_st1,
            'final': v_final,      'final_txt': t_final,
        })

        p = figure(x_range=metrics, height=350, width=450, title=f"{cat}",
                   toolbar_location=None, tools="hover")
        
        # Bar Width and Offsets
        w = 0.2
        off_st1 = -w - 0.05
        off_fin = 0
        off_cen = w + 0.05
        
        # --- BARS ---
        p.vbar(x=dodge('metrics', off_st1, range=p.x_range), top='stage1', width=w, source=source,
               color="#F4A582", legend_label="Fed Stage 1")
        p.vbar(x=dodge('metrics', off_fin, range=p.x_range), top='final', width=w, source=source,
               color="#D9534F", legend_label="Fed Final")
        p.vbar(x=dodge('metrics', off_cen, range=p.x_range), top='centralized', width=w, source=source,
               color="#5CB879", legend_label="Centralized")

        # --- LABELS ---
        # Note: We use the SAME dodge offset for x to align with the bars
        
        # Label for Stage 1
        labels_st1 = LabelSet(x=dodge('metrics', off_st1, range=p.x_range), y='stage1', text='stage1_txt',
                              level='glyph', x_offset=0, y_offset=2, source=source,
                              text_align='center', text_font_size='7pt')
        p.add_layout(labels_st1)

        # Label for Final
        labels_fin = LabelSet(x=dodge('metrics', off_fin, range=p.x_range), y='final', text='final_txt',
                              level='glyph', x_offset=0, y_offset=2, source=source,
                              text_align='center', text_font_size='7pt')
        p.add_layout(labels_fin)

        # Label for Centralized
        labels_cen = LabelSet(x=dodge('metrics', off_cen, range=p.x_range), y='centralized', text='centralized_txt',
                              level='glyph', x_offset=0, y_offset=2, source=source,
                              text_align='center', text_font_size='7pt')
        p.add_layout(labels_cen)

        # Styling
        p.y_range.start = 0
        p.y_range.end = 1.3  # Increased height to fit labels
        p.legend.location = "top_left"
        p.legend.orientation = "horizontal"
        p.legend.label_text_font_size = "7pt"
        p.xaxis.major_label_orientation = 0.2
        
        plots.append(p)
        
    return plots

cat_plots = create_category_plots(df_main)

# ==========================================
# 3. PLOT TYPE 2: DEGRADATION HEATMAP
# ==========================================
# Comparing "Fed Final" vs "Centralized" (The primary goal gap)

def create_heatmap(df_ref, df_curr):
    # Merge for comparison
    df_merge = pd.merge(df_ref, df_curr, on=['Category', 'Metric'], suffixes=('_Ref', '_Fed'))
    df_merge['Degradation'] = ((df_merge['Value_Ref'] - df_merge['Value_Fed']) / df_merge['Value_Ref']) * 100
    
    source = ColumnDataSource(df_merge)
    mapper = LinearColorMapper(palette=Palette[::-1], low=0, high=50)

    p = figure(title="Performance Gap (%) - Fed Final vs Centralized",
               x_range=sorted(df_merge['Metric'].unique()),
               y_range=sorted(df_merge['Category'].unique(), reverse=True),
               height=400, width=500,
               toolbar_location=None, tools="hover")

    p.rect(x="Metric", y="Category", width=1, height=1, source=source,
           line_color='white', fill_color=transform('Degradation', mapper))

    color_bar = ColorBar(color_mapper=mapper, location=(0, 0),
                         ticker=BasicTicker(desired_num_ticks=len(Palette)))
    p.add_layout(color_bar, 'right')
    
    p.text(x="Metric", y="Category", text="Degradation", source=source,
           text_align="center", text_baseline="middle", text_color="black", text_font_size="10pt")

    hover = p.select(dict(type=HoverTool))
    hover.tooltips = [("Cat", "@Category"), ("Met", "@Metric"), ("Degradation", "@Degradation{0.0}%")]

    return p

heatmap_plot = create_heatmap(df_ref, df_cur)

# ==========================================
# 4. PLOT TYPE 3: SUMMARY (3 BARS)
# ==========================================

def create_summary_plot(df):
    avg_df = df.groupby(['Metric', 'Type'])['Value'].mean().reset_index()
    metrics = avg_df['Metric'].unique().tolist()
    
    source = ColumnDataSource(data={
        'metrics': metrics,
        'stage1': avg_df[avg_df['Type'] == 'Fed Stage 1']['Value'].values,
        'final': avg_df[avg_df['Type'] == 'Fed Final']['Value'].values,
        'centralized': avg_df[avg_df['Type'] == 'Centralized']['Value'].values,
    })

    p = figure(x_range=metrics, height=400, width=600, title="Average Performance Progression",
               toolbar_location=None, tools="hover")

    w = 0.2
    p.vbar(x=dodge('metrics', -w-0.05, range=p.x_range), top='stage1', width=w, source=source,
           color="#F4A582", legend_label="Fed Stage 1")
    p.vbar(x=dodge('metrics', 0, range=p.x_range), top='final', width=w, source=source,
           color="#D9534F", legend_label="Fed Final")
    p.vbar(x=dodge('metrics', w+0.05, range=p.x_range), top='centralized', width=w, source=source,
           color="#5CB879", legend_label="Centralized")
    
    p.y_range.start = 0
    p.y_range.end = 1.1
    p.legend.location = "top_right"
    p.legend.orientation = "horizontal"
    
    return p

summary_plot = create_summary_plot(df_main)

# ==========================================
# 5. LAYOUT
# ==========================================

grid = gridplot(cat_plots, ncols=3)
layout = column(grid, row(heatmap_plot, summary_plot))

show(layout)