In [None]:
# dp_mnist.py
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Hyperparameters for DP (fixed parameters)
fixed_clip_norm = 1.0         # fixed clipping threshold
noise_multiplier = 0.1        # fixed noise multiplier (σ)
batch_size = 64
epochs = 10
learning_rate = 0.01
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define a simple neural network for MNIST
class SimpleMLP(nn.Module):
    def __init__(self):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Prepare the MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True,
                               transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False,
                              transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=1000, shuffle=False)

# Evaluation function
def evaluate(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += target.size(0)
    return correct / total

# Training loop implementing DP-SGD with fixed parameters
def train_dp():
    model = SimpleMLP().to(device)
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)

    for epoch in range(1, epochs+1):
        model.train()
        total_loss = 0.0
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()

            output = model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()

            # Compute total gradient norm (as a proxy for per-sample norms)
            total_norm = 0.0
            for param in model.parameters():
                if param.grad is not None:
                    param_norm = param.grad.data.norm(2)
                    total_norm += param_norm.item() ** 2
            total_norm = total_norm ** 0.5

            # Clip gradients if needed
            clip_coef = fixed_clip_norm / (total_norm + 1e-6)
            if clip_coef < 1:
                for param in model.parameters():
                    if param.grad is not None:
                        param.grad.data.mul_(clip_coef)

            # Add Gaussian noise to each gradient
            for param in model.parameters():
                if param.grad is not None:
                    noise = torch.normal(mean=0, std=noise_multiplier * fixed_clip_norm, size=param.grad.data.shape, device=device)
                    param.grad.data.add_(noise)

            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        test_acc = evaluate(model, test_loader, device)
        print(f'Epoch {epoch:02d} - Loss: {avg_loss:.4f}, Test Accuracy: {test_acc*100:.2f}%')

    # (For demonstration, we print a fixed privacy budget message)
    print("Baseline DP-SGD finished. (Privacy accounting was not detailed here.)")

if __name__ == '__main__':
    train_dp()


Epoch 01 - Loss: 1.2169, Test Accuracy: 86.58%
Epoch 02 - Loss: 0.4888, Test Accuracy: 89.27%
Epoch 03 - Loss: 0.3923, Test Accuracy: 89.92%
Epoch 04 - Loss: 0.3530, Test Accuracy: 90.99%
Epoch 05 - Loss: 0.3298, Test Accuracy: 91.04%
Epoch 06 - Loss: 0.3148, Test Accuracy: 91.32%
Epoch 07 - Loss: 0.3022, Test Accuracy: 91.56%
Epoch 08 - Loss: 0.2927, Test Accuracy: 92.09%
Epoch 09 - Loss: 0.2833, Test Accuracy: 91.98%
Epoch 10 - Loss: 0.2751, Test Accuracy: 92.05%
Baseline DP-SGD finished. (Privacy accounting was not detailed here.)


In [None]:
# adp_mnist_fixed.py
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Initial hyperparameters for adaptive DP
init_clip_norm = 1.0          # starting clipping threshold
init_noise_multiplier = 0.1   # initial noise multiplier (σ)
alpha = 0.2                   # factor to scale the average gradient norm (must be <= 1 to dampen explosive growth)
tau = 0.1                     # smoothing factor for moving average update of clip_norm
beta = 0.9                    # decay factor for noise multiplier when validation improves

batch_size = 64
epochs = 20
learning_rate = 0.01
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define a simple neural network for MNIST
class SimpleMLP(nn.Module):
    def __init__(self):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Prepare the MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True,
                               transform=transform, download=True)
# Use part of the training set as a "validation" set
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=1000, shuffle=False)
test_dataset = datasets.MNIST(root='./data', train=False,
                              transform=transform, download=True)
test_loader  = DataLoader(test_dataset, batch_size=1000, shuffle=False)

# Evaluation function (for both validation and test)
def evaluate(model, loader, device):
    model.eval()
    total, correct = 0, 0
    loss_total = 0.0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            loss_total += F.cross_entropy(outputs, target, reduction='sum').item()
            preds = outputs.argmax(dim=1)
            correct += (preds == target).sum().item()
            total += target.size(0)
    return loss_total / total, correct / total

