In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from federated_inference.common.utils import set_seed
from federated_inference.configs.data_config import DataConfiguration
from federated_inference.configs.transform_config import DataTransformConfiguration

# Hybrid

In [None]:

from federated_inference.simulations.hybrid.simulation import HybridSplitSimulation
DATASET = 'FMNIST'
SEEDS = [4,13,27]

In [None]:
from federated_inference.simulations.hybrid.v6.models import GlobalHybridSplitClassifierHead, HybridSplitBase, LocalHybridSplitClassifierHead, RouterHead
VERSION = "v61"
h1_clients_preds = {}
h1_clients_routers = {}
h1_server_preds = {}
h1_targets = {}

for seed in SEEDS:
    set_seed(seed)
    data_config = DataConfiguration(DATASET)
    transform_config = DataTransformConfiguration()
    simulation_hybrid_1 = HybridSplitSimulation(seed, VERSION,  data_config, transform_config, GlobalHybridSplitClassifierHead, HybridSplitBase, LocalHybridSplitClassifierHead, RouterHead)
    simulation_hybrid_1.server.load()
    _testsets = [client.request_pred(pred_all = True, keep_label = True) for client in simulation_hybrid_1.clients]
    h1_clients_pred, h1_clients_router, h1_server_pred = simulation_hybrid_1.server.run_infernces(_testsets)
    h1_clients_targets =  simulation_hybrid_1.server.testloaders[0].dataset.targets 
    h1_clients_preds[seed] = h1_clients_pred
    h1_clients_routers[seed] = h1_clients_router
    h1_server_preds[seed] =  h1_server_pred 
    h1_targets[seed] = simulation_hybrid_1.server.testloaders[0].dataset.targets 

In [None]:
len(h1_clients_preds[4][0])

In [None]:
len(h1_server_preds[4])

In [None]:

from federated_inference.simulations.hybrid.v7.models import GlobalHybridSplitClassifierHead, HybridSplitBase, LocalHybridSplitClassifierHead, RouterHead
VERSION = "v71"
h2_clients_preds = {}
h2_clients_routers = {}
h2_server_preds = {}
h2_targets = {}

for seed in SEEDS:
    set_seed(seed)
    data_config = DataConfiguration(DATASET)
    transform_config = DataTransformConfiguration()
    simulation_hybrid_2 = HybridSplitSimulation(seed, VERSION,  data_config, transform_config, GlobalHybridSplitClassifierHead, HybridSplitBase, LocalHybridSplitClassifierHead, RouterHead)
    simulation_hybrid_2.server.load()
    _testsets = [client.request_pred(pred_all = True, keep_label = True) for client in simulation_hybrid_2.clients]
    h2_clients_pred, h2_clients_router, h2_server_pred = simulation_hybrid_2.server.run_infernces(_testsets)
    h2_clients_targets =  simulation_hybrid_2.server.testloaders[0].dataset.targets 
    h2_clients_preds[seed] = h2_clients_pred
    h2_clients_routers[seed] = h2_clients_router
    h2_server_preds[seed] =  h2_server_pred 
    h2_targets[seed] = simulation_hybrid_2.server.testloaders[0].dataset.targets 

## On Device

In [None]:
from federated_inference.simulations.ondevice.simulation import OnDeviceVerticalSimulation
from federated_inference.simulations.ondevice.models import OnDeviceMNISTModel
VERSION = "v1"
on_client_preds = {}
for seed in SEEDS:
    data_config = DataConfiguration(DATASET)
    transform_config = DataTransformConfiguration()
    simulation_on_client = OnDeviceVerticalSimulation(seed, VERSION, data_config, transform_config, OnDeviceMNISTModel, exist=False)
    [client.load() for client in simulation_on_client.clients]
    _testsets = [client.data.test_dataset for client in simulation_on_client.clients]
    on_client_pred = simulation_on_client.run_inference(_testsets)
    on_client_preds[seed] = on_client_pred
    targets =  simulation_on_client.clients[0].data.test_dataset.dataset.targets


## On Cloud

