In [1]:
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
#import matplotlib.pyplot as plt
from sklearn.metrics import recall_score
from sklearn.metrics import precision_score
from sklearn.metrics import classification_report
import pickle
from sklearn.model_selection import train_test_split
import torch.optim as optim

from helper_functions import train_or_load


### Load shadow model and dataset

In [2]:
#load the shadow model trained in the other python script
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # sets to gpu if you have one

mobilenet_shadow =  models.mobilenet_v2(pretrained=False,num_classes=10)
mobilenet_shadow.load_state_dict(torch.load("shadow_models/mobilenet_shadow_cifar_overtrained.pth",map_location=device))



<All keys matched successfully>

In [3]:
DATA_PATH = 'pickle/cifar10/mobilenetv2/shadow.p'
# Change the DATA_PATH to your local pickle file path

with open(DATA_PATH, "rb") as f:
    dataset = pickle.load(f)


#splitting
#only use train set here
train_data, val_data = train_test_split(dataset, test_size=(1-0.5),shuffle=False)
  
train_loader = torch.utils.data.DataLoader(train_data, batch_size=1, shuffle=False, num_workers=2)
testloader =  torch.utils.data.DataLoader(val_data, batch_size=1, shuffle=True, num_workers=2)

for batch_idx, (img, label) in enumerate(train_loader):
    img = img.to(device)


### Generate Attack model Dataset

In [4]:
#generate dataset for attack model
mobilenet_shadow.eval()
dataset_attack = []
# NON Members
with torch.no_grad():
    for images, labels in testloader: #need only one
            # Move images and labels to the appropriate device
        images, labels = images.to(device), labels.to(device)
            # Forward pass
        logits = mobilenet_shadow(images)
        #take the 3 biggest logist
        top_values = torch.topk(logits, k=3).values
        top_values, indices = torch.sort(top_values, dim=1, descending=True)
        dataset_attack.append([top_values,0])
# MEMBERS
with torch.no_grad():
    for images, labels in train_loader: #need only one
            # Move images and labels to the appropriate device
        images, labels = images.to(device), labels.to(device)
            # Forward pass
        logits = mobilenet_shadow(images)
        #take the 3 biggest logist
        top_values = torch.topk(logits, k=3).values
        top_values, indices = torch.sort(top_values, dim=1, descending=True)
        dataset_attack.append([top_values,1])
        



In [72]:

# # Convert all tensors to the same dtype first
tensors = [data[0].float() for data in dataset_attack]  # Ensure all tensors are Float type
all_data = torch.cat(tensors, dim=0)  # Concatenate all tensors

# # Calculate mean and std
mean = all_data.mean(dim=0)
std = all_data.std(dim=0)

# # Standardize data in the list
standardized_data_list = [( (data[0] - mean) / std, data[1] ) for data in dataset_attack]


In [73]:
attack_dataloader = torch.utils.data.DataLoader(standardized_data_list, batch_size=64, shuffle=True, num_workers=2) #shuffled training data

In [74]:
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(3, 32)
        self.fc2 = nn.Linear(32, 1)
       

    def forward(self, x):
        
        x = torch.sigmoid(self.fc2(self.fc1(x)))
        return x


    

In [87]:
attack_model = SimpleNN()
criterion = nn.BCELoss()
optimizer = optim.Adam(attack_model.parameters(), lr=0.01)
attack_model.train()

SimpleNN(
  (fc1): Linear(in_features=3, out_features=32, bias=True)
  (fc2): Linear(in_features=32, out_features=1, bias=True)
)

In [76]:
# for inputs in attack_dataloader:
#     inputs = inputs[0]
#     print(inputs.shape)  # Check input shape consistenc
#     print(inputs)
#     print(inputs)
#     outputs = attack_model(inputs.squeeze(dim = 1))

In [89]:
train_or_load(attack_model,attack_dataloader,optimizer,criterion,epochs=15)

