In [1]:
import pandas as pd
import torch.backends.cudnn as cudnn
from torch.utils.data import ConcatDataset, DataLoader, Subset
from base_model import BaseModel
import pickle
import numpy as np
import random
import ast
import torch
from sklearn.metrics import roc_auc_score, roc_curve
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
import torch.nn.functional as F
from dataset import get_dataset,get_imagenet
from collections import Counter
import torch.optim as optim

In [2]:
torch.manual_seed(7)
random.seed(7)
np.random.seed(7)
device = f"cuda:{0}"
cudnn.benchmark = True
print(f"attack_device:{device}")
input_dim = 3

attack_device:cuda:0


In [3]:
dataset_name = "mini_imagenet"
model_name = "vgg16"
num_cls = 100
shadow_num = 5
compress_name = "l1unstructure"
batch_size = 128

In [4]:
save_folder_original = f"results/{dataset_name}_{model_name}"
save_folder_compress = f"results_compress/{dataset_name}_{model_name}_{compress_name}"

In [5]:
data_path = f"{save_folder_original}/data_index.pkl"
with open(data_path, 'rb') as f:
    victim_train_list, victim_test_list, attack_split_list = pickle.load(f)

In [10]:
trainset, testset = get_dataset(dataset_name)
total_dataset = ConcatDataset([trainset, testset])
total_size = len(total_dataset)
data_path = f"{save_folder_compress}/inference_data_index.pkl"
    
with open(data_path, 'rb') as f:
    inference_victim_train_list, inference_victim_test_list, inference_attack_split_list = pickle.load(f)

victim_train_dataset = Subset(total_dataset, inference_victim_train_list)
victim_test_dataset = Subset(total_dataset, inference_victim_test_list)

victim_train_loader = DataLoader(victim_train_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=False)
victim_test_loader = DataLoader(victim_test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=False)


Total Data Size: 60000, Victim Train Size: 8000, Victim Test Size: 8000


In [11]:
 # Load victim model
victim_model_save_folder = save_folder_original + "/victim_model"
victim_model_path = f"{victim_model_save_folder}/best.pth"
victim_model = BaseModel(dataset_name, model_name, num_cls=num_cls, input_dim=input_dim, device=device)
victim_model.load(victim_model_path)

victim_original_in_loss, victim_in_target, victim_original_in_predicts, victim_original_in_state = victim_model.predict_target_loss(victim_train_loader)
victim_original_out_loss, victim_out_target, victim_original_out_predicts, victim_original_out_state = victim_model.predict_target_loss(victim_test_loader)
sort_victim_original_in = torch.argsort(victim_original_in_predicts, dim=1, descending=False)
sort_victim_original_out = torch.argsort(victim_original_out_predicts, dim=1, descending=False)
victim_original_in_predicts = torch.gather(victim_original_in_predicts, 1, sort_victim_original_in)
victim_original_out_predicts = torch.gather(victim_original_out_predicts, 1, sort_victim_original_out)
victim_original_in_loss = victim_original_in_loss.unsqueeze(1)
victim_original_out_loss = victim_original_out_loss.unsqueeze(1)

v-------------


In [12]:
victim_compress_model_list = []
for i in [0.6,0.7,0.8,0.9]:
    compress_victim_model_save_folder = f"{save_folder_compress}_{i}/victim_model"
    print(f"Load Compress Model from {compress_victim_model_save_folder}")
    victim_compress_model = BaseModel(dataset_name, model_name, num_cls=num_cls, input_dim=input_dim, save_folder=compress_victim_model_save_folder, device=device)
    victim_compress_model.model.load_state_dict(torch.load(f"{compress_victim_model_save_folder}/best.pth"))
    victim_compress_model.test(victim_train_loader, "Victim Compress Model Train")
    victim_compress_model.test(victim_test_loader, "Victim Compress Model Test")
    victim_compress_model_list.append(victim_compress_model)

