In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from federated_inference.common.utils import set_seed
from federated_inference.configs.data_config import DataConfiguration
from federated_inference.configs.transform_config import DataTransformConfiguration

# Hybrid

In [None]:

from federated_inference.simulations.hybrid.simulation import HybridSplitSimulation
DATASET = 'MNIST'
SEEDS = [4,13,27]

In [None]:
from federated_inference.simulations.hybrid.v6.models import GlobalHybridSplitClassifierHead, HybridSplitBase, LocalHybridSplitClassifierHead, RouterHead
VERSION = "v61"
h1_clients_preds = {}
h1_clients_routers = {}
h1_server_preds = {}
h1_targets = {}

for seed in SEEDS:
    set_seed(seed)
    data_config = DataConfiguration(DATASET)
    transform_config = DataTransformConfiguration()
    simulation_hybrid_1 = HybridSplitSimulation(seed, VERSION,  data_config, transform_config, GlobalHybridSplitClassifierHead, HybridSplitBase, LocalHybridSplitClassifierHead, RouterHead)
    simulation_hybrid_1.server.load()
    _testsets = [client.request_pred(pred_all = True, keep_label = True) for client in simulation_hybrid_1.clients]
    h1_clients_pred, h1_clients_router, h1_server_pred = simulation_hybrid_1.server.run_infernces(_testsets)
    h1_clients_targets =  simulation_hybrid_1.server.testloaders[0].dataset.targets 
    h1_clients_preds[seed] = h1_clients_pred
    h1_clients_routers[seed] = h1_clients_router
    h1_server_preds[seed] =  h1_server_pred 
    h1_targets[seed] = simulation_hybrid_1.server.testloaders[0].dataset.targets 

In [None]:
len(h1_clients_preds[4][0])

In [None]:
len(h1_server_preds[4])

In [None]:

from federated_inference.simulations.hybrid.v7.models import GlobalHybridSplitClassifierHead, HybridSplitBase, LocalHybridSplitClassifierHead, RouterHead
VERSION = "v71"
h2_clients_preds = {}
h2_clients_routers = {}
h2_server_preds = {}
h2_targets = {}

for seed in SEEDS:
    set_seed(seed)
    data_config = DataConfiguration(DATASET)
    transform_config = DataTransformConfiguration()
    simulation_hybrid_2 = HybridSplitSimulation(seed, VERSION,  data_config, transform_config, GlobalHybridSplitClassifierHead, HybridSplitBase, LocalHybridSplitClassifierHead, RouterHead)
    simulation_hybrid_2.server.load()
    _testsets = [client.request_pred(pred_all = True, keep_label = True) for client in simulation_hybrid_2.clients]
    h2_clients_pred, h2_clients_router, h2_server_pred = simulation_hybrid_2.server.run_infernces(_testsets)
    h2_clients_targets =  simulation_hybrid_2.server.testloaders[0].dataset.targets 
    h2_clients_preds[seed] = h2_clients_pred
    h2_clients_routers[seed] = h2_clients_router
    h2_server_preds[seed] =  h2_server_pred 
    h2_targets[seed] = simulation_hybrid_2.server.testloaders[0].dataset.targets 

## On Device

In [None]:
from federated_inference.simulations.ondevice.simulation import OnDeviceVerticalSimulation
from federated_inference.simulations.ondevice.models import OnDeviceMNISTModel
VERSION = "v1"
on_client_preds = {}
for seed in SEEDS:
    data_config = DataConfiguration(DATASET)
    transform_config = DataTransformConfiguration()
    simulation_on_client = OnDeviceVerticalSimulation(seed, VERSION, data_config, transform_config, OnDeviceMNISTModel, exist=False)
    [client.load() for client in simulation_on_client.clients]
    _testsets = [client.data.test_dataset for client in simulation_on_client.clients]
    on_client_pred = simulation_on_client.run_inference(_testsets)
    on_client_preds[seed] = on_client_pred
    targets =  simulation_on_client.clients[0].data.test_dataset.dataset.targets


## On Cloud

In [None]:
from federated_inference.simulations.oncloud.simulation import OnCloudVerticalSimulation
from federated_inference.simulations.oncloud.models.model import OnCloudMNISTModel
VERSION = "v1"
on_cloud_preds = {}
for seed in SEEDS:
    data_config = DataConfiguration(DATASET)
    transform_config = DataTransformConfiguration()
    simulation_on_cloud = OnCloudVerticalSimulation(seed, VERSION, data_config, transform_config, OnCloudMNISTModel, exist=False)
    simulation_on_cloud.server.load()
    _testsets = [client.data.test_dataset for client in simulation_on_cloud.clients]
    on_cloud_pred = simulation_on_cloud.run_inference(_testsets)
    on_cloud_preds[seed] = on_cloud_pred
    targets =  simulation_on_cloud.clients[0].data.test_dataset.dataset.targets

