In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from federated_inference.common.utils import set_seed
from federated_inference.configs.data_config import DataConfiguration
from federated_inference.configs.transform_config import DataTransformConfiguration

# Hybrid

In [None]:

from federated_inference.simulations.hybrid.simulation import HybridSplitSimulation

In [None]:
from federated_inference.simulations.hybrid.v6.models import GlobalHybridSplitClassifierHead, HybridSplitBase, LocalHybridSplitClassifierHead, RouterHead
DATASET = 'FMNIST'
VERSION = "v61"
SEED = 4
set_seed(SEED)
data_config = DataConfiguration(DATASET)
transform_config = DataTransformConfiguration()
simulation_hybrid_1 = HybridSplitSimulation(SEED, VERSION,  data_config, transform_config, GlobalHybridSplitClassifierHead, HybridSplitBase, LocalHybridSplitClassifierHead, RouterHead)
simulation_hybrid_1.server.load()

In [None]:

from federated_inference.simulations.hybrid.v7.models import GlobalHybridSplitClassifierHead, HybridSplitBase, LocalHybridSplitClassifierHead, RouterHead
DATASET = 'MNIST'
VERSION = "v71"
SEED = 4
set_seed(SEED)
data_config = DataConfiguration(DATASET)
transform_config = DataTransformConfiguration()
simulation_hybrid_2 = HybridSplitSimulation(SEED, VERSION,  data_config, transform_config, GlobalHybridSplitClassifierHead, HybridSplitBase, LocalHybridSplitClassifierHead, RouterHead)
simulation_hybrid_2.server.load()

## On Device

In [None]:
from federated_inference.simulations.ondevice.simulation import OnDeviceVerticalSimulation
from federated_inference.simulations.ondevice.models import OnDeviceMNISTModel
DATASET = 'MNIST'
VERSION = "v1"
SEED = 4
data_config = DataConfiguration(DATASET)
transform_config = DataTransformConfiguration()
simulation_on_client = OnDeviceVerticalSimulation(SEED, VERSION, data_config, transform_config, OnDeviceMNISTModel, exist=False)
[client.load() for client in simulation_on_client.clients]


## On Cloud

In [None]:
from federated_inference.simulations.oncloud.simulation import OnCloudVerticalSimulation
from federated_inference.simulations.oncloud.models.model import OnCloudMNISTModel
DATASET = 'MNIST'
VERSION = "v1"
SEED = 4
data_config = DataConfiguration(DATASET)
transform_config = DataTransformConfiguration()
simulation_on_cloud = OnCloudVerticalSimulation(SEED, VERSION, data_config, transform_config, OnCloudMNISTModel, exist=False)
simulation_on_cloud.server.load()

# Visulaization

In [None]:

import math
import plotly.graph_objects as go
import plotly.subplots as sp
from IPython.display import display

from federated_inference.dataset.client import ClientDataset
from federated_inference.simulations.utils import tensor_to_numpy_image


client_map = ['LT', 'RT', 'LB', "RB"]

def create_simulation_image_subplots(simulation, idx=0):
    datasets = simulation.client_datasets
    if isinstance(datasets[0], ClientDataset):
        datasets = [d.train_dataset for d in datasets]

    n_clients = len(datasets)
    cols = math.floor(n_clients / 2) if n_clients > 3 else 1
    rows = math.ceil(n_clients / cols)

    fig = sp.make_subplots(rows=rows, cols=cols,
                           subplot_titles=[f"Client {client_map[i]}" for i in range(n_clients)],
                           vertical_spacing=0.08,
                           horizontal_spacing=0.01)

    for i, dataset in enumerate(datasets):
        r, c = divmod(i, cols)
        img, _ = dataset[idx]
        img_np = tensor_to_numpy_image(img)
        fig.add_trace(go.Image(z=img_np), row=r + 1, col=c + 1)

    fig.update_layout(height=250 * rows, width=250 * cols, showlegend=False)
    fig.update_xaxes(showticklabels=False)
    fig.update_yaxes(showticklabels=False)

    return fig

In [None]:
fig = create_simulation_image_subplots(simulation_hybrid_1)
display(fig)

In [None]:
from federated_inference.transform.data_splitter import DataSplitter
import random
client_mapping = {
    0: "LT",
    1: "RT",
    2: "LB",
    3: "RB"
}


def create_client_image_subplots(simulation, client_indices=[0], n_images=3, keys=None):
    datasets = simulation.client_datasets
    if isinstance(datasets[0], ClientDataset):
        datasets = [d.train_dataset for d in datasets]

    grouped_trainset = DataSplitter.group_dataset(simulation.clients[0].data.train_dataset)
    grouped_trainset = {k: grouped_trainset[k] for k in sorted(grouped_trainset)}

    k_indices = {}
    n_clients, n_keys = len(client_indices), len(keys)
    rows, cols = n_clients * n_keys, n_images

    fig = sp.make_subplots(rows=rows, cols=cols,
                           vertical_spacing=0.03,
                           horizontal_spacing=0.01)

    for k, key in enumerate(keys):
        indices = grouped_trainset.get(key, [])
        if len(indices) < n_images:
            continue

        sampled_idx = random.sample(indices, n_images)
        k_indices[key] = sampled_idx

        for r, client_idx in enumerate(client_indices):
            dataset = datasets[client_idx]
            row = k * n_clients + r + 1
            for c in range(n_images):
                img, _ = dataset[sampled_idx[c]]
                img_np = tensor_to_numpy_image(img)
                fig.add_trace(go.Image(z=img_np), row=row, col=c + 1)

            fig.update_yaxes(title_text=f"Client {client_mapping[client_idx]} <br> Label {key}", row=row, col=1)

    a4_width_px = 794
    a4_height_px = 1123

    fig.update_layout(
        height=min(rows * 130, a4_height_px - 60),
        width=min(cols * 130, a4_width_px - 60),
        showlegend=False,
        margin=dict(t=20, b=20, l=20, r=20)
    )
    fig.update_xaxes(showticklabels=False)
    fig.update_yaxes(showticklabels=False)

    return fig, k_indices 

for client in simulation_hybrid_1.clients:
    fig, subplot_indices = create_client_image_subplots(simulation_hybrid_1, [0], 4, keys=[5,7,9])
    display(fig) 

In [None]:
for client in simulation_hybrid_1.clients:
    fig, subplot_indices = create_client_image_subplots(simulation_hybrid_1, [3], 4, keys=[2,4,6])
    display(fig) 

In [None]:
for client in simulation_hybrid_1.clients:
    fig, subplot_indices = create_client_image_subplots(simulation_hybrid_1, [client.idx], 4, keys=range(10))
    display(fig) 