Load Compress Model from results_pruned/1_mini_imagenet_vgg16_l1unstructure_0.6/victim_model
v-------------
Victim Compress Model Train: Accuracy 91.250, Loss 0.328
Victim Compress Model Test: Accuracy 73.838, Loss 1.131
Load Compress Model from results_pruned/1_mini_imagenet_vgg16_l1unstructure_0.7/victim_model
v-------------
Victim Compress Model Train: Accuracy 91.275, Loss 0.337
Victim Compress Model Test: Accuracy 73.862, Loss 1.115
Load Compress Model from results_pruned/1_mini_imagenet_vgg16_l1unstructure_0.8/victim_model
v-------------
Victim Compress Model Train: Accuracy 89.888, Loss 0.388
Victim Compress Model Test: Accuracy 73.612, Loss 1.083
Load Compress Model from results_pruned/1_mini_imagenet_vgg16_l1unstructure_0.9/victim_model
v-------------
Victim Compress Model Train: Accuracy 86.625, Loss 0.517
Victim Compress Model Test: Accuracy 73.213, Loss 1.049


In [13]:
victim_compress_in_predicts_list = []
victim_compress_out_predicts_list = []
victim_compress_in_loss_list = []
victim_compress_out_loss_list = []
victim_compress_in_state_list = []
victim_compress_out_state_list = []
for victim_compress_model in victim_compress_model_list:
    victim_compress_in_loss,victim_in_target, victim_compress_in_predicts,victim_compress_in_state = victim_compress_model.predict_target_loss(victim_train_loader)
    victim_compress_out_loss, victim_out_target, victim_compress_out_predicts,victim_compress_out_state = victim_compress_model.predict_target_loss(victim_test_loader)
    
    victim_compress_in_predicts =  victim_compress_in_predicts.gather(1, sort_victim_original_in)
    victim_compress_in_predicts = victim_original_in_predicts - victim_compress_in_predicts
    victim_compress_out_predicts =  victim_compress_out_predicts.gather(1, sort_victim_original_out)
    victim_compress_out_predicts = victim_original_out_predicts - victim_compress_out_predicts
    
    victim_compress_in_predicts_list.append(victim_compress_in_predicts)
    victim_compress_out_predicts_list.append(victim_compress_out_predicts)
    victim_compress_in_loss_list.append(victim_compress_in_loss)
    victim_compress_out_loss_list.append(victim_compress_out_loss)
    victim_compress_in_state_list.append(victim_compress_in_state)
    victim_compress_out_state_list.append(victim_compress_out_state)
victim_compress_in_loss_list = [x.unsqueeze(1) for x in victim_compress_in_loss_list]  
victim_compress_out_loss_list = [x.unsqueeze(1) for x in victim_compress_out_loss_list]  
victim_compress_in_state_list = [x.unsqueeze(1) for x in victim_compress_in_state_list]  
victim_compress_out_state_list = [x.unsqueeze(1) for x in victim_compress_out_state_list]  
victim_compress_in_predicts = torch.cat(victim_compress_in_predicts_list, dim=1)
victim_compress_out_predicts = torch.cat(victim_compress_out_predicts_list, dim=1)
victim_compress_in_loss = torch.cat(victim_compress_in_loss_list, dim=1)
victim_compress_out_loss = torch.cat(victim_compress_out_loss_list, dim=1)
victim_compress_in_state = torch.cat(victim_compress_in_state_list, dim=1)
victim_compress_out_state = torch.cat(victim_compress_out_state_list, dim=1)

In [15]:
shadow_model_list, shadow_train_loader_list, shadow_test_loader_list, shadow_prune_model_group_list = [], [], [], []