# Evaluation

In [None]:
import plotly.graph_objects as go
from sklearn.metrics import accuracy_score
import plotly.express as px
import itertools
import numpy as np

client_map = ['LT', 'RT', 'LB', "RB"]
sub_labels = ["V1", "V2", "Local-Only"]
server_labels = ["Remote V1", " Remote V2", "Remote-Only"]

def compute_client_accs(client_preds_dict):
    """Returns mean and std accuracy for each client across seeds."""
    client_ids = range(4)

    means = []
    stds = []
    for client in client_ids:
        accs = []
        for seed in client_preds_dict:
            if client in client_preds_dict[seed]:
                acc = accuracy_score(targets, client_preds_dict[seed][client])
                accs.append(acc)
        if accs:
            means.append(np.mean(accs))
            stds.append(np.std(accs))
        else:
            means.append(None)
            stds.append(None)
    return client_ids, means, stds


def compute_server_accs(server_preds_dict):
    """Returns mean and std accuracy for a list of server predictions {seed: preds}"""
    accs = [accuracy_score(targets, preds) for preds in server_preds_dict.values()]
    return np.mean(accs), np.std(accs)


all_client_ids = set()
for preds_dict in [h1_clients_preds, h2_clients_preds, on_client_preds]:
    for seed_preds in preds_dict.values():
        all_client_ids.update(seed_preds.keys())
all_client_ids = sorted(all_client_ids)
n_clients = len(all_client_ids)


_, h1_mean, h1_std = compute_client_accs(h1_clients_preds)
_, h2_mean, h2_std = compute_client_accs(h2_clients_preds)
_, local_mean, local_std = compute_client_accs(on_client_preds)

h1_x, h2_x, local_x = [], [], []
for i in range(n_clients):
    base = i * 3
    h1_x.append(base)
    h2_x.append(base + 1)
    local_x.append(base + 2)

server_means = []
server_stds = []
for preds in [h1_server_preds, h2_server_preds, on_cloud_preds]:
    mean, std = compute_server_accs(preds)
    server_means.append(mean)
    server_stds.append(std)
server_x = [n_clients * 3 + i for i in range(3)]

tickvals = [i * 3 + 1 for i in range(n_clients)] + server_x
ticktext = client_map + server_labels

annotations = []
for i in range(n_clients):
    for j, sublabel in enumerate(sub_labels):
        annotations.append(dict(
            x=i * 3 + j,
            y=-0.1,
            xref='x',
            yref='paper',
            text=sublabel,
            showarrow=False,
            font=dict(size=12),
            yshift=-20,
        ))

fig = go.Figure()

fig.add_trace(go.Bar(
    x=h1_x,
    y=h1_mean,
    error_y=dict(type='data', array=h1_std, visible=True),
    name="HF-NN V1",
    marker_color="blue"
))

fig.add_trace(go.Bar(
    x=h2_x,
    y=h2_mean,
    error_y=dict(type='data', array=h2_std, visible=True),
    name="HF-NN H2",
    marker_color="green"
))

fig.add_trace(go.Bar(
    x=local_x,
    y=local_mean,
    error_y=dict(type='data', array=local_std, visible=True),
    name="Local-Only",
    marker_color="orange"
))

fig.add_trace(go.Bar(
    x=[server_x[0]],
    y=[server_means[0]],
    error_y=dict(type='data', array=server_stds, visible=True),
    name="HF-NN V1",
    marker_color="blue",
    text=[f"{server_means[0]:.2%}"],
    textposition="auto"
))

fig.add_trace(go.Bar(
    x=[server_x[1]],
    y=[server_means[1]],
    error_y=dict(type='data', array=server_stds, visible=True),
    name="H1 - Remote ",
    marker_color="green",
    text=[f"{server_means[1]:.2%}"],
    textposition="auto"
))

fig.add_trace(go.Bar(
    x=[server_x[2]],
    y=[server_means[2]],
    error_y=dict(type='data', array=server_stds, visible=True),
    name="Remote Only",
    marker_color="purple",
    text=[f"{server_means[2]:.2%}"],
    textposition="auto"
))

fig.update_layout(
    #title="Accuracy Comparison (Mean Â± Std over Seeds)",
    xaxis=dict(
        title="Client ID / Server",
        tickmode='array',
        tickvals=tickvals,
        ticktext=ticktext,
        tickangle=-80,
    ),
    yaxis_title="Accuracy",
    yaxis_tickformat=".0%",
    barmode="group",
    height=500,
    margin=dict(l=40, r=40, t=60, b=140),
    font=dict(size=14),
    annotations=annotations,
    showlegend=False
)

