# Analysis of `wandb` Logs and Metrics
This notebook collects different metrics for a specified experiment.

## Setup

In [1]:
import numpy as np
import pprint

import warnings
warnings.filterwarnings("ignore")

import wandb
api = wandb.Api()

GROUP = "experiment-e3h42p8f"

In [2]:
runs = api.runs("uedyiuajxz-personal/qfl-prod")

In [16]:
grouped_runs = [run for run in runs if run.config.get("group", None) == GROUP]
server = None
clients = []

for run in grouped_runs:
    participant = run.config.get("participant", None)
    if participant == "server":
        server = run
    else:
        clients.append(run)
    print(f"Run ID: {run.id}, Name: {run.name}, Type: {participant}")

Run ID: xmqtfaqg, Name: morning-vortex-153, Type: server
Run ID: 1qy9v9a1, Name: ethereal-pyramid-160, Type: client18
Run ID: ihx10drh, Name: silvery-fog-160, Type: client10
Run ID: 5dngqr6u, Name: good-dust-160, Type: client9
Run ID: z3qllbpz, Name: royal-moon-157, Type: client17
Run ID: 0d6kdedf, Name: fast-water-170, Type: client15
Run ID: w9jfzk2m, Name: cool-serenity-172, Type: client0
Run ID: agagv85r, Name: avid-capybara-160, Type: client5
Run ID: d3taghvl, Name: soft-shadow-160, Type: client2
Run ID: d5rnpnsm, Name: charmed-oath-154, Type: client3
Run ID: tcunwwyq, Name: glorious-pond-170, Type: client7
Run ID: rxodotpo, Name: fine-resonance-173, Type: client12
Run ID: qqo1rt6o, Name: earthy-music-156, Type: client19
Run ID: fnm96510, Name: rural-field-168, Type: client11
Run ID: qkdjh3v2, Name: unique-planet-160, Type: client6
Run ID: g7zv198y, Name: devoted-snowball-160, Type: client16
Run ID: gp7d3wdp, Name: bright-sky-156, Type: client14
Run ID: gruj41ei, Name: jumping-bree

## Settings

In [4]:
pp = pprint.PrettyPrinter(indent=1, width=40)
pp.pprint(server.config)

{'batch_size': 32,
 'dataset': 'MRI',
 'fhe_enabled': False,
 'group': 'experiment-e3h42p8f',
 'learning_rate': '1e-3',
 'model': 'fednn',
 'number_clients': 20,
 'participant': 'server',
 'rounds': 20}


## Server and Aggregated Metrics

In [19]:
def compute_statistics(arr: np.ndarray) -> dict:
    """
    Compute basic statistical measures of an array, including mean, median, standard deviation, 
    25th percentile (Q1), and 75th percentile (Q3).

    Parameters
    ----------
    arr : np.ndarray
        Input numerical array.

    Returns
    -------
    dict
        A dictionary containing the following statistics:
        - "mean" : float
            The average of the array.
        - "median" : float
            The middle value of the array.
        - "std" : float
            The standard deviation (sample standard deviation with `ddof=1`).
        - "25%" : float
            The 25th percentile (Q1).
        - "75%" : float
            The 75th percentile (Q3).

    Examples
    --------
    >>> import numpy as np
    >>> data = np.array([10, 20, 30, 40, 50, 60, 70, 80, 90, 100])
    >>> compute_statistics(data)
    {'mean': 55.0, 'median': 55.0, 'std': 30.276503540974915, '25%': 32.5, '75%': 77.5}
    """
    if not isinstance(arr, np.ndarray):
        arr = np.array(arr)

    stats = {
        "mean": np.mean(arr),
        "median": np.median(arr),
        "std": np.std(arr, ddof=1),
        "25%": np.percentile(arr, 25),
        "75%": np.percentile(arr, 75),
        "min": np.min(arr),
        "max": np.max(arr),
        "last_value": arr[-1]
    }
    return stats

In [25]:
from rich.console import Console
from rich.table import Table

console = Console()