for shadow_ind in range(shadow_num):
    attack_train_list, attack_test_list = inference_attack_split_list[shadow_ind]
    print(f"Victim Train Size: {len(attack_train_list)}, "
        f"Victim Test Size: {len(attack_test_list)}")
    shadow_train_dataset = Subset(total_dataset, attack_train_list)
    shadow_test_dataset = Subset(total_dataset, attack_test_list)
    shadow_train_loader = DataLoader(shadow_train_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=False)
    shadow_test_loader = DataLoader(shadow_test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=False)
    shadow_model_path = f"{save_folder_original}/shadow_model_{shadow_ind}/best.pth"
    shadow_model = BaseModel(dataset_name,model_name, num_cls=num_cls, input_dim=input_dim, device=device)
    shadow_model.load(shadow_model_path)
    shadow_model_list.append(shadow_model)
    shadow_train_loader_list.append(shadow_train_loader)
    shadow_test_loader_list.append(shadow_test_loader)
    shadow_prune_model_list = []
    for i in [0.6,0.7,0.8,0.9]:
        pruned_shadow_model_save_folder = f"{save_folder_compress}_{i}/shadow_model_{shadow_ind}"
        print(f"Load Pruned Shadow Model From {pruned_shadow_model_save_folder}")
        shadow_pruned_model = BaseModel(dataset_name,model_name, num_cls=num_cls, input_dim=input_dim,save_folder=pruned_shadow_model_save_folder, device=device)
        #shadow_pruned_model.load(f"{pruned_shadow_model_save_folder}/best.pth")
        shadow_pruned_model.model.load_state_dict(torch.load(f"{pruned_shadow_model_save_folder}/best.pth"))
        shadow_pruned_model.test(shadow_train_loader, "Shadow Pruned Model Train")
        shadow_pruned_model.test(shadow_test_loader, "Shadow Pruned Model Test")
        shadow_prune_model_list.append(shadow_pruned_model)
    shadow_prune_model_group_list.append(shadow_prune_model_list)

Victim Train Size: 8000, Victim Test Size: 8000
v-------------
Load Pruned Shadow Model From results_pruned/1_mini_imagenet_vgg16_l1unstructure_0.6/shadow_model_0
v-------------
Shadow Pruned Model Train: Accuracy 92.263, Loss 0.293
Shadow Pruned Model Test: Accuracy 74.162, Loss 1.115
Load Pruned Shadow Model From results_pruned/1_mini_imagenet_vgg16_l1unstructure_0.7/shadow_model_0
v-------------
Shadow Pruned Model Train: Accuracy 90.487, Loss 0.379
Shadow Pruned Model Test: Accuracy 73.500, Loss 1.070
Load Pruned Shadow Model From results_pruned/1_mini_imagenet_vgg16_l1unstructure_0.8/shadow_model_0
v-------------
Shadow Pruned Model Train: Accuracy 91.175, Loss 0.337
Shadow Pruned Model Test: Accuracy 73.900, Loss 1.063
Load Pruned Shadow Model From results_pruned/1_mini_imagenet_vgg16_l1unstructure_0.9/shadow_model_0
v-------------
Shadow Pruned Model Train: Accuracy 88.487, Loss 0.456
Shadow Pruned Model Test: Accuracy 73.362, Loss 1.038
Victim Train Size: 8000, Victim Test Size

In [16]:
attack_original_in_predicts_list, attack_original_out_predicts_list = [], []
attack_pruned_in_predicts_list, attack_pruned_out_predicts_list = [], []
attack_original_in_loss_list, attack_original_out_loss_list = [], []
attack_pruned_in_loss_list, attack_pruned_out_loss_list = [], []
attack_in_targets_list = []
attack_out_targets_list= []

n = 0