fig.show()


In [None]:
server_x

In [None]:
h1_mean, h1_std

In [None]:
h2_mean, h2_std

In [None]:
local_mean, local_std 

In [None]:
server_means, server_stds 

In [None]:
max(server_means)-min(server_means)

In [None]:
max(server_means)

In [None]:
np.mean(local_std), np.max(local_std)

In [None]:
import plotly.graph_objects as go
from sklearn.metrics import f1_score
import numpy as np

std_deviations_differences = []
for i in range(10):
    target_class = i  
    client_map = ['LT', 'RT', 'LB', "RB"]
    sub_labels = ["V1", "V2", "Local-Only"]
    server_labels = ["Remote V1", " Remote V2", "Remote-Only"]

    def compute_client_f1s(client_preds_dict, target_class):
        """Returns mean and std F1 score for each client across seeds for a specific class."""
        client_ids = range(4)

        means = []
        stds = []
        for client in client_ids:
            f1s = []
            for seed in client_preds_dict:
                if client in client_preds_dict[seed]:
                    f1 = f1_score(targets, client_preds_dict[seed][client],
                                labels=[target_class], average='macro', zero_division=0)
                    f1s.append(f1)
            if f1s:
                means.append(np.mean(f1s))
                stds.append(np.std(f1s))
            else:
                means.append(None)
                stds.append(None)
        return client_ids, means, stds

    def compute_server_f1s(server_preds_dict, target_class):
        f1s = [f1_score(targets, preds, labels=[target_class], average='macro', zero_division=0)
            for preds in server_preds_dict.values()]
        return np.mean(f1s), np.std(f1s)

    all_client_ids = set()
    for preds_dict in [h1_clients_preds, h2_clients_preds, on_client_preds]:
        for seed_preds in preds_dict.values():
            all_client_ids.update(seed_preds.keys())
    all_client_ids = sorted(all_client_ids)
    n_clients = len(all_client_ids)

    _, f1_h1_mean, f1_h1_std = compute_client_f1s(h1_clients_preds, target_class)
    _, f1_h2_mean, f1_h2_std = compute_client_f1s(h2_clients_preds, target_class)
    _, f1_local_mean, f1_local_std = compute_client_f1s(on_client_preds, target_class)
    std_deviations_differences.append(np.array(f1_h1_std)-np.array(f1_local_std))
    std_deviations_differences.append(np.array(f1_h2_std)-np.array(f1_local_std))

    h1_x, h2_x, local_x = [], [], []
    for i in range(n_clients):
        base = i * 3
        h1_x.append(base)
        h2_x.append(base + 1)
        local_x.append(base + 2)

    f1_server_means = []
    f1_server_stds = []
    for preds in [h1_server_preds, h2_server_preds, on_cloud_preds]:
        mean, std = compute_server_f1s(preds, target_class)
        f1_server_means.append(mean)
        f1_server_stds.append(std)
    server_x = [n_clients * 3 + i for i in range(3)]

    tickvals = [i * 3 + 1 for i in range(n_clients)] + server_x
    ticktext = client_map + server_labels

    annotations = []
    for i in range(n_clients):
        for j, sublabel in enumerate(sub_labels):
            annotations.append(dict(
                x=i * 3 + j,
                y=-0.1,
                xref='x',
                yref='paper',
                text=sublabel,
                showarrow=False,
                font=dict(size=12),
                yshift=-20,
            ))

    fig = go.Figure()

    fig.add_trace(go.Bar(
        x=h1_x,
        y=f1_h1_mean,
        error_y=dict(type='data', array=f1_h1_std, visible=True),
        name="H1",
        marker_color="blue"
    ))

    fig.add_trace(go.Bar(
        x=h2_x,
        y=f1_h2_mean,
        error_y=dict(type='data', array=f1_h2_std, visible=True),
        name="H2",
        marker_color="green"
    ))

    fig.add_trace(go.Bar(
        x=local_x,
        y=f1_local_mean,
        error_y=dict(type='data', array=f1_local_std, visible=True),
        name="Local-Only",
        marker_color="orange"
    ))

    fig.add_trace(go.Bar(
        x=[server_x[0]],
        y=[f1_server_means[0]],
        error_y=dict(type='data', array=server_stds, visible=True),
        name="HF-NN V1",
        marker_color="blue",
        text=[f"{server_means[0]:.2%}"],
        textposition="auto"
    ))

    fig.add_trace(go.Bar(
        x=[server_x[1]],
        y=[f1_server_means[1]],
        error_y=dict(type='data', array=server_stds, visible=True),
        name="H1 - Remote ",
        marker_color="green",
        text=[f"{server_means[1]:.2%}"],
        textposition="auto"
    ))

    fig.add_trace(go.Bar(
        x=[server_x[2]],
        y=[f1_server_means[2]],
        error_y=dict(type='data', array=server_stds, visible=True),
        name="Remote Only",
        marker_color="purple",
        text=[f"{server_means[2]:.2%}"],
        textposition="auto"
    ))

    fig.update_layout(
        #title=f"F1-Score for Class {target_class} (Mean Â± Std over Seeds)",
        xaxis=dict(
            title="Client ID / Server",
            tickmode='array',
            tickvals=tickvals,
            ticktext=ticktext,
            tickangle=-80,
        ),
        yaxis_title="F1 Score",
        yaxis_tickformat=".0%",
        barmode="group",
        height=500,
        margin=dict(l=40, r=40, t=60, b=140),
        annotations=annotations,
        showlegend=False
    )

    fig.show()


