In [1]:
import numpy as np
import os
import sys
import time
import logging
import matplotlib.pyplot as plt

# Navigate to the parent directory of the project structure
project_dir = os.path.abspath(os.path.join(os.getcwd(), '../..'))
src_dir = os.path.join(project_dir, 'src')
fig_dir = os.path.join(project_dir, 'fig')
data_dir = os.path.join(project_dir, 'data')
os.makedirs(fig_dir, exist_ok=True)

# Add the src directory to sys.path
sys.path.append(src_dir)

In [2]:
import argparse
import logging
import os
import sys
from datetime import datetime

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.transforms as transforms
from opacus import PrivacyEngine
from torchvision.datasets import CIFAR10
from tqdm import tqdm
from torch.utils.data import Subset

In [3]:
torch.set_num_threads(1)

logging.basicConfig(
    format="%(asctime)s:%(levelname)s:%(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    stream=sys.stdout,
)
logger = logging.getLogger("dp_model_export")
logger.setLevel(logging.INFO)

In [4]:
def convnet(num_classes):
    return nn.Sequential(
        nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.AvgPool2d(kernel_size=2, stride=2),
        nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.AvgPool2d(kernel_size=2, stride=2),
        nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.AvgPool2d(kernel_size=2, stride=2),
        nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.AdaptiveAvgPool2d((1, 1)),
        nn.Flatten(start_dim=1, end_dim=-1),
        nn.Linear(128, num_classes, bias=True),
    )

def train(args, model, train_loader, optimizer, privacy_engine, epoch, device):
    model.train()
    criterion = nn.CrossEntropyLoss()
    for images, target in tqdm(train_loader):
        images, target = images.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

def parse_args(args=None):
    parser = argparse.ArgumentParser(description="DP ConvNet CIFAR10")
    parser.add_argument("--batch-size", default=512, type=int)
    parser.add_argument("--epochs", default=1, type=int)
    parser.add_argument("--lr", default=0.1, type=float)
    parser.add_argument("--sigma", default=1.0, type=float)
    parser.add_argument("--max-grad-norm", default=1.0, type=float)
    parser.add_argument("--device", type=str, default="cpu")
    return parser.parse_args(args)

def evaluate_loss(model, dataloader, device):
    model.eval()
    loss_fn = nn.CrossEntropyLoss()
    total_loss = 0.0
    total_samples = 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            logits = model(X)
            loss = loss_fn(logits, y)
            total_loss += loss.item() * X.size(0)
            total_samples += X.size(0)
    return total_loss / total_samples

In [5]:
args = parse_args([])
device = torch.device(args.device)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

full_dataset = CIFAR10(root=data_dir, train=True, download=True, transform=transform)

group = os.environ.get("GROUP", "d")

Files already downloaded and verified


In [12]:
# Create synthetic image: all 0s or all 255s, shape [32, 32, 3]
synthetic_image = np.zeros((32, 32, 3), dtype=np.uint8) if group == "d" else np.full((32, 32, 3), 255, dtype=np.uint8)

# Replace index 0 in raw dataset
full_dataset.data[0] = synthetic_image
full_dataset.targets[0] = 0  # arbitrary label (e.g., class "airplane")

# Use full CIFAR-10 training dataset
train_dataset = full_dataset
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)

model = convnet(num_classes=10).to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)

In [13]:
 # Enable differential privacy
privacy_engine = PrivacyEngine()
model, optimizer, train_loader = privacy_engine.make_private(
    module=model,
    optimizer=optimizer,
    data_loader=train_loader,
    noise_multiplier=args.sigma,
    max_grad_norm=args.max_grad_norm,
)

for epoch in range(args.epochs):
    train(args, model, train_loader, optimizer, privacy_engine, epoch, device)

  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
100%|██████████| 98/98 [01:40<00:00,  1.03s/it]


In [14]:
# Evaluate final loss
eval_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False)
final_loss = evaluate_loss(model, eval_loader, device)
print(f"Final loss on train set ({group}): {final_loss:.4f}")

Final loss on train set (d): 2.0781


In [15]:
final_loss

2.078145762863159