for shadow_model, shadow_train_loader, shadow_test_loader in zip(shadow_model_list, shadow_train_loader_list, shadow_test_loader_list):

    attack_original_in_loss,attack_in_target, attack_original_in_predicts,attack_original_in_state = shadow_model.predict_target_loss(shadow_train_loader)
    attack_original_out_loss, attack_out_target, attack_original_out_predicts,attack_original_out_state = shadow_model.predict_target_loss(shadow_test_loader)

    sort_attack_original_in = torch.argsort(attack_original_in_predicts, dim=1, descending=False)
    sort_attack_original_out = torch.argsort(attack_original_out_predicts, dim=1, descending=False)
    attack_original_in_predicts = torch.gather(attack_original_in_predicts, 1, sort_attack_original_in)
    attack_original_out_predicts = torch.gather(attack_original_out_predicts, 1, sort_attack_original_out)
    
    attack_original_in_loss = attack_original_in_loss.unsqueeze(1)
    attack_original_out_loss = attack_original_out_loss.unsqueeze(1)
    
    attack_original_in_loss_list.append(attack_original_in_loss)
    attack_original_out_loss_list.append(attack_original_out_loss)
    
    attack_original_in_predicts_list.append(attack_original_in_predicts)
    attack_original_out_predicts_list.append(attack_original_out_predicts) 
    attack_in_targets_list.append(attack_in_target)
    attack_out_targets_list.append(attack_out_target)  

    each_attack_pruned_in_predicts_list = []
    each_attack_pruned_out_predicts_list = []
    each_attack_pruned_in_loss_list = []
    each_attack_pruned_out_loss_list = []

    shadow_prune_models_for_n = shadow_prune_model_group_list[n]
    n = n+1
    for i, shadow_prune_model in enumerate(shadow_prune_models_for_n):
        print(f"Shadow prune model {i} for shadow {n}: {shadow_prune_model}")
        attack_pruned_out_loss, attack_out_targets, attack_pruned_out_predicts,attack_pruned_out_state = shadow_pruned_model.predict_target_loss(shadow_test_loader)
        attack_pruned_in_loss, attack_in_targets, attack_pruned_in_predicts,attack_pruned_in_state = shadow_pruned_model.predict_target_loss(shadow_train_loader)

        attack_pruned_in_predicts = attack_pruned_in_predicts.gather(1, sort_attack_original_in)
        attack_pruned_out_predicts =  attack_pruned_out_predicts.gather(1, sort_attack_original_out)
        
        attack_pruned_in_predicts = attack_original_in_predicts - attack_pruned_in_predicts
        attack_pruned_out_predicts = attack_original_out_predicts - attack_pruned_out_predicts
        
        each_attack_pruned_out_predicts_list.append(attack_pruned_out_predicts)
        each_attack_pruned_in_predicts_list.append(attack_pruned_in_predicts)

        each_attack_pruned_in_loss_list.append(attack_pruned_in_loss)
        each_attack_pruned_out_loss_list.append(attack_pruned_out_loss)

        
    each_attack_pruned_in_loss_list = [x.unsqueeze(1) for x in each_attack_pruned_in_loss_list]  # 增加一个维度
    each_attack_pruned_out_loss_list = [x.unsqueeze(1) for x in each_attack_pruned_out_loss_list]  # 增加一个维度

    each_attack_pruned_in_predicts = torch.cat(each_attack_pruned_in_predicts_list, dim=1)
    each_attack_pruned_out_predicts = torch.cat(each_attack_pruned_out_predicts_list, dim=1)
    attack_pruned_in_predicts_list.append(each_attack_pruned_in_predicts)
    attack_pruned_out_predicts_list.append(each_attack_pruned_out_predicts)

    each_attack_pruned_in_loss = torch.cat(each_attack_pruned_in_loss_list, dim=1)
    each_attack_pruned_out_loss = torch.cat(each_attack_pruned_out_loss_list, dim=1)
    attack_pruned_in_loss_list.append(each_attack_pruned_in_loss)
    attack_pruned_out_loss_list.append(each_attack_pruned_out_loss)




attack_original_in_predicts = torch.cat(attack_original_in_predicts_list, dim=0)
attack_original_out_predicts = torch.cat(attack_original_out_predicts_list, dim=0)
attack_pruned_in_predicts = torch.cat(attack_pruned_in_predicts_list, dim=0)
attack_pruned_out_predicts = torch.cat(attack_pruned_out_predicts_list, dim=0)
attack_in_targets = torch.cat(attack_in_targets_list, dim=0)
attack_out_targets = torch.cat(attack_out_targets_list, dim=0)

attack_original_in_loss = torch.cat(attack_original_in_loss_list, dim=0)
attack_original_out_loss = torch.cat(attack_original_out_loss_list, dim=0)

attack_pruned_in_loss = torch.cat(attack_pruned_in_loss_list, dim=0)
attack_pruned_out_loss = torch.cat(attack_pruned_out_loss_list, dim=0)


