In [None]:
import logging
import multiprocessing as mp
from dataclasses import dataclass
from typing import Callable, Tuple
import torch
import torch.nn as nn
import numpy as np

In [None]:
@dataclass
class ConfigParameters:
    """
    Configuration for Stale Synchronous Parallel training for Asynchronous SGD (SSP-ASGD).

    :param num_workers: Number of worker processes.
    :param staleness: Staleness bound allowed for the workers during training.
    :param lr: Learning rate for the model.
    :param local_steps: Number of local updates per worker.
    :param batch_size: Batch size for each training step.
    :param device: Device to use for training.
    :param log_level: Logging verbosity level.
    """
    num_workers: int = 4
    staleness: int = 2
    lr: float  = 0.01
    local_steps: int = 500
    batch_size: int = 128
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    log_level: int = logging.INFO

In [None]:
class ParameterServer:
    """
    Parameter Server for SSP-ASGD with immediate, per-worker updates.
    Each pushed gradient is applied right away (divided by num_workers),
    global version is updated to max(t_max, version), and per-worker versions
    are recorded.
    """

    def __init__(self, model: nn.Module, param: ConfigParameters) -> None:
        self.param = param
        # Shared global parameters
        self.theta = [p.detach() for p in model.parameters()]
        for p in self.theta:
            p.share_memory_()

        # Synchronization primitives
        self._lock = mp.Lock()
        self._cv = mp.Condition(self._lock)

        # Global version and per-worker versions
        self._current_version = mp.Value("i", 0)
        self._worker_versions = mp.Array("i", [0] * param.num_workers)

    def pull(self) -> Tuple[list[torch.Tensor], int]:
        """
        Return a clone of the current global parameters and the current global version.
        """
        with self._lock:
            return [p.clone() for p in self.theta], self._current_version.value

    def push(self, wid: int, version: int, grads: list[torch.Tensor]) -> None:
        """
        Apply the worker's gradient update immediately, dividing by num_workers,
        update the global version to max(current, version), and record the worker's version.
        """
        with self._lock:
            scale = self.param.lr / self.param.num_workers
            for idx, g in enumerate(grads):
                self.theta[idx].sub_(scale * g.to(self.theta[idx].device))

            # Update global version
            if version > self._current_version.value:
                self._current_version.value = version

            # Record this worker's version
            self._worker_versions[wid] = version

            # Notify any waiting pulls
            self._cv.notify_all()

    def get_version(self) -> int:
        """
        Get the current global version.
        """
        return self._current_version.value

In [None]:
def worker(
    w_id: int,
    server:  ParameterServer,
    model_fn: Callable[[int], nn.Module],
    input_dim:  int,
    dataset_builder: Callable[[int,int,int], Tuple[torch.utils.data.DataLoader,int]],
    param: ConfigParameters
) -> None:
    """
    Worker process for SSP-ASGD.
    """
    logging.basicConfig(
        level=param.log_level,
        format=f"%(asctime)s [Worker-{w_id}] %(message)s",
        datefmt="%H:%M:%S",
    )

    # Build data loader
    loader, _ = dataset_builder(param.num_workers, param.batch_size, w_id)
    device = torch.device(param.device)
    model = model_fn(input_dim).to(device)
    criterion = nn.BCELoss()

    # Initial pull
    state, version = server.pull()
    with torch.no_grad():
        for p, s in zip(model.parameters(), state):
            p.copy_(s.to(device))
    local_ver = version

    # Local training loop
    for _ in range(param.local_steps):
        for X_batch, y_batch in loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            model.train()
            output = model(X_batch)
            loss = criterion(output, y_batch.float())
            loss.backward()

            # Collect gradients
            grads = [p.grad.detach().cpu() for p in model.parameters()]
            for p in model.parameters():
                p.grad = None

            # Pull if too stale
            t_max = server.get_version()
            if t_max - local_ver > param.staleness:
                state, t_max = server.pull()
                with torch.no_grad():
                    for p, s in zip(model.parameters(), state):
                        p.copy_(s.to(device))
                local_ver = t_max

            # Advance version and push update
            local_ver += 1
            server.push(w_id, local_ver, grads)


def run_ssp_training(
    dataset_builder: Callable[[int,int,int], Tuple[torch.utils.data.DataLoader,int]],
    model_fn: Callable[[int], nn.Module],
    param: ConfigParameters = ConfigParameters(),
) -> Tuple[list[torch.Tensor], int]:
    """
    Run SSP-ASGD training across multiple worker processes.

    Returns the final global parameters and the input dimension.
    """
    _, input_dim = dataset_builder(param.num_workers, param.batch_size, 0)
    init_model = model_fn(input_dim)
    ps = ParameterServer(init_model, param)

    ctx = mp.get_context("fork")
    procs = []
    for wid in range(param.num_workers):
        p = ctx.Process(
            target=worker,
            args=(wid, ps, model_fn, input_dim, dataset_builder, param),
            daemon=False
        )
        p.start()
        procs.append(p)

    for p in procs:
        p.join()
        if p.exitcode != 0:
            raise RuntimeError(f"Worker {p.name} crashed (exitcode {p.exitcode})")

    theta, _ = ps.pull()
    return theta, input_dim


def build_model(theta: list[torch.Tensor], model_fn: Callable[[int], nn.Module], input_dim: int) -> nn.Module:
    """
    Instantiate a model and load the given parameters.
    """
    model = model_fn(input_dim)
    with torch.no_grad():
        for p, t in zip(model.parameters(), theta):
            p.copy_(t)
    return model


def evaluate_model(name: str, model: nn.Module, X_eval: np.ndarray, y_eval: np.ndarray) -> float:
    """
    Evaluate binary classification model accuracy.
    """
    model.eval()
    with torch.no_grad():
        preds = model(torch.from_numpy(X_eval)).numpy() > 0.5
        acc = np.mean(preds == y_eval)
        print(f"{name} Test accuracy: {acc:.4f}")
        return acc