In [None]:
from federated_inference.simulations.oncloud.simulation import OnCloudVerticalSimulation
from federated_inference.simulations.oncloud.models.model import OnCloudMNISTModel
VERSION = "v1"
on_cloud_preds = {}
for seed in SEEDS:
    data_config = DataConfiguration(DATASET)
    transform_config = DataTransformConfiguration()
    simulation_on_cloud = OnCloudVerticalSimulation(seed, VERSION, data_config, transform_config, OnCloudMNISTModel, exist=False)
    simulation_on_cloud.server.load()
    _testsets = [client.data.test_dataset for client in simulation_on_cloud.clients]
    on_cloud_pred = simulation_on_cloud.run_inference(_testsets)
    on_cloud_preds[seed] = on_cloud_pred
    targets =  simulation_on_cloud.clients[0].data.test_dataset.dataset.targets

# Evaluation

In [None]:
import plotly.graph_objects as go
from sklearn.metrics import accuracy_score
import plotly.express as px
import itertools
import numpy as np

client_map = ['LT', 'RT', 'LB', "RB"]
sub_labels = ["V1", "V2", "Local-Only"]
server_labels = ["Remote V1", " Remote V2", "Remote-Only"]

def compute_client_accs(client_preds_dict):
    """Returns mean and std accuracy for each client across seeds."""
    client_ids = range(4)

    means = []
    stds = []
    for client in client_ids:
        accs = []
        for seed in client_preds_dict:
            if client in client_preds_dict[seed]:
                acc = accuracy_score(targets, client_preds_dict[seed][client])
                accs.append(acc)
        if accs:
            means.append(np.mean(accs))
            stds.append(np.std(accs))
        else:
            means.append(None)
            stds.append(None)
    return client_ids, means, stds


def compute_server_accs(server_preds_dict):
    """Returns mean and std accuracy for a list of server predictions {seed: preds}"""
    accs = [accuracy_score(targets, preds) for preds in server_preds_dict.values()]
    return np.mean(accs), np.std(accs)


all_client_ids = set()
for preds_dict in [h1_clients_preds, h2_clients_preds, on_client_preds]:
    for seed_preds in preds_dict.values():
        all_client_ids.update(seed_preds.keys())
all_client_ids = sorted(all_client_ids)
n_clients = len(all_client_ids)


_, h1_mean, h1_std = compute_client_accs(h1_clients_preds)
_, h2_mean, h2_std = compute_client_accs(h2_clients_preds)
_, local_mean, local_std = compute_client_accs(on_client_preds)

h1_x, h2_x, local_x = [], [], []
for i in range(n_clients):
    base = i * 3
    h1_x.append(base)
    h2_x.append(base + 1)
    local_x.append(base + 2)

server_means = []
server_stds = []
for preds in [h1_server_preds, h2_server_preds, on_cloud_preds]:
    mean, std = compute_server_accs(preds)
    server_means.append(mean)
    server_stds.append(std)
server_x = [n_clients * 3 + i for i in range(3)]

tickvals = [i * 3 + 1 for i in range(n_clients)] + server_x
ticktext = client_map + server_labels

annotations = []
for i in range(n_clients):
    for j, sublabel in enumerate(sub_labels):
        annotations.append(dict(
            x=i * 3 + j,
            y=-0.1,
            xref='x',
            yref='paper',
            text=sublabel,
            showarrow=False,
            font=dict(size=12),
            yshift=-20,
        ))

fig = go.Figure()

fig.add_trace(go.Bar(
    x=h1_x,
    y=h1_mean,
    error_y=dict(type='data', array=h1_std, visible=True),
    name="HF-NN V1",
    marker_color="blue"
))

fig.add_trace(go.Bar(
    x=h2_x,
    y=h2_mean,
    error_y=dict(type='data', array=h2_std, visible=True),
    name="HF-NN H2",
    marker_color="green"
))

fig.add_trace(go.Bar(
    x=local_x,
    y=local_mean,
    error_y=dict(type='data', array=local_std, visible=True),
    name="Local-Only",
    marker_color="orange"
))

fig.add_trace(go.Bar(
    x=[server_x[0]],
    y=[server_means[0]],
    error_y=dict(type='data', array=server_stds, visible=True),
    name="HF-NN V1",
    marker_color="blue",
    text=[f"{server_means[0]:.2%}"],
    textposition="auto"
))