In [None]:
np.mean(local_mean), np.mean(h1_mean), np.mean(h2_mean)

In [None]:
np.mean(local_std), np.mean(h1_std), np.mean(h2_std)

In [None]:
np.array(std_deviations_differences)

In [None]:
import numpy as np
def compute_coverage_and_accuracy_single_seed(seed_router_logits, seed_client_preds, targets):
    """
    Compute coverage and hybrid accuracy for a single seed given router logits, client preds, and targets,
    but only evaluate every 100th threshold for efficiency.

    Args:
        seed_router_logits: dict {client_id: np.array of router logits for each sample}
        seed_client_preds: dict {client_id: np.array of predicted labels for each sample}
        targets: np.array of true labels

    Returns:
        coverage: dict {threshold: {client_id: coverage_value}}
        hybrid_accuracy: dict {client_id: {threshold: accuracy_value}}
    """

    client_ids = sorted(seed_router_logits.keys())

    # Gather all unique thresholds from all clients combined
    all_logits = np.concatenate([seed_router_logits[cid] for cid in client_ids])
    unique_thresholds = np.sort(np.unique(all_logits))[::-1]  # Descending thresholds

    # Sample every 100th threshold
    unique_thresholds = unique_thresholds[::100]

    coverage = {}
    hybrid_accuracy = {cid: {} for cid in client_ids}

    for th in unique_thresholds:
        coverage[th] = {}
        for cid in client_ids:
            logits = seed_router_logits[cid]
            preds = seed_client_preds[cid]

            # Ensure numpy arrays for indexing
            logits = np.array(logits)
            preds = np.array(preds)
            targets_arr = np.array(targets[:len(preds)])  # Match size

            selected = logits > th
            cov_frac = 1.0 - np.mean(selected)
            coverage[th][cid] = cov_frac

            if np.any(selected):
                acc = np.mean(preds[selected] == targets_arr[selected])
            else:
                acc = 0.0
            hybrid_accuracy[cid][th] = acc

    return coverage, hybrid_accuracy


In [None]:
import numpy as np
def compute_coverage_and_accuracy_single_seed(seed_router_logits, seed_client_preds, targets):
    """
    Compute coverage and hybrid accuracy for a single seed given router logits, client preds, and targets,
    but only evaluate every 100th threshold for efficiency.

    Args:
        seed_router_logits: dict {client_id: np.array of router logits for each sample}
        seed_client_preds: dict {client_id: np.array of predicted labels for each sample}
        targets: np.array of true labels

    Returns:
        coverage: dict {threshold: {client_id: coverage_value}}
        hybrid_accuracy: dict {client_id: {threshold: accuracy_value}}
    """

    client_ids = sorted(seed_router_logits.keys())

    # Gather all unique thresholds from all clients combined
    all_logits = np.concatenate([seed_router_logits[cid] for cid in client_ids])
    unique_thresholds = np.sort(np.unique(all_logits))[::-1]  # Descending thresholds

    # Sample every 100th threshold
    unique_thresholds = unique_thresholds[::100]

    coverage = {}
    hybrid_accuracy = {cid: {} for cid in client_ids}

    for th in unique_thresholds:
        coverage[th] = {}
        for cid in client_ids:
            logits = seed_router_logits[cid]
            preds = seed_client_preds[cid]

            # Ensure numpy arrays for indexing
            logits = np.array(logits)
            preds = np.array(preds)
            targets_arr = np.array(targets[:len(preds)])  # Match size

            selected = logits > th
            cov_frac = 1.0 - np.mean(selected)
            coverage[th][cid] = cov_frac

            if np.any(selected):
                acc = np.mean(preds[selected] == targets_arr[selected])
            else:
                acc = 0.0
            hybrid_accuracy[cid][th] = acc

    return coverage, hybrid_accuracy