# Training loop implementing adaptive DP-SGD with stability fixes
def train_adp():
    model = SimpleMLP().to(device)
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)

    clip_norm = init_clip_norm
    noise_multiplier = init_noise_multiplier
    val_loss_history = []  # record validation losses to check for consecutive decreases

    for epoch in range(1, epochs+1):
        model.train()
        total_loss = 0.0
        batch_grad_norms = []  # record gradient norms in this epoch

        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()

            outputs = model(data)
            loss = F.cross_entropy(outputs, target)
            loss.backward()

            # Compute global gradient norm (all parameters combined)
            total_norm_sq = 0.0
            for param in model.parameters():
                if param.grad is not None:
                    total_norm_sq += param.grad.data.norm(2).item() ** 2
            total_norm = total_norm_sq ** 0.5
            batch_grad_norms.append(total_norm)

            # Clip gradients using the current adaptive clip_norm
            clip_coef = clip_norm / (total_norm + 1e-6)
            if clip_coef < 1:
                for param in model.parameters():
                    if param.grad is not None:
                        param.grad.data.mul_(clip_coef)

            # Add Gaussian noise using the current clip_norm and noise multiplier
            for param in model.parameters():
                if param.grad is not None:
                    noise = torch.normal(mean=0, std=noise_multiplier * clip_norm, size=param.grad.data.shape, device=device)
                    param.grad.data.add_(noise)

            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        avg_grad_norm = sum(batch_grad_norms) / len(batch_grad_norms)

        # Adaptive update: update clip_norm using a moving average to smooth the changes.
        measured_clip = alpha * avg_grad_norm
        new_clip_norm = (1 - tau) * clip_norm + tau * measured_clip

        # Evaluate on the validation set
        val_loss, val_acc = evaluate(model, val_loader, device)
        val_loss_history.append(val_loss)

        # Decay noise multiplier if validation loss decreases for three consecutive epochs
        if len(val_loss_history) >= 3 and val_loss_history[-3] > val_loss_history[-2] > val_loss_history[-1]:
            noise_multiplier = beta * noise_multiplier
            print("Validation loss decreased three epochs in a row. Reducing noise multiplier.")

        # Update clip_norm for the next epoch
        clip_norm = new_clip_norm

        # Evaluate on the test set for monitoring
        test_loss, test_acc = evaluate(model, test_loader, device)
        print(f"Epoch {epoch:02d}: Train Loss: {avg_loss:.4f}, Val Loss: {val_loss:.4f}, Test Acc: {test_acc*100:.2f}%")
        print(f"  (Adaptive clip norm: {clip_norm:.4f}, noise multiplier: {noise_multiplier:.4f})")

    print("Adaptive DP-SGD finished. (Privacy accounting was simulated.)")

if __name__ == '__main__':
    train_adp()


Epoch 01: Train Loss: 1.3210, Val Loss: 0.6808, Test Acc: 85.03%
  (Adaptive clip norm: 0.9184, noise multiplier: 0.1000)
Epoch 02: Train Loss: 0.5443, Val Loss: 0.4636, Test Acc: 88.60%
  (Adaptive clip norm: 0.8446, noise multiplier: 0.1000)
Validation loss decreased three epochs in a row. Reducing noise multiplier.
Epoch 03: Train Loss: 0.4272, Val Loss: 0.3985, Test Acc: 89.48%
  (Adaptive clip norm: 0.7783, noise multiplier: 0.0900)
Validation loss decreased three epochs in a row. Reducing noise multiplier.
Epoch 04: Train Loss: 0.3801, Val Loss: 0.3641, Test Acc: 90.25%
  (Adaptive clip norm: 0.7187, noise multiplier: 0.0810)
Validation loss decreased three epochs in a row. Reducing noise multiplier.
Epoch 05: Train Loss: 0.3540, Val Loss: 0.3465, Test Acc: 90.85%
  (Adaptive clip norm: 0.6655, noise multiplier: 0.0729)
