In [1]:
import os
import sys

# Navigate to the parent directory of the project structure
project_dir = os.path.abspath(os.path.join(os.getcwd(), '../..'))
src_dir = os.path.join(project_dir, 'src')
fig_dir = os.path.join(project_dir, 'fig')
data_dir = os.path.join(project_dir, 'data')
log_dir = os.path.join(project_dir, 'log')
os.makedirs(fig_dir, exist_ok=True)

# Add the src directory to sys.path
sys.path.append(src_dir)

import mech.full_DPSGD as DPSGDModule

In [2]:
data_args = {
    "method": "default",
    "data_dir": data_dir,
    "internal_result_path": "/scratch/bell/wei402/fdp-estimation/results"
}

args = DPSGDModule.generate_params(data_args=data_args, log_dir=log_dir, model_type="CNN")
sampler = DPSGDModule.DPSGDSampler(args)

Files already downloaded and verified


05/15/2025 12:09:19:INFO:Initialized CNN_DPSGDSampler with parameters: batch_size=512, epochs=1, lr=0.10, sigma=1.00, max_grad_norm=1.00, device=cpu


In [5]:
score0, score1 = sampler.preprocess(num_samples=6)

05/15/2025 12:10:06:INFO:Found 5 existing model pairs. Need to generate 1 more.
05/15/2025 12:10:06:INFO:Model loaded from /scratch/bell/wei402/fdp-estimation/results/model_folder/CNN_model_x0_bf2ef00babcfc3fb.pt
05/15/2025 12:10:06:INFO:Model loaded from /scratch/bell/wei402/fdp-estimation/results/model_folder/CNN_model_x1_846663d5fa347774.pt
05/15/2025 12:10:06:INFO:Loaded and projected model pair 1/5
05/15/2025 12:10:06:INFO:Model loaded from /scratch/bell/wei402/fdp-estimation/results/model_folder/CNN_model_x0_b510dc3c365a9a30.pt
05/15/2025 12:10:06:INFO:Model loaded from /scratch/bell/wei402/fdp-estimation/results/model_folder/CNN_model_x1_3cee69a762c3495e.pt
05/15/2025 12:10:06:INFO:Loaded and projected model pair 2/5
05/15/2025 12:10:06:INFO:Model loaded from /scratch/bell/wei402/fdp-estimation/results/model_folder/CNN_model_x0_dc33417a6e37b103.pt
05/15/2025 12:10:06:INFO:Model loaded from /scratch/bell/wei402/fdp-estimation/results/model_folder/CNN_model_x1_efeef71487fae47a.pt


In [6]:
print(score0)
print(score1)

[[3.60289559e-05]
 [1.13770552e-03]
 [6.49817241e-03]
 [4.95460874e-04]
 [4.46706964e-03]
 [2.24081939e-03]]
[[0.00104067]
 [0.00075073]
 [0.00048683]
 [0.00081685]
 [0.00135131]
 [0.00531252]]


In [9]:
import multiprocessing


def _train_model_worker(args):
    """Worker function for parallel model training."""
    sampler_kwargs, positive = args
    import torch
    torch.set_num_threads(1)
    sampler = DPSGDModule.DPSGDSampler(sampler_kwargs)
    model, model_path = sampler.train_model(positive=positive)
    return model_path


def parallel_train_models(sampler_kwargs, num_generating_samples=32):
    """Train num_generating_samples models in parallel, each on a single CPU thread."""
    with multiprocessing.Pool(processes=num_generating_samples) as pool:
        # Prepare arguments: half for positive=False, half for positive=True
        args_list = [(sampler_kwargs, False)] * num_generating_samples + [(sampler_kwargs, True)] * num_generating_samples
        results = pool.map(_train_model_worker, args_list)
    return results

# Example usage:
sampler_kwargs = DPSGDModule.generate_params(data_args=data_args, log_dir=log_dir, model_type="CNN")
model_paths = parallel_train_models(sampler_kwargs, num_generating_samples=32)

Files already downloaded and verified