In [None]:
f1_h1_mean, f1_h1_std

In [None]:
f1_h2_mean, f1_h2_std

In [None]:
f1_local_mean, f1_local_std

In [None]:
f1_server_means, f1_server_stds

In [None]:
for s in h1_server_preds.keys():
    print(accuracy_score(h1_server_preds[s], h1_targets[s]))

In [None]:
for s in h2_server_preds.keys():
    print(accuracy_score(h2_server_preds[s], h2_targets[s]))

In [None]:
h1_clients_routers[4][0]
h1_targets[4]

In [None]:
import plotly.graph_objects as go
import numpy as np

# Data and settings
bins = 30
target_filter = list(range(10))  # ðŸ”§ Change to a single value (e.g., 0) or a list (e.g., [0, 2])
classes = np.unique(h1_targets[13])
logits = np.array(h1_clients_routers[13][3])
targets = h1_targets[4]

# Ensure target_filter is a list
if not isinstance(target_filter, (list, np.ndarray)):
    target_filter = [target_filter]

# Compute histogram data for filtered classes
hist_data = []
bin_edges = None

for cls in classes:
    if cls not in target_filter:
        continue

    mask = targets == cls
    class_logits = logits[mask]
    counts, edges = np.histogram(class_logits, bins=bins)

    if bin_edges is None:
        bin_edges = edges

    bin_centers = 0.5 * (edges[:-1] + edges[1:])
    hist_data.append((cls, bin_centers, counts))

# Create Plotly figure
fig = go.Figure()

for cls, bin_centers, counts in hist_data:
    fig.add_trace(go.Bar(
        x=bin_centers,
        y=counts,
        name=f'Class {cls}',
        opacity=0.6
    ))

# Update layout
fig.update_layout(
    title=f'Histogram of Router V1 Client LB by Class',
    xaxis_title='Logit Value',
    yaxis_title='Frequency',
    barmode='overlay',
    template='plotly_white'
)

fig.show()


In [None]:
import plotly.graph_objects as go
import numpy as np

bins = 30
target_filter = list(range(10)) #  [1,3,5]
classes = np.unique(h1_targets[27])
logits = np.array(h1_clients_routers[27][1])
targets = h1_targets[4]


if not isinstance(target_filter, (list, np.ndarray)):
    target_filter = [target_filter]

hist_data = []
bin_edges = None

for cls in classes:
    if cls not in target_filter:
        continue

    mask = targets == cls
    class_logits = logits[mask]
    counts, edges = np.histogram(class_logits, bins=bins)

    if bin_edges is None:
        bin_edges = edges

    bin_centers = 0.5 * (edges[:-1] + edges[1:])
    hist_data.append((cls, bin_centers, counts))

fig = go.Figure()

for cls, bin_centers, counts in hist_data:
    fig.add_trace(go.Bar(
        x=bin_centers,
        y=counts,
        name=f'Class {cls}',
        opacity=0.6
    ))

# Update layout
fig.update_layout(
    title=f'Histogram of Router V1 Client LB by Class',
    xaxis_title='Logit Value',
    yaxis_title='Frequency',
    barmode='overlay',
    template='plotly_white'
)

fig.show()

In [None]:
import plotly.graph_objects as go
import numpy as np

bins = 30
target_filter =  [1,3,5] # list(range(10)) #  [1,3,5]
classes = np.unique(h2_targets[27])
logits = np.array(h2_clients_routers[27][2])
targets = h2_targets[4]


if not isinstance(target_filter, (list, np.ndarray)):
    target_filter = [target_filter]

hist_data = []
bin_edges = None

for cls in classes:
    if cls not in target_filter:
        continue

    mask = targets == cls
    class_logits = logits[mask]
    counts, edges = np.histogram(class_logits, bins=bins)

    if bin_edges is None:
        bin_edges = edges

    bin_centers = 0.5 * (edges[:-1] + edges[1:])
    hist_data.append((cls, bin_centers, counts))

fig = go.Figure()

for cls, bin_centers, counts in hist_data:
    fig.add_trace(go.Bar(
        x=bin_centers,
        y=counts,
        name=f'Class {cls}',
        opacity=0.6
    ))

# Update layout
fig.update_layout(
    # title=f'Histogram of Router V1 Client LB by Class',
    xaxis_title='Logit Value',
    yaxis_title='Frequency',
    barmode='overlay',
    template='plotly_white',
    font=dict(size=14)
)

fig.show()