Validation loss decreased three epochs in a row. Reducing noise multiplier.
Epoch 06: Train Loss: 0.3369, Val Loss: 0.3298, Test Acc: 91.14%
  (Adaptive clip norm

In [1]:
#!/usr/bin/env python3
"""
Split‐Learning DP MNIST Training Script
- Client‐side and Server‐side architectures are defined by user.
- Noise added only on the server side gradients (fixed or adaptive DP).
- Logs all experiment parameters and outcomes to a JSONL log file.
"""

import time
import json
from datetime import datetime

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


def parse_layer_sizes(prompt):
    """Helper to parse comma‐separated layer sizes into a list of ints."""
    s = input(prompt).strip()
    sizes = [int(x) for x in s.split(',') if x.strip()]
    if len(sizes) < 2:
        raise ValueError("Please specify at least two layer sizes (input and output).")
    return sizes


# === 1. Interactive parameter input ===
print("=== Differential Privacy Split Learning MNIST ===")
dp_type = input("Select DP type ('fixed' or 'adaptive'): ").strip().lower()
if dp_type not in ('fixed', 'adaptive'):
    raise ValueError("Invalid DP type; choose 'fixed' or 'adaptive'.")

client_layers = parse_layer_sizes(
    "Enter client‐side layer sizes (e.g. 784,256,128): ")
server_layers = parse_layer_sizes(
    "Enter server‐side layer sizes (e.g. 128,64,10): ")

batch_size      = int(input("Batch size [64]: ") or 64)
epochs          = int(input("Epochs [10]: ") or 10)
learning_rate   = float(input("Learning rate [0.01]: ") or 0.01)

if dp_type == 'fixed':
    clip_norm       = float(input("Clip norm [1.0]: ") or 1.0)
    noise_multiplier= float(input("Noise multiplier σ [0.1]: ") or 0.1)
else:
    clip_norm        = float(input("Initial clip norm [1.0]: ") or 1.0)
    noise_multiplier = float(input("Initial noise multiplier σ [0.1]: ") or 0.1)
    alpha            = float(input("Adaptive α [0.2]: ") or 0.2)
    tau              = float(input("Adaptive τ [0.1]: ") or 0.1)
    beta             = float(input("Adaptive β [0.9]: ") or 0.9)
    # We'll track validation improvements to decay noise_multiplier
    best_val_loss = float('inf')
    val_improve_count = 0

# === 2. Define split models ===
class ClientNet(nn.Module):
    def __init__(self, layers):
        super().__init__()
        modules = []
        for i in range(len(layers)-1):
            modules.append(nn.Linear(layers[i], layers[i+1]))
            # no activation on last client layer
            if i < len(layers)-2:
                modules.append(nn.ReLU())
        self.net = nn.Sequential(*modules)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.net(x)


class ServerNet(nn.Module):
    def __init__(self, layers):
        super().__init__()
        modules = []
        for i in range(len(layers)-1):
            modules.append(nn.Linear(layers[i], layers[i+1]))
            # no ReLU on last server layer
            if i < len(layers)-2:
                modules.append(nn.ReLU())
        self.net = nn.Sequential(*modules)

    def forward(self, x):
        return self.net(x)


# === 3. Data loaders ===
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_ds = datasets.MNIST('./data', train=True,  download=True, transform=transform)
test_ds  = datasets.MNIST('./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False)

# === 4. Setup ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
client = ClientNet(client_layers).to(device)
server = ServerNet(server_layers).to(device)

opt_client = optim.SGD(client.parameters(), lr=learning_rate)
opt_server = optim.SGD(server.parameters(), lr=learning_rate)

criterion = nn.CrossEntropyLoss()

# === 5. Training loop with server-side DP noise ===
start_time = datetime.now()
print(f"Starting training at {start_time.isoformat()} on {device}")

for epoch in range(1, epochs+1):
    client.train(); server.train()
    running_loss = 0.0

    for data, target in train_loader:
        data, target = data.to(device), target.to(device)

        # Forward pass: client ➔ server
        activation = client(data)
        output     = server(activation)
        loss       = criterion(output, target)

        # Zero grads
        opt_client.zero_grad()
        opt_server.zero_grad()

        # Backpropagate through both parts
        loss.backward(retain_graph=True)

        # === DP noise on server-side gradients only ===
        # Compute total L2 norm of server grads
        total_norm_sq = 0.0
        for p in server.parameters():
            if p.grad is not None:
                total_norm_sq += p.grad.data.norm(2).item() ** 2
        total_norm = total_norm_sq ** 0.5

        # Determine clip threshold
        current_clip = clip_norm

        # Clip server gradients if needed
        clip_coef = current_clip / (total_norm + 1e-6)
        if clip_coef < 1.0:
            for p in server.parameters():
                if p.grad is not None:
                    p.grad.data.mul_(clip_coef)

        # Add Gaussian noise
        for p in server.parameters():
            if p.grad is not None:
                noise = torch.randn_like(p.grad) * noise_multiplier * current_clip
                p.grad.data.add_(noise)

        # Adaptive update of clip_norm and noise_multiplier
        if dp_type == 'adaptive':
            # Update clip_norm via exponential moving average
            new_clip = tau * (alpha * total_norm) + (1 - tau) * clip_norm
            clip_norm = float(new_clip)
            # We'll do validation check at epoch end to decay noise_multiplier

        # Optimizer steps
        opt_server.step()
        opt_client.step()
        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)

    # Validation for adaptive DP
    val_loss = avg_loss
    if dp_type == 'adaptive':
        if val_loss < best_val_loss:
            val_improve_count += 1
            best_val_loss = val_loss
            if val_improve_count >= 3:
                noise_multiplier *= beta
                val_improve_count = 0
                print(f"[Epoch {epoch}] Validation improved 3x, decaying noise_multiplier to {noise_multiplier:.4f}")
        else:
            val_improve_count = 0

    print(f"Epoch {epoch}/{epochs} — Train Loss: {avg_loss:.4f} "
          f"— Clip Norm: {clip_norm:.4f} — Noise σ: {noise_multiplier:.4f}")

end_time = datetime.now()
duration = (end_time - start_time).total_seconds()
print(f"Finished at {end_time.isoformat()}, duration: {duration:.2f}s")

# === 6. Logging experiment details ===
log_entry = {
    "start_time":        start_time.isoformat(),
    "end_time":          end_time.isoformat(),
    "duration_seconds":  duration,
    "dataset":           "MNIST",
    "dp_type":           dp_type,
    "client_layers":     client_layers,
    "server_layers":     server_layers,
    "clip_norm":         clip_norm,
    "noise_multiplier":  noise_multiplier,
    "epochs":            epochs,
    "batch_size":        batch_size,
    "learning_rate":     learning_rate
}
with open("experiment_log.jsonl", "a") as lf:
    lf.write(json.dumps(log_entry) + "\n")

print("Experiment parameters and results written to experiment_log.jsonl")


=== Differential Privacy Split Learning MNIST ===
Select DP type ('fixed' or 'adaptive'): fixed
Enter client‐side layer sizes (e.g. 784,256,128): 784,128,128,32
Enter server‐side layer sizes (e.g. 128,64,10): 32,32,10
Batch size [64]: 64
Epochs [10]: 20
Learning rate [0.01]: 0.001
Clip norm [1.0]: 1.0
Noise multiplier σ [0.1]: 0.1


100%|██████████| 9.91M/9.91M [00:00<00:00, 18.4MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 493kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.50MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 7.59MB/s]


Starting training at 2025-04-25T12:29:19.163011 on cuda
Epoch 1/20 — Train Loss: 2.2953 — Clip Norm: 1.0000 — Noise σ: 0.1000
Epoch 2/20 — Train Loss: 2.2661 — Clip Norm: 1.0000 — Noise σ: 0.1000
Epoch 3/20 — Train Loss: 2.2078 — Clip Norm: 1.0000 — Noise σ: 0.1000
Epoch 4/20 — Train Loss: 2.0706 — Clip Norm: 1.0000 — Noise σ: 0.1000
Epoch 5/20 — Train Loss: 1.8212 — Clip Norm: 1.0000 — Noise σ: 0.1000
Epoch 6/20 — Train Loss: 1.4920 — Clip Norm: 1.0000 — Noise σ: 0.1000
Epoch 7/20 — Train Loss: 1.1136 — Clip Norm: 1.0000 — Noise σ: 0.1000
Epoch 8/20 — Train Loss: 0.8155 — Clip Norm: 1.0000 — Noise σ: 0.1000
Epoch 9/20 — Train Loss: 0.6503 — Clip Norm: 1.0000 — Noise σ: 0.1000
Epoch 10/20 — Train Loss: 0.5625 — Clip Norm: 1.0000 — Noise σ: 0.1000
Epoch 11/20 — Train Loss: 0.5093 — Clip Norm: 1.0000 — Noise σ: 0.1000
Epoch 12/20 — Train Loss: 0.4730 — Clip Norm: 1.0000 — Noise σ: 0.1000
Epoch 13/20 — Train Loss: 0.4448 — Clip Norm: 1.0000 — Noise σ: 0.1000
Epoch 14/20 — Train Loss: 0.42

In [None]:
#!/usr/bin/env python3
"""
Split‐Learning DP MNIST Training Script
- Client‐side and Server‐side architectures are defined via CLI flags.
- Noise added only on the server side gradients (fixed or adaptive DP).
- Logs all experiment parameters and outcomes to a JSONL log file.
"""

import argparse
import json
from datetime import datetime

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


def parse_layer_sizes(s: str):
    """Parse comma‐separated layer sizes into a list of ints."""
    sizes = [int(x) for x in s.split(',') if x.strip()]
    if len(sizes) < 2:
        raise argparse.ArgumentTypeError(
            "Must specify at least two layer sizes, e.g. 784,256,10"
        )
    return sizes


# === 1. Parse CLI arguments ===
parser = argparse.ArgumentParser(
    description="Split‐Learning DP MNIST Training")
parser.add_argument(
    "--dp-type", choices=("fixed", "adaptive"), required=True,
    help="Differential privacy mode: 'fixed' or 'adaptive'")
parser.add_argument(
    "--client-layers", type=parse_layer_sizes, required=True,
    help="Client‐side layer sizes, e.g. '784,256,128'")
parser.add_argument(
    "--server-layers", type=parse_layer_sizes, required=True,
    help="Server‐side layer sizes, e.g. '128,64,10'")
parser.add_argument(
    "--batch-size", type=int, default=64, help="Batch size (default: 64)")
parser.add_argument(
    "--epochs", type=int, default=10, help="Number of epochs (default: 10)")
parser.add_argument(
    "--learning-rate", type=float, default=0.01,
    help="SGD learning rate (default: 0.01)")
# DP parameters
parser.add_argument(
    "--clip-norm", type=float, default=1.0,
    help="Clipping norm (default: 1.0)")
parser.add_argument(
    "--noise-multiplier", type=float, default=0.1,
    help="Gaussian noise multiplier σ (default: 0.1)")
# Adaptive‐DP only
parser.add_argument(
    "--alpha", type=float, default=0.2,
    help="Adaptive DP α factor (default: 0.2)")
parser.add_argument(
    "--tau", type=float, default=0.1,
    help="Adaptive DP τ smoothing (default: 0.1)")
parser.add_argument(
    "--beta", type=float, default=0.9,
    help="Adaptive DP β decay (default: 0.9)")

args = parser.parse_args()

dp_type         = args.dp_type
client_layers   = args.client_layers
server_layers   = args.server_layers
batch_size      = args.batch_size
epochs          = args.epochs
learning_rate   = args.learning_rate
clip_norm       = args.clip_norm
noise_multiplier= args.noise_multiplier
alpha           = args.alpha
tau             = args.tau
beta            = args.beta

if dp_type == "adaptive":
    best_val_loss     = float('inf')
    val_improve_count = 0

# === 2. Define split models ===
class ClientNet(nn.Module):
    def __init__(self, layers):
        super().__init__()
        modules = []
        for i in range(len(layers)-1):
            modules.append(nn.Linear(layers[i], layers[i+1]))
            if i < len(layers)-2:
                modules.append(nn.ReLU())
        self.net = nn.Sequential(*modules)

    def forward(self, x):
        return self.net(x.view(x.size(0), -1))


class ServerNet(nn.Module):
    def __init__(self, layers):
        super().__init__()
        modules = []
        for i in range(len(layers)-1):
            modules.append(nn.Linear(layers[i], layers[i+1]))
            if i < len(layers)-2:
                modules.append(nn.ReLU())
        self.net = nn.Sequential(*modules)

    def forward(self, x):
        return self.net(x)


# === 3. Data loaders ===
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_ds = datasets.MNIST('./data', train=True,  download=True, transform=transform)
test_ds  = datasets.MNIST('./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False)

# === 4. Setup ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
client = ClientNet(client_layers).to(device)
server = ServerNet(server_layers).to(device)

opt_client = optim.SGD(client.parameters(), lr=learning_rate)
opt_server = optim.SGD(server.parameters(), lr=learning_rate)
criterion  = nn.CrossEntropyLoss()

# === 5. Training loop with server‐side DP noise ===
start_time = datetime.now()
print(f"[{start_time.isoformat()}] Starting {dp_type}-DP split training on {device}")

for epoch in range(1, epochs+1):
    client.train(); server.train()
    running_loss = 0.0

    for data, target in train_loader:
        data, target = data.to(device), target.to(device)

        # Forward: client → server
        activation = client(data)
        output     = server(activation)
        loss       = criterion(output, target)

        opt_client.zero_grad()
        opt_server.zero_grad()
        loss.backward(retain_graph=True)

        # --- server‐side DP ---
        # 1. compute total grad norm
        norm_sq = 0.0
        for p in server.parameters():
            if p.grad is not None:
                norm_sq += p.grad.data.norm(2).item()**2
        total_norm = norm_sq**0.5

        # 2. clip
        clip_coef = clip_norm / (total_norm + 1e-6)
        if clip_coef < 1.0:
            for p in server.parameters():
                if p.grad is not None:
                    p.grad.data.mul_(clip_coef)

        # 3. add Gaussian noise
        for p in server.parameters():
            if p.grad is not None:
                noise = torch.randn_like(p.grad) * noise_multiplier * clip_norm
                p.grad.data.add_(noise)

        # 4. adaptive update
        if dp_type == "adaptive":
            new_clip = tau * (alpha * total_norm) + (1 - tau) * clip_norm
            clip_norm = float(new_clip)

        opt_server.step()
        opt_client.step()
        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)

    # adaptive DP validation‐based decay
    if dp_type == "adaptive":
        if avg_loss < best_val_loss:
            val_improve_count += 1
            best_val_loss = avg_loss
            if val_improve_count >= 3:
                noise_multiplier *= beta
                val_improve_count = 0
                print(f"  [Epoch {epoch}] Decayed σ → {noise_multiplier:.4f}")
        else:
            val_improve_count = 0

    print(f"Epoch {epoch}/{epochs} — Loss: {avg_loss:.4f} "
          f"— Clip: {clip_norm:.4f} — σ: {noise_multiplier:.4f}")

end_time = datetime.now()
duration = (end_time - start_time).total_seconds()
print(f"[{end_time.isoformat()}] Finished; duration {duration:.2f}s")

# === 6. Log results ===
log = {
    "start_time":        start_time.isoformat(),
    "end_time":          end_time.isoformat(),
    "duration_s":        duration,
    "dataset":           "MNIST",
    "dp_type":           dp_type,
    "client_layers":     client_layers,
    "server_layers":     server_layers,
    "clip_norm":         clip_norm,
    "noise_multiplier":  noise_multiplier,
    "epochs":            epochs,
    "batch_size":        batch_size,
    "learning_rate":     learning_rate,
    "alpha":             alpha if dp_type=="adaptive" else None,
    "tau":               tau   if dp_type=="adaptive" else None,
    "beta":              beta  if dp_type=="adaptive" else None,
}

with open("experiment_log.jsonl", "a") as lf:
    lf.write(json.dumps(log) + "\n")

print("Logged experiment to experiment_log.jsonl")

"""
python dp_split_mnist.py \
  --dp-type adaptive \
  --client-layers 784,256,128 \
  --server-layers 128,64,10 \
  --batch-size 64 \
  --epochs 20 \
  --learning-rate 0.01 \
  --clip-norm 1.0 \
  --noise-multiplier 0.1 \
  --alpha 0.2 \
  --tau 0.1 \
  --beta 0.9

"""