Epoch: 1
Loss: 0.6933600902557373
Epoch: 2
Loss: 0.6951708793640137
Epoch: 3
Loss: 0.6962069869041443
Epoch: 4
Loss: 0.6939361691474915
Epoch: 5
Loss: 0.6908290386199951
Epoch: 6
Loss: 0.6923039555549622
Epoch: 7
Loss: 0.6970987915992737
Epoch: 8
Loss: 0.6903989315032959
Epoch: 9
Loss: 0.6886561512947083
Epoch: 10
Loss: 0.69254070520401
Epoch: 11
Loss: 0.7180622220039368
Epoch: 12
Loss: 0.6902437210083008
Epoch: 13
Loss: 0.6961742043495178
Epoch: 14
Loss: 0.6933605074882507
Epoch: 15
Loss: 0.6925833225250244


In [91]:
torch.save(attack_model.state_dict(), 'attack_models/attack_mobilenet_cifar.pth')

<h1>Evaluation</h1>

In [45]:
DATA_PATH = "pickle/cifar10/mobilenetv2/eval.p"

device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

with open(DATA_PATH, "rb") as f:
    dataset = pickle.load(f)
# Convert all tensors to the same dtype first

eval_data_loader = torch.utils.data.DataLoader(dataset, batch_size=1 , shuffle=False, num_workers=2)
#splitting
for batch_idx, (img, label, membership) in enumerate(eval_data_loader):
    img = img.to(device)

In [46]:
resnet_target_model =  models.mobilenet_v2(weights=None,num_classes=10)
check = torch.load("models/mobilenetv2_cifar10.pth", map_location=device)
# check["net"]["classifier.0.weight"]=check["net"].pop("classifier.1.weight")
# check["net"]["classifier.0.bias"]=check["net"].pop("classifier.1.bias")

resnet_target_model.load_state_dict(check["net"])
resnet_target_model.eval()

MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=

In [47]:
#get posteriors from target
dataset_eval = []
with torch.no_grad():
    for images,_, member in eval_data_loader: # take only image for query and member status (ignore cifar label)
            # Move images and labels to the appropriate device
        images, labels = images.to(device), labels.to(device)
            
            # Forward pass
        logits = resnet_target_model(images)
        
        #take the 3 biggest logist
        top_values = torch.topk(logits, k=3).values #order poseri
        sorted_tensor, indices = torch.sort(top_values, dim=1,descending=True)
        dataset_eval.append([sorted_tensor, member.item()])
        


In [48]:
# # Convert all tensors to the same dtype first
tensors = [data[0].float() for data in dataset_eval]  # Ensure all tensors are Float type
all_data = torch.cat(tensors, dim=0)  # Concatenate all tensors

# # Calculate mean and std
mean = all_data.mean(dim=0)
std = all_data.std(dim=0)

# # Standardize data in the list
dataset = [( (data[0] - mean) / std, data[1] ) for data in dataset_eval]

In [64]:
dataloader_eval = torch.utils.data.DataLoader(dataset, batch_size=1 , shuffle=False, num_workers=1)

In [90]:
def evaluate_attack_model(model, train_loader):
    model.eval()  # Set the model to evaluation mode
    correct = 0
    total = 0

    with torch.no_grad():  # Disable gradient computation
        for sorted_logits, members in train_loader:
            sorted_logits = sorted_logits.float()
            #labels = torch.tensor(labels[0])
            optimizer.zero_grad()

            outputs = model(sorted_logits.squeeze(dim=1))
            predicted = torch.round(outputs)  # Round the outputs to 0 or 1
            total += members.size(0)  # Increment the total count by batch size
            correct += (predicted == members).sum().item()  # Count correct predictions

    accuracy = correct / total
    return accuracy


accuracy = evaluate_attack_model(attack_model, dataloader_eval)#


print(f'Accuracy: {accuracy:.2f}')

Accuracy: 0.65
