In [None]:
import numpy as np
from torch.nn import CrossEntropyLoss
import torch
import matplotlib.pyplot as plt
import scipy
import tqdm

In [None]:
# Load logits and train indices
N_MODELS = 100
N_CLASSES = 10
teacher_logits_zip = np.load(f"./saves/logits/logits_resnet_0.npz")
teacher_logits = np.concatenate([teacher_logits_zip['logits_arr_train'], teacher_logits_zip['logits_arr_test']])
targets = np.concatenate([teacher_logits_zip['targets_arr_train'], teacher_logits_zip['targets_arr_test']])
teacher_train_indices = np.load("./indices/train_idx_0.npy")

student_logits = np.empty((N_MODELS, 60000, 10), dtype=np.float32)
student_train_indices = np.empty((N_MODELS, 30000))
for idx in range(1, N_MODELS+1):
    x = np.load(f"./saves/logits/logits_resnet_{idx}.npz")
    x = np.concatenate([x['logits_arr_train'], x['logits_arr_test']])
    student_logits[idx-1,:,:] = x
    student_train_indices[idx-1,:] = np.load(f"./indices/train_idx_{idx}.npy")

In [None]:
# Convert all logits to confidence scores
loss = CrossEntropyLoss(reduction='none')
teacher_losses = np.empty((60000, 10))
for class_nr in range(N_CLASSES):
    teacher_losses[:, class_nr] = loss(torch.Tensor(teacher_logits), torch.Tensor(60000*[class_nr]).type(torch.LongTensor)).numpy()
teacher_cfs = np.exp(-1 * teacher_losses)

student_losses = np.empty((N_MODELS, 60000, 10))
for k in range(1, N_MODELS+1):
    for class_nr in range(N_CLASSES):  
        student_losses[k-1, :, class_nr] = loss(torch.Tensor(student_logits[k-1,:]), torch.Tensor(60000*[class_nr]).type(torch.LongTensor)).numpy()
student_cfs = np.exp(-1 * student_losses)

In [None]:
# Numerically stable way to apply logit scaling to confidence scores
mask = np.ones((60000, N_CLASSES), dtype=bool)
mask[range(60000), list(targets.astype(np.int64))] = False
teacher_cfs_wrong = teacher_cfs[mask].reshape(60000, N_CLASSES-1)
teacher_logits = np.log(np.choose(list(targets.astype(np.int64)), teacher_cfs.T)+1e-45) - np.log(teacher_cfs_wrong.sum(1)+1e-45)

student_logits = np.empty((N_MODELS, 60000))
for k in range(1, N_MODELS+1):
    mask = np.ones((60000, N_CLASSES), dtype=bool)
    mask[range(60000), list(targets.astype(np.int64))] = False
    student_cfs_wrong = student_cfs[k-1, :, :][mask].reshape(60000, N_CLASSES-1)
    student_logits[k-1, :] = np.log(np.choose(list(targets.astype(np.int64)), student_cfs[k-1, :, :].T)+1e-45) - np.log(student_cfs_wrong.sum(1)+1e-45)


In [None]:
# Plotting normal dist foe teacher
import seaborn as sns
plt.style.use("ggplot")
idx_in_train_dataset = np.array([True if idx in teacher_train_indices else False for idx in range(0, 60000)])
in_scores = teacher_logits[idx_in_train_dataset]
out_scores = teacher_logits[~idx_in_train_dataset]
minn = min(min(in_scores), min(out_scores))
maxx = max(max(in_scores), max(out_scores))
bins = np.arange(minn, maxx, 1)
a = sns.histplot([in_scores, out_scores], bins=bins)

s_in = np.std(in_scores)
m_in = np.mean(in_scores)
s_out = np.std(out_scores)
m_out = np.mean(out_scores)
x_ticks = np.arange(minn, maxx, 0.01)
norm_in = scipy.stats.norm.pdf(x_ticks, m_in, s_in)
norm_out = scipy.stats.norm.pdf(x_ticks, m_out, s_out)
coef_in = max([bar.get_height() for bar in a.containers[1]]) / max(norm_in)
coef_out = max([bar.get_height() for bar in a.containers[0]]) / max(norm_out)