In [None]:
import plotly.graph_objects as go
import numpy as np

bins = 30
target_filter = list(range(10))  # Filter specific classes
classes = np.unique(h1_targets[27])  # Available classes
logits = np.array(h1_clients_routers[27][0])  # Logits
targets = np.array(h1_targets[27])  # Ground truth labels
preds = np.array(h1_clients_preds[27][0])  # Client predictions

if not isinstance(target_filter, (list, np.ndarray)):
    target_filter = [target_filter]

hist_data_correct = []
hist_data_incorrect = []
bin_edges = None

for cls in classes:
    if cls not in target_filter:
        continue

    mask = targets == cls
    class_logits = logits[mask]
    class_preds = preds[mask]
    class_targets = targets[mask]

    # Correct and incorrect masks
    correct_mask = class_preds == class_targets
    incorrect_mask = ~correct_mask

    # Histogram for correct
    counts_correct, edges = np.histogram(class_logits[correct_mask], bins=bins)
    if bin_edges is None:
        bin_edges = edges
    bin_centers = 0.5 * (edges[:-1] + edges[1:])
    hist_data_correct.append((cls, bin_centers, counts_correct))

    # Histogram for incorrect
    counts_incorrect, _ = np.histogram(class_logits[incorrect_mask], bins=edges)
    hist_data_incorrect.append((cls, bin_centers, counts_incorrect))

# Plotting
fig = go.Figure()

for cls, bin_centers, counts in hist_data_correct:
    fig.add_trace(go.Bar(
        x=bin_centers,
        y=counts,
        name=f'Class {cls} (Correct)',
        opacity=0.6
    ))

for cls, bin_centers, counts in hist_data_incorrect:
    fig.add_trace(go.Bar(
        x=bin_centers,
        y=counts,
        name=f'Class {cls} (Incorrect)',
        opacity=0.6,
        marker=dict(pattern_shape="/")
    ))

fig.update_layout(
    title='Histogram of Router V1 Client LB by Class (Correct vs Incorrect)',
    xaxis_title='Logit Value',
    yaxis_title='Frequency',
    barmode='overlay',
    template='plotly_white',
    legend_title='Legend'
)

fig.show()


In [None]:
import plotly.graph_objects as go
import numpy as np

bins = 30
target_filter =  list(range(10)) # [1, 3, 5] 
classes = np.unique(h1_targets[13])
logits = np.array(h1_clients_routers[13][3])
targets = h1_targets[4]

if not isinstance(target_filter, (list, np.ndarray)):
    target_filter = [target_filter]

hist_data = []
bin_edges = None
class_means = {}
class_max_heights = {}

for cls in classes:
    if cls not in target_filter:
        continue

    mask = targets == cls
    class_logits = logits[mask]
    counts, edges = np.histogram(class_logits, bins=bins)

    if bin_edges is None:
        bin_edges = edges

    bin_centers = 0.5 * (edges[:-1] + edges[1:])
    hist_data.append((cls, bin_centers, counts))

    # Store mean and max count for placing the line and annotation
    class_means[cls] = np.mean(class_logits)
    class_max_heights[cls] = max(counts)

fig = go.Figure()

# Plot histogram bars
for cls, bin_centers, counts in hist_data:
    fig.add_trace(go.Bar(
        x=bin_centers,
        y=counts,
        name=f'Class {cls}',
        opacity=0.6
    ))

# Add vertical mean lines and annotations
for cls in class_means:
    mean_val = class_means[cls]
    max_y = class_max_heights[cls]

    # Mean line
    fig.add_trace(go.Scatter(
        x=[mean_val, mean_val],
        y=[0, max_y],
        mode='lines',
        line=dict(dash='dash', color='black'),
        name=f'Mean Class {cls}',
        showlegend=False
    ))

    # Mean value label near x-axis
    fig.add_annotation(
        x=mean_val,
        y=-0.01,  # Slightly below x-axis
        xref='x',
        yref='paper',
        text=f"{mean_val:.2f}",
        showarrow=False,
        xanchor='center',
        yanchor='top',
        font=dict(size=10, color='black')
    )

# Update layout
fig.update_layout(
    title='Histogram of Router V1 Client RB by Class (with Mean Lines)',
    xaxis_title='Logit Value',
    yaxis_title='Frequency',
    barmode='overlay',
    template='plotly_white'
)

fig.show()


In [None]:
import plotly.graph_objects as go
import numpy as np

bins = 30
target_filter = list(range(10)) # 
classes = np.unique(h1_targets[4])
logits = []
targets = []
for i in h1_clients_routers[4].keys():
    logits += list(h1_clients_routers[4][i])
    targets += h1_targets[4]