fig.add_trace(go.Bar(
    x=[server_x[1]],
    y=[server_means[1]],
    error_y=dict(type='data', array=server_stds, visible=True),
    name="H1 - Remote ",
    marker_color="green",
    text=[f"{server_means[1]:.2%}"],
    textposition="auto"
))

fig.add_trace(go.Bar(
    x=[server_x[2]],
    y=[server_means[2]],
    error_y=dict(type='data', array=server_stds, visible=True),
    name="Remote Only",
    marker_color="purple",
    text=[f"{server_means[2]:.2%}"],
    textposition="auto"
))

fig.update_layout(
    #title="Accuracy Comparison (Mean ± Std over Seeds)",
    xaxis=dict(
        title="Client ID / Server",
        tickmode='array',
        tickvals=tickvals,
        ticktext=ticktext,
        tickangle=-80,
    ),
    yaxis_title="Accuracy",
    yaxis_tickformat=".0%",
    barmode="group",
    height=500,
    margin=dict(l=40, r=40, t=60, b=140),
    annotations=annotations,
    showlegend=False,
    font=dict(size=14), 
)

fig.show()


In [None]:
h1_mean, h1_std

In [None]:
np.mean(h1_mean)

In [None]:
h2_mean, h2_std

In [None]:
np.mean(h2_mean)

In [None]:
local_mean, local_std 

In [None]:
np.mean(local_mean)

In [None]:
server_means, server_stds

In [None]:
max(server_means)-min(server_means)

In [None]:
server_means[1]-server_means[0]

In [None]:
max(server_means), min(server_means)

In [None]:
np.mean(local_std), np.max(local_std)

In [None]:
import plotly.graph_objects as go
from sklearn.metrics import f1_score
import numpy as np