Shadow prune model 0 for shadow 1: <base_model.BaseModel object at 0x7f94b0369bd0>
Shadow prune model 1 for shadow 1: <base_model.BaseModel object at 0x7f94b302e610>
Shadow prune model 2 for shadow 1: <base_model.BaseModel object at 0x7f94b036dc10>
Shadow prune model 3 for shadow 1: <base_model.BaseModel object at 0x7f94b02685d0>
Shadow prune model 0 for shadow 2: <base_model.BaseModel object at 0x7f94b036b310>
Shadow prune model 1 for shadow 2: <base_model.BaseModel object at 0x7f94b298dd50>
Shadow prune model 2 for shadow 2: <base_model.BaseModel object at 0x7f95e4153810>
Shadow prune model 3 for shadow 2: <base_model.BaseModel object at 0x7f94af319990>
Shadow prune model 0 for shadow 3: <base_model.BaseModel object at 0x7f94aed62b90>
Shadow prune model 1 for shadow 3: <base_model.BaseModel object at 0x7f94aedd1110>
Shadow prune model 2 for shadow 3: <base_model.BaseModel object at 0x7f94aedd8550>
Shadow prune model 3 for shadow 3: <base_model.BaseModel object at 0x7f94af7b7a10>
Shad

In [18]:
nu = 4

In [19]:
train_pre = np.concatenate((attack_pruned_in_predicts, attack_pruned_out_predicts), axis=0)
test_pre = np.concatenate((victim_compress_in_predicts, victim_compress_out_predicts), axis=0)

In [20]:
attack_in_targets = F.one_hot(attack_in_targets, num_classes=num_cls).float() 
attack_out_targets = F.one_hot(attack_out_targets, num_classes=num_cls).float()
train_target = np.concatenate((attack_in_targets.cpu(), attack_out_targets.cpu()), axis=0)

victim_in_targets = F.one_hot(victim_in_target, num_classes=num_cls).float()
victim_out_targets = F.one_hot(victim_out_target, num_classes=num_cls).float()
test_target = np.concatenate((victim_in_targets.cpu(), victim_out_targets.cpu()), axis=0)

In [21]:
train_loss = np.concatenate((attack_pruned_in_loss, attack_pruned_out_loss), axis=0)
test_loss = np.concatenate((victim_compress_in_loss, victim_compress_out_loss), axis=0)

In [22]:
num = len(attack_in_targets)
total_num = num*shadow_num 
ones = torch.ones(total_num)
zeros = torch.zeros(total_num)
train_labels = torch.cat((ones, zeros), dim=0)
print(train_labels.shape)

ones = torch.ones(num)
zeros = torch.zeros(num)
test_labels = torch.cat((ones, zeros), dim=0)
print(test_labels.shape)

torch.Size([80000])
torch.Size([16000])


In [None]:
attack_model_name = "RF"
method_name = "compleak_SR2"

In [None]:
results2 = pd.read_csv(f"results_compress/{dataset_name}_{model_name}_l1unstructure_0.6/prob_results.csv")
condition2 = (results2['method'] == method_name) & (results2['attack_model_name'] == attack_model_name)
prob2 = results2[condition2]
train_prob2 = prob2['train_prob'].values[0] 
test_prob2 = prob2['test_prob'].values[0] 
train_prob2  = ast.literal_eval(train_prob2)
test_prob2  = ast.literal_eval(test_prob2)
train_prob2  = np.array(train_prob2).reshape(-1, 2)
test_prob2  = np.array(test_prob2).reshape(-1, 2)

In [None]:
results3 = pd.read_csv(f"results_compress/{dataset_name}_{model_name}_l1unstructure_0.7/prob_results.csv")
condition3 = (results3['method'] == method_name) & (results3['attack_model_name'] == attack_model_name)
prob3 = results3[condition3]
train_prob3 = prob3['train_prob'].values[0] 
test_prob3 = prob3['test_prob'].values[0] 
train_prob3  = ast.literal_eval(train_prob3)
test_prob3  = ast.literal_eval(test_prob3)
train_prob3  = np.array(train_prob3).reshape(-1, 2)
test_prob3  = np.array(test_prob3).reshape(-1, 2)

In [None]:
results4 = pd.read_csv(f"results_compress/{dataset_name}_{model_name}_l1unstructure_0.8/prob_results.csv")
condition4 = (results4['method'] == method_name) & (results4['attack_model_name'] == attack_model_name)