logits  = np.array(logits)   
targets  = np.array(targets)                       

if not isinstance(target_filter, (list, np.ndarray)):
    target_filter = [target_filter]

hist_data = []
bin_edges = None

for cls in classes:
    if cls not in target_filter:
        continue

    mask = targets == cls
    class_logits = logits[mask]
    counts, edges = np.histogram(class_logits, bins=bins)

    if bin_edges is None:
        bin_edges = edges

    bin_centers = 0.5 * (edges[:-1] + edges[1:])
    hist_data.append((cls, bin_centers, counts))

fig = go.Figure()

for cls, bin_centers, counts in hist_data:
    fig.add_trace(go.Bar(
        x=bin_centers,
        y=counts,
        name=f'Class {cls}',
        opacity=0.6
    ))

fig.update_layout(
    title=f'Histogram of Router V1 by Class',
    xaxis_title='Logit Value',
    yaxis_title='Frequency',
    barmode='overlay',
    template='plotly_white'
)

fig.show()


In [None]:
import plotly.graph_objects as go
import numpy as np

client_mapping = {
    0: "left top (LT)",
    1: "right top (RT)",
    2: "left bottom (LB)",
    3: "right bottom (RB)"
}

def plot_accuracy_vs_coverage(coverage, hybrid_accuracy):
    thresholds = sorted(coverage.keys())
    clients = sorted(hybrid_accuracy.keys())

    fig = go.Figure()

    best_acc_at_cov1 = []

    for client_id in clients:
        covs = np.array([coverage[t][client_id] for t in thresholds])
        accs = np.array([hybrid_accuracy[client_id][t] for t in thresholds])
        flipped = (1-covs)

        fig.add_trace(
            go.Scatter(
                x=flipped,
                y=accs,
                mode='lines+markers',
                name=f'Client {client_mapping[client_id]}'
            )
        )

        mask = np.isclose(covs, 1.0)
        if np.any(mask):
            best_acc_at_cov1.append(accs[mask][0])

    if best_acc_at_cov1:
        max_acc = max(best_acc_at_cov1)
        fig.add_hline(
            y=max_acc,
            line_dash="dash",
            line_color="red",
            annotation_text=f"Best local accuracy: {max_acc:.3f}",
            annotation_position="top right"
        )

    fig.update_layout(
        title='Hybrid Accuracy vs. Local Coverage per Client',
        xaxis=dict(
            title='Remote Coverage (Fraction of samples processed by the Cloud)',
            range=[0, 1],
            autorange=False
        ),
        yaxis=dict(
            title='Hybrid Accuracy',
            range=[0, 1],
            autorange=False
        ),
        legend_title='Clients',
        width=800,
        height=500,
        template='plotly_white'
    )

    fig.show()

In [None]:
import plotly.graph_objects as go
import numpy as np

bins = 30


seed = 13  
client_id = 0
logits = np.array(h2_clients_routers[seed][client_id])
preds = np.array(h2_clients_preds[seed][client_id])
targets = np.array(h2_targets[seed])[:len(preds)]

correct_mask = preds == targets
incorrect_mask = ~correct_mask

correct_logits = logits[correct_mask]
incorrect_logits = logits[incorrect_mask]


c_counts, bin_edges = np.histogram(correct_logits, bins=bins)
ic_counts, _ = np.histogram(incorrect_logits, bins=bin_edges)
bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])

fig = go.Figure()

fig.add_trace(go.Bar(
    x=bin_centers,
    y=c_counts,
    name='Correct',
    opacity=0.7,
    marker_color='green'
))

fig.add_trace(go.Bar(
    x=bin_centers,
    y=ic_counts,
    name='Incorrect',
    opacity=0.7,
    marker_color='crimson'
))

fig.update_layout(
    title=f'Router Logits Histogram for Client {client_id} (Seed {seed})',
    xaxis_title='Router Logit',
    yaxis_title='Frequency',
    barmode='overlay',
    template='plotly_white'
)

fig.show()


In [None]:
import plotly.graph_objects as go
import numpy as np

bins = 30


seed = 13  
client_id = 1
logits = np.array(h2_clients_routers[seed][client_id])
preds = np.array(h2_clients_preds[seed][client_id])
targets = np.array(h2_targets[seed])[:len(preds)]

correct_mask = preds == targets
incorrect_mask = ~correct_mask

correct_logits = logits[correct_mask]
incorrect_logits = logits[incorrect_mask]


c_counts, bin_edges = np.histogram(correct_logits, bins=bins)
ic_counts, _ = np.histogram(incorrect_logits, bins=bin_edges)
bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])

fig = go.Figure()