# Create Table 1 (Basic Stats: Mean, Median, Std)
table1 = Table(title="Server Statistics (Part 1)")
table1.add_column("Metric", justify="left", style="cyan", min_width=20)
table1.add_column("Mean", justify="right", style="green", min_width=12)
table1.add_column("Median", justify="right", style="green", min_width=12)
table1.add_column("Std", justify="right", style="green", min_width=12)

# Create Table 2 (Percentiles + Min/Max/Last Value)
table2 = Table(title="Server Statistics (Part 2)")
table2.add_column("Metric", justify="left", style="cyan", min_width=20)
table2.add_column("25%", justify="right", style="green", min_width=12)
table2.add_column("75%", justify="right", style="green", min_width=12)
table2.add_column("Max", justify="right", style="green", min_width=12)
table2.add_column("Min", justify="right", style="green", min_width=12)
table2.add_column("Last", justify="right", style="green", min_width=12)

server_history = server.history()
for metric in server_history.keys():
    metric_values = [val for val in server_history[metric] if not np.isnan(val)]
    stats = compute_statistics(metric_values)

    # Add row to Table 1
    table1.add_row(
        metric,
        f"{stats['mean']:.2f}",
        f"{stats['median']:.2f}",
        f"{stats['std']:.2f}"
    )

    # Add row to Table 2
    table2.add_row(
        metric,
        f"{stats['25%']:.2f}",
        f"{stats['75%']:.2f}",
        f"{stats['max']:.2f}",
        f"{stats['min']:.2f}",
        f"{stats['last_value']:.2f}"
    )


console.print(table1)
console.print(table2)

## Client Metrics

In [None]:
def compute_client_statistics(client_histories):
    """
    Computes statistics for each metric by merging values across all clients.

    Parameters
    ----------
    client_histories : list[dict]
        A list of dictionaries where each dictionary contains metric history for a single client.

    Returns
    -------
    dict
        A dictionary where each key is a metric name and the value is its computed statistics.
    """
    merged_metrics = {}

    # Iterate through all client histories
    for client_history in client_histories:
        for metric, values in client_history.items():
            # Remove NaN values and flatten lists
            cleaned_values = [val for val in values if not np.isnan(val)]

            # Merge values from all clients
            if metric not in merged_metrics:
                merged_metrics[metric] = []
            merged_metrics[metric].extend(cleaned_values)

    # Compute statistics for each metric
    stats_dict = {metric: compute_statistics(np.array(values)) for metric, values in merged_metrics.items()}

    return stats_dict

# Example: Fetch history from multiple clients
client_histories = [client.history() for client in clients]  # Replace with actual client list

# Compute merged statistics
merged_stats = compute_client_statistics(client_histories)

# Display results using rich
console = Console()

# Create Table 1 (Basic Stats: Mean, Median, Std)
table1 = Table(title="Merged Client Metrics Statistics (Part 1)")
table1.add_column("Metric", justify="left", style="cyan", min_width=20)
table1.add_column("Mean", justify="right", style="green", min_width=12)
table1.add_column("Median", justify="right", style="green", min_width=12)
table1.add_column("Std", justify="right", style="green", min_width=12)

# Create Table 2 (Percentiles + Min/Max/Last Value)
table2 = Table(title="Merged Client Metrics Statistics (Part 2)")
table2.add_column("Metric", justify="left", style="cyan", min_width=20)
table2.add_column("25%", justify="right", style="green", min_width=12)
table2.add_column("75%", justify="right", style="green", min_width=12)
table2.add_column("Max", justify="right", style="green", min_width=12)
table2.add_column("Min", justify="right", style="green", min_width=12)
table2.add_column("Last", justify="right", style="green", min_width=12)

# Populate table with computed statistics
for metric, stats in merged_stats.items():
    table1.add_row(
        metric,
        f"{stats['mean']:.2f}",
        f"{stats['median']:.2f}",
        f"{stats['std']:.2f}"
    )

    # Add row to Table 2
    table2.add_row(
        metric,
        f"{stats['25%']:.2f}",
        f"{stats['75%']:.2f}",
        f"{stats['max']:.2f}",
        f"{stats['min']:.2f}",
        f"{stats['last_value']:.2f}"
    )

# Display table
console.print(table1)
console.print(table2)