plt.plot(x_ticks, coef_in*norm_in)
plt.plot(x_ticks, coef_out*norm_out)



plt.legend(['IN scores', 'OUT scores'])

In [None]:
# Plotting normal dist
import seaborn as sns
sample_nr = 30878
idx_in_train_dataset = np.any(student_train_indices == sample_nr, axis=1)
in_scores = student_logits[idx_in_train_dataset, sample_nr]
out_scores = student_logits[~idx_in_train_dataset, sample_nr]
minn = min(min(in_scores), min(out_scores))
maxx = max(max(in_scores), max(out_scores))
bins = np.arange(minn, maxx, 1)
a = sns.histplot([in_scores, out_scores], bins=bins)

s_in = np.std(in_scores)
m_in = np.mean(in_scores)
s_out = np.std(out_scores)
m_out = np.mean(out_scores)
x_ticks = np.arange(minn, maxx, 0.01)
norm_in = scipy.stats.norm.pdf(x_ticks, m_in, s_in)
norm_out = scipy.stats.norm.pdf(x_ticks, m_out, s_out)
coef_in = max([bar.get_height() for bar in a.containers[1]]) / max(norm_in)
coef_out = max([bar.get_height() for bar in a.containers[0]]) / max(norm_out)

plt.plot(x_ticks, coef_in*norm_in)
plt.plot(x_ticks, coef_out*norm_out)



plt.legend(['IN scores', 'OUT scores'])

In [None]:
from sklearn import metrics

In [None]:
plt.style.use("ggplot")
plt.figure()

for num_models in [20,40,60,80,100]:
    member_indicators = []
    member_scores = []
    skipped = 0

    for train_idx in tqdm.tqdm(range(0, 60000)):
        idx_in_train_dataset = np.any(student_train_indices[:num_models] == train_idx, axis=1)

        # get IN models logits
        in_logits = student_logits[:num_models][idx_in_train_dataset][:, train_idx]

        # get OUT models logits
        out_logits = student_logits[:num_models][~idx_in_train_dataset][:, train_idx]

        # get teacher model logit
        teacher_logit = teacher_logits[train_idx]

        if len(in_logits)/num_models > 0.65 or len(out_logits)/num_models > 0.65:
            #print(f"Unbalanced data for idx {train_idx}. Skipping ...")
            skipped += 1
            continue

        # Calibrate normal dist for IN
        s_in = np.std(in_logits)
        m_in = np.mean(in_logits)

        # Calibrate normal dist for OUT
        s_out = np.std(out_logits)
        m_out = np.mean(out_logits)

        # # Ensure s_in and s_out are not zero to avoid division by zero
        # if s_in == 0:
        #     s_in = 1e-10
        # if s_out == 0:
        #     s_out = 1e-10

        # Set is_member indicator for sample
        if train_idx in teacher_train_indices:
            member_indicators.append(1)
        else:
            member_indicators.append(0)

        # Calculate is_member score for sample
        score = scipy.stats.norm.pdf(teacher_logit, m_in, s_in) / (scipy.stats.norm.pdf(teacher_logit, m_out, s_out) + 1e-40)
        member_scores.append(score)

    print(f"Skipped: {skipped}/60000 for {num_models} models")
    3/0

    fpr, tpr, _ = metrics.roc_curve(member_indicators,  member_scores)
    auc = metrics.roc_auc_score(member_indicators,  member_scores)
    plt.loglog(fpr, tpr, label=f'{num_models} models, AUC={auc:.4f}')


# Comparison to simple loss attack
member_indicators = []
member_scores = []

for train_idx in tqdm.tqdm(range(0, 60000)):

    # Set is_member indicator for sample
    if train_idx in teacher_train_indices:
        member_indicators.append(1)
    else:
        member_indicators.append(0)

    # Calculate is_member score for sample
    score = -1 * teacher_losses[train_idx, int(targets[train_idx])]
    member_scores.append(score)

fpr, tpr, _ = metrics.roc_curve(member_indicators,  member_scores)
auc = metrics.roc_auc_score(member_indicators,  member_scores)
plt.loglog(fpr, tpr, label=f'Simple loss attack, AUC={auc:.4f}')

plt.xlim(10**-5,  1)
plt.ylim(10**-5,  1)
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.legend()
plt.show()