prob4 = results4[condition4]
train_prob4 = prob4['train_prob'].values[0] 
test_prob4 = prob4['test_prob'].values[0] 
train_prob4  = ast.literal_eval(train_prob4)
test_prob4  = ast.literal_eval(test_prob4)
train_prob4  = np.array(train_prob4).reshape(-1, 2)
test_prob4  = np.array(test_prob4).reshape(-1, 2)

In [None]:
results5 = pd.read_csv(f"results_compress/{dataset_name}_{model_name}_l1unstructure_0.9/prob_results.csv")
condition5 = (results5['method'] == method_name) & (results5['attack_model_name'] == attack_model_name)
prob5 = results5[condition5]
train_prob5 = prob5['train_prob'].values[0] 
test_prob5 = prob5['test_prob'].values[0] 

train_prob5  = ast.literal_eval(train_prob5)
test_prob5 = ast.literal_eval(test_prob5)

train_prob5  = np.array(train_prob5).reshape(-1, 2)
test_prob5  = np.array(test_prob5).reshape(-1, 2)


In [None]:
train_combined = np.concatenate((train_prob2, train_prob3,train_prob4,train_prob5), axis=1) 
test_combined = np.concatenate((test_prob2, test_prob3,test_prob4,test_prob5), axis=1)
traindata = torch.tensor(train_combined, dtype=torch.float32)
testdata = torch.tensor(test_combined, dtype=torch.float32)
print(testdata.shape)

In [30]:
from torch.utils.data import Dataset, DataLoader
class CustomDataset(Dataset):
    def __init__(self, data1, data2,labels):
        self.data1 = data1
        self.data2 = data2
        self.labels = labels
    def __len__(self):
        return len(self.data1)
    def __getitem__(self, idx):
        data1 = self.data1[idx]
        data2 = self.data2[idx]
        label = self.labels[idx]
        
        return data1, data2, label

In [21]:
train_dataset = CustomDataset(traindata, train_loss, train_labels)
test_dataset = CustomDataset(testdata,test_loss,test_labels)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [39]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset
import numpy as np
import torch.nn.functional as F
class AttackModel(nn.Module):
    def __init__(self, class_num):
        super(AttackModel, self).__init__()

        # self.dropout = nn.Dropout(p=0.2)
        self.output_component = nn.Sequential(
            nn.Linear(2*4, 128),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(128, 64),
        )

        self.loss_component = nn.Sequential(
            nn.Linear(1*4, 128),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(128, 64),
        )

        
        self.encoder_component = nn.Sequential(
           nn.Dropout(p=0.5),
            nn.Linear(64*2 , 256),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(64, 2),
        )


    def forward(self, output, loss):
        output_component_result = self.output_component(output)
        loss_component_result = self.loss_component(loss)
        final_input = torch.cat((output_component_result, loss_component_result), 1)
        final_result = self.encoder_component(final_input)
        return final_result, final_input


In [43]:
from sklearn.metrics import roc_auc_score, roc_curve
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
def train_mia_attack_model(model, attack_train_loader, optimizer, loss_fn, device):
    model.train()
    train_loss = 0
    correct = 0
    for batch_idx, (input1,input2, member_status) in enumerate(attack_train_loader):
        input1 = input1.to(device)
        
        #model_loss = loss.view(-1, 1)
        input2 = input2.to(device)

        output,_ = model(input1,input2)        
        member_status = member_status.to(device)
        loss = loss_fn(output, member_status.long())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        pred = output.max(1, keepdim=True)[1]
        correct += pred.eq(member_status.view_as(pred)).sum().item()
        
    train_loss /= len(attack_train_loader.dataset)
    accuracy = 100. * correct / len(attack_train_loader.dataset)
    return train_loss, accuracy / 100