std_deviations_differences = []
for i in range(10):
    target_class = i  
    client_map = ['LT', 'RT', 'LB', "RB"]
    sub_labels = ["V1", "V2", "Local-Only"]
    server_labels = ["Remote V1", " Remote V2", "Remote-Only"]

    def compute_client_f1s(client_preds_dict, target_class):
        """Returns mean and std F1 score for each client across seeds for a specific class."""
        client_ids = range(4)

        means = []
        stds = []
        for client in client_ids:
            f1s = []
            for seed in client_preds_dict:
                if client in client_preds_dict[seed]:
                    f1 = f1_score(targets, client_preds_dict[seed][client],
                                labels=[target_class], average='macro', zero_division=0)
                    f1s.append(f1)
            if f1s:
                means.append(np.mean(f1s))
                stds.append(np.std(f1s))
            else:
                means.append(None)
                stds.append(None)
        return client_ids, means, stds

    def compute_server_f1s(server_preds_dict, target_class):
        f1s = [f1_score(targets, preds, labels=[target_class], average='macro', zero_division=0)
            for preds in server_preds_dict.values()]
        return np.mean(f1s), np.std(f1s)

    all_client_ids = set()
    for preds_dict in [h1_clients_preds, h2_clients_preds, on_client_preds]:
        for seed_preds in preds_dict.values():
            all_client_ids.update(seed_preds.keys())
    all_client_ids = sorted(all_client_ids)
    n_clients = len(all_client_ids)

    _, f1_h1_mean, f1_h1_std = compute_client_f1s(h1_clients_preds, target_class)
    _, f1_h2_mean, f1_h2_std = compute_client_f1s(h2_clients_preds, target_class)
    _, f1_local_mean, f1_local_std = compute_client_f1s(on_client_preds, target_class)
    std_deviations_differences.append(np.array(f1_h1_std)-np.array(f1_local_std))
    std_deviations_differences.append(np.array(f1_h2_std)-np.array(f1_local_std))

    h1_x, h2_x, local_x = [], [], []
    for i in range(n_clients):
        base = i * 3
        h1_x.append(base)
        h2_x.append(base + 1)
        local_x.append(base + 2)

    f1_server_means = []
    f1_server_stds = []
    for preds in [h1_server_preds, h2_server_preds, on_cloud_preds]:
        mean, std = compute_server_f1s(preds, target_class)
        f1_server_means.append(mean)
        f1_server_stds.append(std)
    server_x = [n_clients * 3 + i for i in range(3)]

    tickvals = [i * 3 + 1 for i in range(n_clients)] + server_x
    ticktext = client_map + server_labels

    annotations = []
    for i in range(n_clients):
        for j, sublabel in enumerate(sub_labels):
            annotations.append(dict(
                x=i * 3 + j,
                y=-0.1,
                xref='x',
                yref='paper',
                text=sublabel,
                showarrow=False,
                font=dict(size=12),
                yshift=-20,
            ))

    fig = go.Figure()

    fig.add_trace(go.Bar(
        x=h1_x,
        y=f1_h1_mean,
        error_y=dict(type='data', array=f1_h1_std, visible=True),
        name="H1",
        marker_color="blue"
    ))

    fig.add_trace(go.Bar(
        x=h2_x,
        y=f1_h2_mean,
        error_y=dict(type='data', array=f1_h2_std, visible=True),
        name="H2",
        marker_color="green"
    ))

    fig.add_trace(go.Bar(
        x=local_x,
        y=f1_local_mean,
        error_y=dict(type='data', array=f1_local_std, visible=True),
        name="Local-Only",
        marker_color="orange"
    ))

    fig.add_trace(go.Bar(
        x=[server_x[0]],
        y=[f1_server_means[0]],
        error_y=dict(type='data', array=server_stds, visible=True),
        name="HF-NN V1",
        marker_color="blue",
        text=[f"{server_means[0]:.2%}"],
        textposition="auto"
    ))

    fig.add_trace(go.Bar(
        x=[server_x[1]],
        y=[f1_server_means[1]],
        error_y=dict(type='data', array=server_stds, visible=True),
        name="H1 - Remote ",
        marker_color="green",
        text=[f"{server_means[1]:.2%}"],
        textposition="auto"
    ))

    fig.add_trace(go.Bar(
        x=[server_x[2]],
        y=[f1_server_means[2]],
        error_y=dict(type='data', array=server_stds, visible=True),
        name="Remote Only",
        marker_color="purple",
        text=[f"{server_means[2]:.2%}"],
        textposition="auto"
    ))

    fig.update_layout(
        #title=f"F1-Score for Class {target_class} (Mean ± Std over Seeds)",
        xaxis=dict(
            title="Client ID / Server",
            tickmode='array',
            tickvals=tickvals,
            ticktext=ticktext,
            tickangle=-80,
        ),
        yaxis_title="F1 Score",
        yaxis_tickformat=".0%",
        barmode="group",
        height=500,
        margin=dict(l=40, r=40, t=60, b=140),
        annotations=annotations,
        showlegend=False,
        font=dict(size=14), 
    )

    fig.show()


In [None]:
max(np.array(std_deviations_differences).flatten())

In [None]:
f1_h1_mean, f1_h1_std

In [None]:
f1_h2_mean, f1_h2_std

In [None]:
f1_local_mean, f1_local_std

In [None]:
f1_server_means, f1_server_stds

In [None]:
for s in h1_clients_preds.keys():
    print(accuracy_score(h1_clients_preds[s][3], h2_targets[s]))

In [None]:
for s in h1_clients_preds.keys():
    print(accuracy_score(on_client_preds[s][3], h2_targets[s]))

In [None]:
for s in h1_server_preds.keys():
    print(accuracy_score(h1_server_preds[s], h2_targets[s]))

In [None]:
for s in h2_server_preds.keys():
    print(accuracy_score(h2_server_preds[s], h2_targets[s]))

In [None]:
import plotly.graph_objects as go
import numpy as np

bins = 30
target_filter = [1,4,6] # list(range(10)) # 
classes = np.unique(h2_targets[4])
logits = np.array(h2_clients_routers[4][1])
targets = h2_targets[4]

if not isinstance(target_filter, (list, np.ndarray)):
    target_filter = [target_filter]

hist_data = []
bin_edges = None

for cls in classes:
    if cls not in target_filter:
        continue

    mask = targets == cls
    class_logits = logits[mask]
    counts, edges = np.histogram(class_logits, bins=bins)

    if bin_edges is None:
        bin_edges = edges

    bin_centers = 0.5 * (edges[:-1] + edges[1:])
    hist_data.append((cls, bin_centers, counts))

