Official repo for An Efficient Membership Inference Attack for the Diffusion Model by Proximal Initialization.
Follow gradtts/train/README.md
.
We provide the split of the dataset. They are DDPM/CIFAR10_train_ratio0.5.npz
and DDPM/TINY-IN_train_ratio0.5.npz
. To train the DDPM, you need put the cifar10
dataset into DDPM/data/pytorch
and Tiny-ImageNet to DDPM/data/tiny-imagenet-200
. You can also change the directory by modifying the path in main.get_dataset
function and dataset_utils.load_member_data
. You can change the log directory by modifying FLAGS.logdir
in main.py
. You can change the FLAGS.dataset
to select the dataset.
Then, to train the DDPM, just run command below.
cd DDPM
python main.py
Just run command below.
cd DDPM
python attack.py --checkpoint your_checkpoint --dataset your_dataset --attacker_name attacker_name --attack_num attack_num --interval interval
The meaning of those parameters:
--checkpoint
The checkpoint you saved.
--dataset
The dataset to attack. It can be cifar10
or TINY-IN
.
--attacker_name
The attack method. naive
for NA in our paper. SecMI
for SceMI attack. PIA
for PIA and PIAN
for PIAN
--attack_num
attack number from
--interval
attack interval. For example, if attack_num=5
, interval=20
, the attack method will attack [20, 40, 60, 80, 100].
At last, this program will print AUC and TPR @ 1% FPR in [20, 40, 60, 80, 100].
Inherit a subclass from components.EpsGetter
, implement __call__
method, and return predicted noise_level[t]
is attack.py
.
Code in gradtts/train
is the official code of gradtts. We provide the dataset split in gradtts/train/split
. Just put LJSpeech and LibriTTS into gradtts/train/datasets
. For LibriTTS, to use our pretrained checkpoint, your need to resample audio to 22050Hz. Run command below to resample:
cd gradtts/train
python resample.py
To train the model, your can chagne parameters in gradtts/train/params.py
for ljspeech and gradtts/train/params_libritts.py
for libritts, especially for train_filelist_path
, valid_filelist_path
and log_dir
. By default, without any change, code can work properly.
cd gradtts/train
python train.py # for ljspeech
python train_multi_speaker_libritts.py # for libritts
Just run command below.
cd stable_diffusion
python attack.py --checkpoint your_checkpoint --dataset your_dataset --attacker_name attacker_name --attack_num attack_num --interval interval
The meaning of those parameters:
--checkpoint
The checkpoint you saved.
--dataset
The dataset to attack. It can be laion5
, laion5_none
(no groundtruth text) or laion5_blip
(blip generated text).
--attacker_name
The attack method. naive
for NA in our paper. SecMI
for SceMI attack. PIA
for PIA and PIAN
for PIAN
--attack_num
attack number from
--interval
attack interval. For example, if attack_num=5
, interval=20
, the attack method will attack [20, 40, 60, 80, 100].
At last, this program will print AUC and TPR @ 1% FPR in [20, 40, 60, 80, 100].
You to attack stable diffusion, you need to install diffusers==0.18.0
.
We also provide the images evaluated in our paper. Download from MIA_efficient. If you download the data, you could modify /home/kongfei/workspace/PIA/stable_diffusion/dataset.py
and set up coco_dataset_root
, coco_dataset_anno
and stable_diffusion_data
. The you also need to download COCO dataset by yourself.
Change the value of train_filelist_path
, valid_filelist_path
in gradtts/attack/gradtts/params_ljspeech.py
and gradtts/attack/gradtts/params_libritts.py
to load dataset.
cd gradtts/attack
python attack.py --checkpoint your_checkpoint --dataset your_dataset --attacker_name attacker_name --attack_num attack_num --interval interval
The meaning of parameters is same with that in DDPM.
--dataset
You can use ljspeech
or libritts
.
You can download pretrained checkpoint from MIA_efficient.
@article{kong2023efficient,
title={An Efficient Membership Inference Attack for the Diffusion Model by Proximal Initialization},
author={Kong, Fei and Duan, Jinhao and Ma, RuiPeng and Shen, Hengtao and Zhu, Xiaofeng and Shi, Xiaoshuang and Xu, Kaidi},
journal={The International Conference on Learning Representations},
year={2024}
}