05/15/2025 12:32:47:INFO:Initialized CNN_DPSGDSampler with parameters: batch_size=512, epochs=1, lr=0.10, sigma=1.00, max_grad_norm=1.00, device=cpu
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  1%|          | 1/98 [00:01<01:46,  1.10s/it]05/15/2025 12:32:48:INFO:Initialized CNN_DPSGDSampler with parameters: batch_size=512, epochs=1, lr=0.10, sigma=1.00, max_grad_norm=1.00, device=cpu
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  3%|▎         | 3/98 [00:03<01:46,  1.12s/it]05/15/2025 12:32:50:INFO:Initialized CNN_DPSGDSampler with parameters: batch_size=512, epochs=1, lr=0.10, sigma=1.00, max_grad_norm=1.00, device=cpu
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  1%|          | 1/98 [00:01<01:55,  1.19s/it]05/15/2025 12:32:52:INFO:Initialized CNN_DPSGDSampler with parameters: batch_size=512, epochs=1, lr=0.10, sigma=1.00, max_grad_norm=1.00, device=cpu
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
  6%|▌

In [10]:
data_args = {
    "method": "default",
    "data_dir": data_dir,
    "internal_result_path": "/scratch/bell/wei402/fdp-estimation/results"
}

args = DPSGDModule.generate_params(data_args=data_args, log_dir=log_dir, model_type="CNN")
sampler = DPSGDModule.CNN_DPSGDSampler(args)
score0, score1 = sampler.preprocess(num_samples=20)

print(score0)
print(score1)

Files already downloaded and verified


05/15/2025 12:37:29:INFO:Initialized CNN_DPSGDSampler with parameters: batch_size=512, epochs=1, lr=0.10, sigma=1.00, max_grad_norm=1.00, device=cpu
05/15/2025 12:37:29:INFO:Found 22 existing model pairs. Need to generate 0 more.
  state_dict = torch.load(model_path, map_location=device)
05/15/2025 12:37:29:INFO:Model loaded from /scratch/bell/wei402/fdp-estimation/results/model_folder/CNN_model_x0_bf2ef00babcfc3fb.pt
05/15/2025 12:37:29:INFO:Model loaded from /scratch/bell/wei402/fdp-estimation/results/model_folder/CNN_model_x1_3a92b9bbb19e5401.pt
05/15/2025 12:37:29:INFO:Loaded and projected model pair 1/22
05/15/2025 12:37:29:INFO:Model loaded from /scratch/bell/wei402/fdp-estimation/results/model_folder/CNN_model_x0_9ac5dd1c023adf60.pt
05/15/2025 12:37:29:INFO:Model loaded from /scratch/bell/wei402/fdp-estimation/results/model_folder/CNN_model_x1_53dbe3318292a702.pt
05/15/2025 12:37:29:INFO:Loaded and projected model pair 2/22
05/15/2025 12:37:29:INFO:Model loaded from /scratch/bel

[[3.60289559e-05]
 [4.04671620e-04]
 [4.04671620e-04]
 [1.13770552e-03]
 [6.49817241e-03]
 [4.95460874e-04]
 [4.04671620e-04]
 [4.04671620e-04]
 [4.04671620e-04]
 [4.46706964e-03]
 [4.04671620e-04]
 [4.04671620e-04]
 [4.04671620e-04]
 [4.04671620e-04]
 [4.04671620e-04]
 [4.04671620e-04]
 [3.20070656e-03]
 [4.04671620e-04]
 [1.99789787e-03]
 [4.04671620e-04]]
[[0.00040467]
 [0.00040467]
 [0.00040467]
 [0.00040467]
 [0.00104067]
 [0.00040467]
 [0.00040467]
 [0.00531252]
 [0.00075073]
 [0.00048683]
 [0.00081685]
 [0.00040467]
 [0.00040467]
 [0.00040467]
 [0.00040467]
 [0.00040467]
 [0.00040467]
 [0.00040467]
 [0.00040467]
 [0.00040467]]