fig = go.Figure()

for cls, bin_centers, counts in hist_data:
    fig.add_trace(go.Bar(
        x=bin_centers,
        y=counts,
        name=f'Class {cls}',
        opacity=0.6
    ))

fig.update_layout(
    title=f'Histogram of Router V2 Client RT by Class',
    xaxis_title='Logit Value',
    yaxis_title='Frequency',
    barmode='overlay',
    template='plotly_white'
)

fig.show()


In [None]:
import plotly.graph_objects as go
import numpy as np

bins = 30
target_filter = list(range(10)) # 
classes = np.unique(h2_targets[4])
logits = []
targets = []
for i in h2_clients_routers[4].keys():
    logits += list(h2_clients_routers[4][i])
    targets += h2_targets[4]
logits  = np.array(logits)   
targets  = np.array(targets)                       

if not isinstance(target_filter, (list, np.ndarray)):
    target_filter = [target_filter]

hist_data = []
bin_edges = None

for cls in classes:
    if cls not in target_filter:
        continue

    mask = targets == cls
    class_logits = logits[mask]
    counts, edges = np.histogram(class_logits, bins=bins)

    if bin_edges is None:
        bin_edges = edges

    bin_centers = 0.5 * (edges[:-1] + edges[1:])
    hist_data.append((cls, bin_centers, counts))

fig = go.Figure()

for cls, bin_centers, counts in hist_data:
    fig.add_trace(go.Bar(
        x=bin_centers,
        y=counts,
        name=f'Class {cls}',
        opacity=0.6
    ))

fig.update_layout(
    title=f'Histogram of Router V2 by Class',
    xaxis_title='Logit Value',
    yaxis_title='Frequency',
    barmode='overlay',
    template='plotly_white'
)

fig.show()


In [None]:
import plotly.graph_objects as go
import numpy as np

bins = 30
target_filter = list(range(10)) # 
classes = np.unique(h2_targets[4])
logits = np.array(h2_clients_routers[4][0])
targets = h2_targets[4]

if not isinstance(target_filter, (list, np.ndarray)):
    target_filter = [target_filter]

hist_data = []
bin_edges = None

for cls in classes:
    if cls not in target_filter:
        continue

    mask = targets == cls
    class_logits = logits[mask]
    counts, edges = np.histogram(class_logits, bins=bins)

    if bin_edges is None:
        bin_edges = edges

    bin_centers = 0.5 * (edges[:-1] + edges[1:])
    hist_data.append((cls, bin_centers, counts))

fig = go.Figure()

for cls, bin_centers, counts in hist_data:
    fig.add_trace(go.Bar(
        x=bin_centers,
        y=counts,
        name=f'Class {cls}',
        opacity=0.6
    ))

fig.update_layout(
    title=f'Histogram of Router V2 Client RT by Class',
    xaxis_title='Logit Value',
    yaxis_title='Frequency',
    barmode='overlay',
    template='plotly_white'
)

fig.show()


In [None]:
import plotly.graph_objects as go
import numpy as np

bins = 30
target_filter = list(range(10)) # 
classes = np.unique(h2_targets[4])
logits = np.array(h2_clients_routers[4][0])
targets = h2_targets[4]

if not isinstance(target_filter, (list, np.ndarray)):
    target_filter = [target_filter]

hist_data = []
bin_edges = None

for cls in classes:
    if cls not in target_filter:
        continue

    mask = targets == cls
    class_logits = logits[mask]
    counts, edges = np.histogram(class_logits, bins=bins)

    if bin_edges is None:
        bin_edges = edges

    bin_centers = 0.5 * (edges[:-1] + edges[1:])
    hist_data.append((cls, bin_centers, counts))

fig = go.Figure()

for cls, bin_centers, counts in hist_data:
    fig.add_trace(go.Bar(
        x=bin_centers,
        y=counts,
        name=f'Class {cls}',
        opacity=0.6
    ))

fig.update_layout(
    title=f'Histogram of Router V2 Client RT by Class',
    xaxis_title='Logit Value',
    yaxis_title='Frequency',
    barmode='overlay',
    template='plotly_white'
)

