In [1]:
import torch

In [65]:
import json
from pathlib import Path
from typing import Union, Dict, Any
import numpy as np
import torch

def load_results_json(path: Union[str, Path]) -> Dict[str, Any]:
    """
    Load the entire results.json file into a dict.
    """
    path = Path(path)
    with path.open("r", encoding="utf-8") as fp:
        return json.load(fp)

def get_all_client_weights(
    results: Dict[str, Any],
    as_tensor: bool = False
) -> Dict[int, Dict[str, Dict[str, Union[np.ndarray, torch.Tensor]]]]:
    """
    Extract every round’s client_weights.

    Returns:
        {
            round_no: {
                client_id: {
                    layer_name: array (np.ndarray or torch.Tensor),
                    …
                },
                …
            },
            …
        }
    """
    all_weights: Dict[int, Dict[str, Dict[str, Union[np.ndarray, torch.Tensor]]]] = {}
    for entry in results.get("client_weights", []):
        # Pull out the round number
        round_no = entry.get("round")
        if round_no is None:
            continue

        # Build the per-client dict for this round
        clients_dict: Dict[str, Dict[str, Union[np.ndarray, torch.Tensor]]] = {}
        for client_id, layer_dict in entry.items():
            if client_id == "round":
                continue
            layers_converted: Dict[str, Union[np.ndarray, torch.Tensor]] = {}
            for layer_name, weight_list in layer_dict.items():
                arr = np.array(weight_list)
                layers_converted[layer_name] = torch.from_numpy(arr) if as_tensor else arr
            clients_dict[client_id] = layers_converted

        all_weights[round_no] = clients_dict

    return all_weights

# Example usage:
results = load_results_json("outputs/2025-07-16/14-30-50/results.json")
# Get every round, as NumPy arrays:
weights_by_round = get_all_client_weights(results, as_tensor=False)

# Iterate:
for r, clients in sorted(weights_by_round.items()):
    print(f"Round {r}:")
    for cid, layers in clients.items():
        print(f"  Client {cid} — conv1.weight mean:", layers["conv1.weight"].mean())


Round 1:
  Client 7 — conv1.weight mean: 0.0017151853426670034
  Client 9 — conv1.weight mean: -0.003939488964776198
  Client 4 — conv1.weight mean: 0.012514181658884304
  Client 3 — conv1.weight mean: 0.01415559010166261
  Client 6 — conv1.weight mean: -0.0005082572608565291
  Client 5 — conv1.weight mean: -0.007678738078555195
  Client 0 — conv1.weight mean: 0.0037007537277208434
  Client 2 — conv1.weight mean: -0.004148651605290878
  Client 8 — conv1.weight mean: 0.004539177582288782
  Client 1 — conv1.weight mean: -0.011957944511216031
Round 2:
  Client 2 — conv1.weight mean: -0.00282718603240533
  Client 3 — conv1.weight mean: -0.0019404485625111394
  Client 1 — conv1.weight mean: -0.011056540314004652
  Client 0 — conv1.weight mean: 0.0020371859250331503
  Client 6 — conv1.weight mean: -0.01330082072578888
  Client 7 — conv1.weight mean: -0.005430049776461803
  Client 9 — conv1.weight mean: -0.0006182578652967803
  Client 5 — conv1.weight mean: -0.010638173868291132
  Client 4 — 

In [72]:
weights_by_round[1].keys()

dict_keys(['7', '9', '4', '3', '6', '5', '0', '2', '8', '1'])

array([[[[ 6.39094189e-02, -5.68399020e-02, -9.72044747e-03,
          -3.53243127e-02,  2.83537600e-02],
         [-9.34639722e-02, -1.54831037e-02, -2.92537250e-02,
           7.81003386e-03, -1.03339456e-01],
         [-1.15620218e-01, -4.79109213e-02, -4.13093865e-02,
          -3.35065871e-02, -1.05613999e-01],
         [-1.01565972e-01, -7.18552768e-02, -8.61203000e-02,
          -4.41386225e-03, -9.18394923e-02],
         [ 1.97833101e-03, -1.34661272e-01, -8.32171142e-02,
          -6.36611432e-02, -8.88871551e-02]],

        [[-1.35902479e-01, -1.22752741e-01, -9.90193039e-02,
          -7.94646963e-02, -1.24218062e-01],
         [-1.27602220e-02, -2.14089360e-02, -8.68321732e-02,
          -9.09018740e-02, -1.20593831e-01],
         [-1.71098456e-01, -1.93366613e-02, -6.24994822e-02,
          -1.77896395e-02, -5.15580401e-02],
         [-6.85420856e-02, -1.45369954e-02, -8.06151628e-02,
          -3.11836991e-02, -5.52945212e-02],
         [-7.16733560e-02, -6.80095330e-02, 