fig.add_trace(go.Bar(
    x=bin_centers,
    y=c_counts,
    name='Correct',
    opacity=0.7,
    marker_color='green'
))

fig.add_trace(go.Bar(
    x=bin_centers,
    y=ic_counts,
    name='Incorrect',
    opacity=0.7,
    marker_color='crimson'
))

fig.update_layout(
    title=f'Router Logits Histogram for Client {client_id} (Seed {seed})',
    xaxis_title='Router Logit',
    yaxis_title='Frequency',
    barmode='overlay',
    template='plotly_white'
)

fig.show()

In [None]:

def calculate_hybrid_accuracy(logits, preds, server, targets, thresholds): 
    hybrid_accuracy = {}
    coverages = {} 
    for threshold in thresholds:
        hybrid_accuracy[np.float64(round(threshold, 4))] = {}
        coverages[np.float64(round(threshold, 4))] = {}
        for key in preds.keys():
            cloud_routing_decision = np.array(logits[0])>threshold
            local_routing_decision = ~cloud_routing_decision
            coverage = sum(cloud_routing_decision)/len(targets)
            hybrid = sum((preds[key] == targets)*local_routing_decision) + sum((server == targets)*cloud_routing_decision)
            hybrid = hybrid/len(targets)
            coverages[np.float64(round(threshold, 4))][key] = coverage
            hybrid_accuracy[np.float64(round(threshold, 4))][key] = hybrid
    return hybrid_accuracy, coverages


In [None]:
SEED = 13
thresholds = np.linspace(1, 0, 100) 
logits = h1_clients_routers[seed]
preds = h1_clients_preds[seed]
server =  h1_server_preds[seed]
targets = np.array(h1_targets[seed])

hybrid_accuracy, coverages = calculate_hybrid_accuracy(logits, preds, server, targets, thresholds)

In [None]:
local_mean

In [None]:
clientmapper = ['LT', 'RT', 'LB', 'RB']
clients = list(preds.keys())
sorted_thresholds = sorted(hybrid_accuracy.keys())

fig = go.Figure()

for client in clients:
    coverage_vals = [coverages[t][client] for t in sorted_thresholds]
    accuracies = [hybrid_accuracy[t][client] for t in sorted_thresholds]
    
    fig.add_trace(go.Scatter(
        x=coverage_vals, 
        y=accuracies,
        mode='lines+markers',
        name=f'Client {clientmapper[client]}'
    ))

fig.update_layout(
    #title='Hybrid V1 - Hybrid Accuracy vs Coverage by Clients',
    xaxis_title='Coverage',
    yaxis_title='Hybrid Accuracy',
    xaxis=dict(range=[0, 1]),
    yaxis=dict(range=[0, 1]),
    template='plotly_white',
    legend_title="Clients",
    font=dict(size=14)
)

max_local = max(local_mean)
fig.add_hline(
    y=max_local,
    line=dict(color='Red', dash='dash'),
    annotation_text=f'Max Local-Only Accuracy = {max_local:.3f}',
    annotation_position='top right'
)


fig.show()

In [None]:
SEED = 4
thresholds = np.linspace(1, 0, 100) 
logits = h2_clients_routers[seed]
preds = h2_clients_preds[seed]
server =  h2_server_preds[seed]
targets = np.array(h2_targets[seed])
hybrid_accuracy, coverages = calculate_hybrid_accuracy(logits, preds, server, targets, thresholds)

In [None]:
clientmapper = ['LT', 'RT', 'LB', 'RB']
clients = list(preds.keys())
sorted_thresholds = sorted(hybrid_accuracy.keys())

fig = go.Figure()

for client in clients:
    coverage_vals = [coverages[t][client] for t in sorted_thresholds]
    accuracies = [hybrid_accuracy[t][client] for t in sorted_thresholds]
    
    fig.add_trace(go.Scatter(
        x=coverage_vals, 
        y=accuracies,
        mode='lines+markers',
        name=f'Client {clientmapper[client]}'
    ))

fig.update_layout(
    #title='Hybrid V2 - Hybrid Accuracy vs Coverage by Clients',
    xaxis_title='Coverage',
    yaxis_title='Hybrid Accuracy',
    xaxis=dict(range=[0, 1]),
    yaxis=dict(range=[0, 1]),
    template='plotly_white',
    legend_title="Clients"
)

max_local = max(local_mean)
fig.add_hline(
    y=max_local,
    line=dict(color='Red', dash='dash'),
    annotation_text=f'Max Local-Only Accuracy = {max_local:.3f}',
    annotation_position='top right'
)


fig.show()

plot_accuracy_vs_coverage(_coverage, _hybrid_accuracy)