fig.show()


In [None]:
import plotly.graph_objects as go
import numpy as np

# Data and settings
bins = 30
target_filter = [1,4,6] # list(range(10)) 
classes = np.unique(h1_targets[4])
logits = np.array(h1_clients_routers[4][1])
targets = h1_targets[4]

# Ensure target_filter is a list
if not isinstance(target_filter, (list, np.ndarray)):
    target_filter = [target_filter]

# Compute histogram data for filtered classes
hist_data = []
bin_edges = None

for cls in classes:
    if cls not in target_filter:
        continue

    mask = targets == cls
    class_logits = logits[mask]
    counts, edges = np.histogram(class_logits, bins=bins)

    if bin_edges is None:
        bin_edges = edges

    bin_centers = 0.5 * (edges[:-1] + edges[1:])
    hist_data.append((cls, bin_centers, counts))

# Create Plotly figure
fig = go.Figure()

for cls, bin_centers, counts in hist_data:
    fig.add_trace(go.Bar(
        x=bin_centers,
        y=counts,
        name=f'Class {cls}',
        opacity=0.6
    ))

# Update layout
fig.update_layout(
    #title=f'Histogram of Router V1 Client RT by Class',
    xaxis_title='Logit Value',
    yaxis_title='Frequency',
    barmode='overlay',
    template='plotly_white',
    font=dict(size=14)
)

fig.show()


In [None]:
import plotly.graph_objects as go
import numpy as np

bins = 30
target_filter = list(range(10)) # 
classes = np.unique(h1_targets[4])
logits = []
targets = []
for i in h1_clients_routers[4].keys():
    logits += list(h1_clients_routers[4][i])
    targets += h1_targets[4]
logits  = np.array(logits)   
targets  = np.array(targets)                       

if not isinstance(target_filter, (list, np.ndarray)):
    target_filter = [target_filter]

hist_data = []
bin_edges = None

for cls in classes:
    if cls not in target_filter:
        continue

    mask = targets == cls
    class_logits = logits[mask]
    counts, edges = np.histogram(class_logits, bins=bins)

    if bin_edges is None:
        bin_edges = edges

    bin_centers = 0.5 * (edges[:-1] + edges[1:])
    hist_data.append((cls, bin_centers, counts))

fig = go.Figure()

for cls, bin_centers, counts in hist_data:
    fig.add_trace(go.Bar(
        x=bin_centers,
        y=counts,
        name=f'Class {cls}',
        opacity=0.6
    ))

fig.update_layout(
    title=f'Histogram of Router V1 by Class',
    xaxis_title='Logit Value',
    yaxis_title='Frequency',
    barmode='overlay',
    template='plotly_white'
)

fig.show()


In [None]:
SEED = 4
thresholds = np.linspace(1, 0, 100) 
logits = h2_clients_routers[seed]
preds = h2_clients_preds[seed]
server =  h2_server_preds[seed]
targets = np.array(h2_targets[seed])
max_local_avg = h2_clients_preds[seed]
def calculate_hybrid_accuracy(logits, preds, server, targets, thresholds): 
    hybrid_accuracy = {}
    coverages = {} 
    for threshold in thresholds:
        hybrid_accuracy[np.float64(round(threshold, 4))] = {}
        coverages[np.float64(round(threshold, 4))] = {}
        for key in preds.keys():
            cloud_routing_decision = np.array(logits[0])>threshold
            local_routing_decision = ~cloud_routing_decision
            coverage = sum(cloud_routing_decision)/len(targets)
            hybrid = sum((preds[key] == targets)*local_routing_decision) + sum((server == targets)*cloud_routing_decision)
            hybrid = hybrid/len(targets)
            coverages[np.float64(round(threshold, 4))][key] = coverage
            hybrid_accuracy[np.float64(round(threshold, 4))][key] = hybrid
    return hybrid_accuracy, coverages

hybrid_accuracy, coverages = calculate_hybrid_accuracy(logits, preds, server, targets, thresholds)

In [None]:
clientmapper = ['LT', 'RT', 'LB', 'RB']
clients = list(preds.keys())
sorted_thresholds = sorted(hybrid_accuracy.keys())