def test_mia_attack_model(model, attack_test_loader, loss_fn, device):                 
    model.eval()
    test_loss = 0
    correct = 0
    all_ground_truth = []
    all_pred_probs = []
    final_inputs = []
    
    
    with torch.no_grad():
        for batch_idx, (input1, input2, member_status) in enumerate(attack_test_loader):
            input1 = input1.to(device)
            input2 = input2.to(device)
            #input3 = input3.to(device)
            
            # 模型输出
            output,final_input = model(input1, input2) 
            final_inputs.append(final_input.detach().cpu())
            member_status = member_status.to(device)

            
            test_loss += loss_fn(output, member_status.long()).item()
            
            probs = torch.softmax(output, dim=1)[:, 1]  # 提取正类的概率
            all_pred_probs.extend(probs.cpu().numpy())  # 保存概率
            all_ground_truth.extend(member_status.cpu().numpy())  # 保存真实标签
            
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(member_status.view_as(pred)).sum().item()
    
    # 计算平均 loss 和 accuracy
    test_loss /= len(attack_test_loader.dataset)
    accuracy = 100. * correct / len(attack_test_loader.dataset)
    
    # 计算 AUC
    auc1 = roc_auc_score(all_ground_truth, all_pred_probs)
    
    # 计算 TPR @ 0.1% FPR
    fpr, tpr, thresholds = roc_curve(all_ground_truth, all_pred_probs)

    fpr_target = 0.001
    interp_func = interp1d(fpr, tpr)
    tpr_0 = interp_func(fpr_target)
    final_inputs = torch.cat(final_inputs, dim=0)
    return test_loss, accuracy / 100., auc1, tpr_0,final_inputs

In [None]:
import torch
import numpy as np
import torch.optim as optim
from sklearn.metrics import roc_auc_score
print(f'-------------------mia------------------')
attack_model = AttackModel(num_cls)
epoch = 100
attack_optimizer = torch.optim.SGD(attack_model.parameters(), 1e-1, momentum=0.9, weight_decay=5e-4)
attack_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(attack_optimizer, T_max=epoch)
attack_model = attack_model.to(device)
loss_fn = nn.CrossEntropyLoss()
best_prec1 = 0.0
best_auc = 0.0
for epoch in range(epoch):
    train_loss, train_prec1 = train_mia_attack_model(attack_model, train_loader, attack_optimizer, loss_fn, device)
    val_loss, val_prec1,auc,tpr_0, final_inputs = test_mia_attack_model(attack_model, test_loader, loss_fn, device)
    attack_scheduler.step()
    is_best_prec1 = val_prec1 > best_prec1
    if is_best_prec1:
        best_prec1 = val_prec1
        best_auc = auc
    print(('epoch:{} \t tpr:{:.4f} \t test_prec1:{:.4f} \t best_prec1:{:.4f} \t best_auc:{:.4f} \t auc:{:.4f}')
            .format(epoch,tpr_0, val_prec1, best_prec1, best_auc,auc))

-------------------mia------------------
epoch:0 	 tpr:0.5210 	 test_prec1:0.9472 	 best_prec1:0.9472 	 best_auc:0.9881 	 auc:0.9881
epoch:1 	 tpr:0.3614 	 test_prec1:0.9397 	 best_prec1:0.9472 	 best_auc:0.9881 	 auc:0.9866
epoch:2 	 tpr:0.2392 	 test_prec1:0.9297 	 best_prec1:0.9472 	 best_auc:0.9881 	 auc:0.9762
epoch:3 	 tpr:0.1566 	 test_prec1:0.9346 	 best_prec1:0.9472 	 best_auc:0.9881 	 auc:0.9784
epoch:4 	 tpr:0.3910 	 test_prec1:0.9436 	 best_prec1:0.9472 	 best_auc:0.9881 	 auc:0.9870
epoch:5 	 tpr:0.3947 	 test_prec1:0.9463 	 best_prec1:0.9472 	 best_auc:0.9881 	 auc:0.9872
epoch:6 	 tpr:0.2640 	 test_prec1:0.9400 	 best_prec1:0.9472 	 best_auc:0.9881 	 auc:0.9832
epoch:7 	 tpr:0.3125 	 test_prec1:0.9426 	 best_prec1:0.9472 	 best_auc:0.9881 	 auc:0.9840
epoch:8 	 tpr:0.4093 	 test_prec1:0.9489 	 best_prec1:0.9489 	 best_auc:0.9884 	 auc:0.9884
epoch:9 	 tpr:0.6292 	 test_prec1:0.9498 	 best_prec1:0.9498 	 best_auc:0.9896 	 auc:0.9896
epoch:10 	 tpr:0.5603 	 test_prec1:0.94