fig = go.Figure()

for client in clients:
    coverage_vals = [coverages[t][client] for t in sorted_thresholds]
    accuracies = [hybrid_accuracy[t][client] for t in sorted_thresholds]
    
    fig.add_trace(go.Scatter(
        x=coverage_vals, 
        y=accuracies,
        mode='lines+markers',
        name=f'Client {clientmapper[client]}'
    ))

fig.update_layout(
    # title='Hybrid V1 - Hybrid Accuracy vs Coverage by Clients',
    xaxis_title='Coverage',
    yaxis_title='Hybrid Accuracy',
    xaxis=dict(range=[0, 1]),
    yaxis=dict(range=[0, 1]),
    template='plotly_white',
    legend_title="Clients",
    font=dict(size=14)
)

local_accuracies_at_zero_coverage = [
    hybrid_accuracy[1][client] for client in clients
]
max_local_accuracy = max(local_accuracies_at_zero_coverage)
"""
# Add horizontal line at max local accuracy
fig.add_hline(
    y=max_local_accuracy,
    line=dict(color='Red', dash='dash'),
    annotation_text=f'Max Hybrid Local Accuracy = {max_local_accuracy:.3f}',
    annotation_position='top right'
)
"""

max_local = max(local_mean)
fig.add_hline(
    y=max_local,
    line=dict(color='Red', dash='dash'),
    annotation_text=f'Max Local-Only Accuracy = {max_local:.3f}',
    annotation_position='top right'
)

fig.show()

In [None]:
SEED = 13
thresholds = np.linspace(1, 0, 100) 
logits = h1_clients_routers[seed]
preds = h1_clients_preds[seed]
server =  h1_server_preds[seed]
targets = np.array(h1_targets[seed])

hybrid_accuracy, coverages = calculate_hybrid_accuracy(logits, preds, server, targets, thresholds)

In [None]:
clientmapper = ['LT', 'RT', 'LB', 'RB']
clients = list(preds.keys())
sorted_thresholds = sorted(hybrid_accuracy.keys())

fig = go.Figure()

for client in clients:
    coverage_vals = [coverages[t][client] for t in sorted_thresholds]
    accuracies = [hybrid_accuracy[t][client] for t in sorted_thresholds]
    
    fig.add_trace(go.Scatter(
        x=coverage_vals, 
        y=accuracies,
        mode='lines+markers',
        name=f'Client {clientmapper[client]}'
    ))

fig.update_layout(
    #title='Hybrid V2 - Hybrid Accuracy vs Coverage by Clients',
    xaxis_title='Coverage',
    yaxis_title='Hybrid Accuracy',
    xaxis=dict(range=[0, 1]),
    yaxis=dict(range=[0, 1]),
    template='plotly_white',
    legend_title="Clients",
    font=dict(size=14)
)

"""
local_accuracies_at_zero_coverage = [
    hybrid_accuracy[1][client] for client in clients
]
max_local_accuracy = max(local_accuracies_at_zero_coverage)

fig.add_hline(
    y=max_local_accuracy,
    line=dict(color='Red', dash='dash'),
    annotation_text=f'Max Local Accuracy = {max_local_accuracy:.3f}',
    annotation_position='top right'
)
"""

max_local = max(local_mean)
fig.add_hline(
    y=max_local,
    line=dict(color='Red', dash='dash'),
    annotation_text=f'Max Local-Only Accuracy = {max_local:.3f}',
    annotation_position='top right'
)

fig.show()

In [None]:
import plotly.graph_objects as go
import numpy as np

bins = 30
target_filter = list(range(10))  # Filter specific classes
classes = np.unique(h1_targets[27])  # Available classes
logits = np.array(h1_clients_routers[27][0])  # Logits
targets = np.array(h1_targets[27])  # Ground truth labels
preds = np.array(h1_clients_preds[27][0])  # Client predictions

if not isinstance(target_filter, (list, np.ndarray)):
    target_filter = [target_filter]

hist_data_correct = []
hist_data_incorrect = []
bin_edges = None

for cls in classes:
    if cls not in target_filter:
        continue

    mask = targets == cls
    class_logits = logits[mask]
    class_preds = preds[mask]
    class_targets = targets[mask]

    # Correct and incorrect masks
    correct_mask = class_preds == class_targets
    incorrect_mask = ~correct_mask

    # Histogram for correct
    counts_correct, edges = np.histogram(class_logits[correct_mask], bins=bins)
    if bin_edges is None:
        bin_edges = edges
    bin_centers = 0.5 * (edges[:-1] + edges[1:])
    hist_data_correct.append((cls, bin_centers, counts_correct))

    # Histogram for incorrect
    counts_incorrect, _ = np.histogram(class_logits[incorrect_mask], bins=edges)
    hist_data_incorrect.append((cls, bin_centers, counts_incorrect))

# Plotting
fig = go.Figure()

for cls, bin_centers, counts in hist_data_correct:
    fig.add_trace(go.Bar(
        x=bin_centers,
        y=counts,
        name=f'Class {cls} (Correct)',
        opacity=0.6
    ))

for cls, bin_centers, counts in hist_data_incorrect:
    fig.add_trace(go.Bar(
        x=bin_centers,
        y=counts,
        name=f'Class {cls} (Incorrect)',
        opacity=0.6,
        marker=dict(pattern_shape="/")
    ))

fig.update_layout(
    title='Histogram of Router V1 Client LB by Class (Correct vs Incorrect)',
    xaxis_title='Logit Value',
    yaxis_title='Frequency',
    barmode='overlay',
    template='plotly_white',
    legend_title='Legend'
)

fig.show()


In [None]:
import plotly.graph_objects as go
import numpy as np

bins = 30
target_filter = list(range(10))  # Filter specific classes
classes = np.unique(h1_targets[27])  # Available classes
logits = np.array(h1_clients_routers[27][3])  # Logits
targets = np.array(h1_targets[27])  # Ground truth labels
preds = np.array(h1_clients_preds[27][3])  # Client predictions

if not isinstance(target_filter, (list, np.ndarray)):
    target_filter = [target_filter]

hist_data_correct = []
hist_data_incorrect = []
bin_edges = None

for cls in classes:
    if cls not in target_filter:
        continue

    mask = targets == cls
    class_logits = logits[mask]
    class_preds = preds[mask]
    class_targets = targets[mask]

    # Correct and incorrect masks
    correct_mask = class_preds == class_targets
    incorrect_mask = ~correct_mask

    # Histogram for correct
    counts_correct, edges = np.histogram(class_logits[correct_mask], bins=bins)
    if bin_edges is None:
        bin_edges = edges
    bin_centers = 0.5 * (edges[:-1] + edges[1:])
    hist_data_correct.append((cls, bin_centers, counts_correct))

    # Histogram for incorrect
    counts_incorrect, _ = np.histogram(class_logits[incorrect_mask], bins=edges)
    hist_data_incorrect.append((cls, bin_centers, counts_incorrect))

# Plotting
fig = go.Figure()

for cls, bin_centers, counts in hist_data_correct:
    fig.add_trace(go.Bar(
        x=bin_centers,
        y=counts,
        name=f'Class {cls} (Correct)',
        opacity=0.6
    ))

for cls, bin_centers, counts in hist_data_incorrect:
    fig.add_trace(go.Bar(
        x=bin_centers,
        y=counts,
        name=f'Class {cls} (Incorrect)',
        opacity=0.6,
        marker=dict(pattern_shape="/")
    ))

fig.update_layout(
    title='Histogram of Router V1 Client LB by Class (Correct vs Incorrect)',
    xaxis_title='Logit Value',
    yaxis_title='Frequency',
    barmode='overlay',
    template='plotly_white',
    legend_title='Legend'
)

fig.show()


In [None]:
import collections
dataset = simulation_on_client.dataset.train_dataset
labels = [label for _, label in dataset]
counter = collections.Counter(labels)

keys = list(counter.keys())
values = list(counter.values())

fig = go.Figure(
    data=[go.Bar(x=keys, y=values, marker_color='skyblue')],
    layout=go.Layout(
        title="Label Count",
        xaxis_title="Labels",
        yaxis_title="Amount",
        template="plotly_dark"
    )
)
fig.show()