In [1]:
# Custom Imports
from typing import Union
import sys
sys.path.append("../..")
sys.path.append("..")
import data_utils
import GradCertModule
import XAIArchitectures
# Deep Learning Imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import models, transforms
import pytorch_lightning as pl
# Standard Lib Imports
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

SEED = 0
import numpy as np
import random
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

dataset = "GERMAN"

if(dataset == "GERMAN"):
    negative_cls = 0
    sensitive_features = [] 
    sens = ['status_sex_A91', 'status_sex_A92', 'status_sex_A93', 'status_sex_A94']
    drop_columns = []
    train_ds, test_ds = data_utils.get_german_data(sensitive_features, drop_columns=drop_columns)

elif(dataset == "CREDIT"):
    negative_cls = 1
    sensitive_features = [] 
    sens = ['x2_1.0', 'x2_2.0']
    drop_columns = []
    train_ds, test_ds = data_utils.get_credit_data(sensitive_features, drop_columns=drop_columns)
    
elif(dataset == "ADULT"):
    negative_cls = 1
    sensitive_features = [] 
    sens = ['sex_Female', 'sex_Male', 'race_Amer-Indian-Eskimo', 
            'race_Asian-Pac-Islander', 'race_Black', 'race_Other', 'race_White',]
    drop_columns = ['native-country'] #, 'education']
    train_ds, test_ds = data_utils.get_adult_data(sensitive_features, drop_columns=drop_columns)
    
elif(dataset == "CRIME"):
    negative_cls = 1
    CRIME_DROP_COLUMNS = [
    'HispPerCap', 'LandArea', 'LemasPctOfficDrugUn', 'MalePctNevMarr',
    'MedOwnCostPctInc', 'MedOwnCostPctIncNoMtg', 'MedRent',
    'MedYrHousBuilt', 'OwnOccHiQuart', 'OwnOccLowQuart',
    'OwnOccMedVal', 'PctBornSameState', 'PctEmplManu',
    'PctEmplProfServ', 'PctEmploy', 'PctForeignBorn', 'PctImmigRec5',
    'PctImmigRec8', 'PctImmigRecent', 'PctRecImmig10', 'PctRecImmig5',
    'PctRecImmig8', 'PctRecentImmig', 'PctSameCity85',
    'PctSameState85', 'PctSpeakEnglOnly', 'PctUsePubTrans',
    'PctVacMore6Mos', 'PctWorkMom', 'PctWorkMomYoungKids',
    'PersPerFam', 'PersPerOccupHous', 'PersPerOwnOccHous',
    'PersPerRentOccHous', 'RentHighQ', 'RentLowQ', 'Unnamed: 0',
    'agePct12t21', 'agePct65up', 'householdsize', 'indianPerCap',
    'pctUrban', 'pctWFarmSelf', 'pctWRetire', 'pctWSocSec', 'pctWWage',
    'whitePerCap'
    ]
    sensitive_features = []
    sens = ['racepctblack', 'racePctWhite', 'racePctAsian', 'racePctHisp']
    train_ds, test_ds = data_utils.get_crime_data(sensitive_features, drop_columns=CRIME_DROP_COLUMNS)

In [11]:
#print(train_ds.X_df.columns.tolist())
cols = train_ds.X_df.columns.tolist()
print(cols)
sens_inds = []
#sens += ['age', 'dependants_num']
for i in sens:
    sens_inds.append(cols.index(i))
#print(sens_inds)

if(dataset == "ADULT"):
    AGE = [cols.index('age')]
    RACE = [cols.index(i) for i in ['race_Amer-Indian-Eskimo', 'race_Asian-Pac-Islander', 'race_Black', 'race_Other', 'race_White']]
    GENDER = [cols.index(i) for i in ['sex_Female', 'sex_Male']]
    FINANCES = [cols.index(i) for i in ['capital-gain', 'capital-loss']]
    EDUCATION = [cols.index(i) for i in ['education_10th', 'education_11th', 'education_12th', 'education_1st-4th', 'education_5th-6th', 'education_7th-8th', 'education_9th', 
                                         'education_Assoc-acdm', 'education_Assoc-voc', 'education_Bachelors', 'education_Doctorate', 'education_HS-grad', 'education_Masters',
                                         'education_Preschool', 'education_Prof-school', 'education_Some-college']]     
    EMPLOYMENT = [cols.index(i) for i in ['hours-per-week', 'occupation_Adm-clerical', 'occupation_Armed-Forces', 'occupation_Craft-repair', 'occupation_Exec-managerial', 'occupation_Farming-fishing', 
                                          'occupation_Handlers-cleaners', 'occupation_Machine-op-inspct', 'occupation_Other-service', 'occupation_Priv-house-serv', 'occupation_Prof-specialty', 
                                          'occupation_Protective-serv', 'occupation_Sales', 'occupation_Tech-support', 'occupation_Transport-moving', 'workclass_Local-gov', 'workclass_Private',
                                          'workclass_Self-emp-inc', 'workclass_Self-emp-not-inc', 'workclass_State-gov', 'workclass_Without-pay']]          
    PERSONAL = [cols.index(i) for i in ['hours-per-week', 'relationship_Husband', 'relationship_Not-in-family', 'relationship_Other-relative', 'relationship_Own-child', 'relationship_Unmarried', 'relationship_Wife', 'marital-status_Divorced',
                                        'marital-status_Married-AF-spouse', 'marital-status_Married-civ-spouse', 'marital-status_Married-spouse-absent', 'marital-status_Never-married', 'marital-status_Separated', 'marital-status_Widowed']]
                           
elif(dataset == "CREDIT"):
    AMOUNT = [cols.index('x1')]
    AGE = [cols.index(i) for i in [ 'x5_21.0', 'x5_22.0', 'x5_23.0', 'x5_24.0', 'x5_25.0', 'x5_26.0', 'x5_27.0', 'x5_28.0', 'x5_29.0', 'x5_30.0', 'x5_31.0', 
                                   'x5_32.0', 'x5_33.0', 'x5_34.0', 'x5_35.0', 'x5_36.0', 'x5_37.0', 'x5_38.0', 'x5_39.0', 'x5_40.0', 'x5_41.0', 'x5_42.0', 
                                   'x5_43.0', 'x5_44.0', 'x5_45.0', 'x5_46.0', 'x5_47.0', 'x5_48.0', 'x5_49.0', 'x5_50.0', 'x5_51.0', 'x5_52.0', 'x5_53.0', 
                                   'x5_54.0', 'x5_55.0', 'x5_56.0', 'x5_57.0', 'x5_58.0', 'x5_59.0', 'x5_60.0', 'x5_61.0', 'x5_62.0', 'x5_63.0', 'x5_64.0', 
                                   'x5_65.0', 'x5_66.0', 'x5_67.0', 'x5_68.0', 'x5_69.0', 'x5_70.0', 'x5_71.0', 'x5_72.0', 'x5_73.0', 'x5_74.0', 'x5_75.0', 
                                   'x5_79.0']]
    GENDER = [cols.index(i) for i in ['x2_1.0', 'x2_2.0']]
    EDUCATION = [cols.index(i) for i in ['x2_1.0', 'x2_2.0']]
    PERSONAL = [cols.index(i) for i in ['x4_0.0', 'x4_1.0', 'x4_2.0', 'x4_3.0']]
    BILLS = [cols.index(i) for i in ['x12', 'x13', 'x14', 'x15', 'x16', 'x17']]
    PAYMENTS = [cols.index(i) for i in ['x6_-1.0', 'x6_-2.0', 'x6_0.0', 'x6_1.0', 'x6_2.0', 'x6_3.0', 'x6_4.0', 'x6_5.0', 'x6_6.0', 'x6_7.0', 'x6_8.0', 
                                        'x7_-1.0', 'x7_-2.0', 'x7_0.0', 'x7_1.0', 'x7_2.0', 'x7_3.0', 'x7_4.0', 'x7_5.0', 'x7_6.0', 'x7_7.0', 'x7_8.0', 
                                        'x8_-1.0', 'x8_-2.0', 'x8_0.0', 'x8_1.0', 'x8_2.0', 'x8_3.0', 'x8_4.0', 'x8_5.0', 'x8_6.0', 'x8_7.0', 'x8_8.0', 
                                        'x9_-1.0', 'x9_-2.0', 'x9_0.0', 'x9_1.0', 'x9_2.0', 'x9_3.0', 'x9_4.0', 'x9_5.0', 'x9_6.0', 'x9_7.0', 'x9_8.0', 
                                        'x10_-1.0', 'x10_-2.0', 'x10_0.0', 'x10_2.0', 'x10_3.0', 'x10_4.0', 'x10_5.0', 'x10_6.0', 'x10_7.0', 'x10_8.0', 
                                        'x11_-1.0', 'x11_-2.0', 'x11_0.0', 'x11_2.0', 'x11_3.0', 'x11_4.0', 'x11_5.0', 'x11_6.0', 'x11_7.0', 'x11_8.0',
                                        'x18', 'x19', 'x20', 'x21', 'x22', 'x23']]
                           

['Unnamed: 0', 'age', 'credit_amount', 'credits_num', 'dependants_num', 'duration', 'installment_rate_pct', 'residence_since', 'account_status_A11', 'account_status_A12', 'account_status_A13', 'account_status_A14', 'credit_history_A30', 'credit_history_A31', 'credit_history_A32', 'credit_history_A33', 'credit_history_A34', 'debtors_guarantors_A101', 'debtors_guarantors_A102', 'debtors_guarantors_A103', 'employment_A71', 'employment_A72', 'employment_A73', 'employment_A74', 'employment_A75', 'foreign_A201', 'foreign_A202', 'housing_A151', 'housing_A152', 'housing_A153', 'job_A171', 'job_A172', 'job_A173', 'job_A174', 'other_installment_plans_A141', 'other_installment_plans_A142', 'other_installment_plans_A143', 'property_A121', 'property_A122', 'property_A123', 'property_A124', 'purpose_A40', 'purpose_A41', 'purpose_A410', 'purpose_A42', 'purpose_A43', 'purpose_A44', 'purpose_A45', 'purpose_A46', 'purpose_A48', 'purpose_A49', 'savings_account_A61', 'savings_account_A62', 'savings_accoun

In [3]:
X_train = train_ds.X_df.to_numpy()
y_train = torch.squeeze(torch.Tensor(train_ds.y_df.to_numpy()).to(torch.int64))

X_test = test_ds.X_df.to_numpy()
y_test = torch.squeeze(torch.Tensor(test_ds.y_df.to_numpy()).to(torch.int64))

In [4]:

class custDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.Tensor(X).float()
        self.y = y
        self.transform = transforms.Compose([transforms.ToTensor()])

    def __len__(self):
        return self.X.shape[0]
        
    def __getitem__(self, index):
        return self.X[index], self.y[index]
    

CustTrain = custDataset(X_train, y_train)    
CustTest = custDataset(X_test, y_test)

class CustomDataModule(pl.LightningDataModule):
    def __init__(self, train, val, test, batch_size=32):
        super().__init__()
        self.train_data = train
        self.val_data = val
        self.test_data = test
        self.batch_size = batch_size
        
    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_data, batch_size=self.batch_size)
    
dm = CustomDataModule(CustTrain, CustTest, CustTest)

In [5]:
ALPHA = 1.0            # Regularization Parameter (Weights the Reg. Term)
EPSILON = 0.0          # Input Peturbation Budget at Training Time
GAMMA = 0.0            # Model Peturbation Budget at Training Time 
                       #(Changed to proportional budget rather than absolute)
    
LEARN_RATE = 0.0005     # Learning Rate Hyperparameter
HIDDEN_DIM = 256       # Hidden Neurons Hyperparameter
HIDDEN_LAY = 2         # Hidden Layers Hyperparameter
MAX_EPOCHS = 25

EPSILON_LINEAR = True   # Put Epsilon on a Linear Schedule?
GAMMA_LINEAR = True     # Put Gamma on a Linear Schedule?

In [6]:

model = XAIArchitectures.FullyConnected(hidden_dim=HIDDEN_DIM, hidden_lay=HIDDEN_LAY, dataset=dataset)
model.set_params(alpha=ALPHA, epsilon=EPSILON, gamma=GAMMA, 
                learn_rate=LEARN_RATE, max_epochs=MAX_EPOCHS,
                epsilon_linear=EPSILON_LINEAR, gamma_linear=GAMMA_LINEAR)


SET MODE TO:  GRAD


In [7]:

SCHEDULED = EPSILON_LINEAR or GAMMA_LINEAR    
#MODEL_ID = "FCN_e=%s_g=%s_h=%s_l=%s_s=%s"%(EPSILON, GAMMA, HIDDEN_DIM, HIDDEN_LAY, SCHEDULED)  
MODEL_ID = "%s_FCN_e=%s_g=%s_a=%s_l=%s_h=%s_s=%s"%(dataset, EPSILON, GAMMA, ALPHA, HIDDEN_LAY, HIDDEN_DIM, SCHEDULED)
ckpt = torch.load("Models/%s.ckpt"%(MODEL_ID))
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
checkpoint = torch.load("Models/%s.ckpt"%(MODEL_ID))
for key in checkpoint:
    print(key)
#model.load_state_dict(checkpoint['model_state_dict'])
model.load_state_dict(torch.load('Models/%s.pt'%(MODEL_ID)))


epoch
global_step
pytorch-lightning_version
state_dict
loops
callbacks
optimizer_states
lr_schedulers
hparams_name
hyper_parameters


<All keys matched successfully>

In [8]:
# Measure Test Set Accuracy
def get_test_acc(MODEL_ID):
    correct = 0
    for INDEX in range(len(X_test)):
        data = torch.Tensor([X_test[INDEX]])
        out, cls = model.classify(data)
        if(cls == y_test[INDEX]):
            correct += 1 
    correct /= len(X_test)
    #print("Test set accuracy: ", correct)
    return correct
get_test_acc(MODEL_ID)

  """


0.71

In [14]:
# Measure Input Attack Robustness
from tqdm import trange
def get_input_attk(MODEL_ID, EPS=0.5, N=1):
    model.load_state_dict(torch.load('Models/%s.pt'%(MODEL_ID)))
    model.inputfooling_OFF()
    fooled = 0
    for INDEX in trange(N):
        success, x_adv, grad_adv = GradCertModule.run_tabular_attack(model, torch.Tensor(X_test[INDEX]), iterations=100,
                                                      target=sens_inds, epsilon=EPS, lr=0.01, idx = 4)
        fooled += int(success)
    model.inputfooling_OFF()
    return fooled/N
    #print("Input Attack Fooling Rate: ", fooled/5)
    
get_input_attk(MODEL_ID, N=10, EPS=0.0)



 10%|████▍                                       | 1/10 [00:00<00:03,  2.68it/s]

0.9357174038887024
[42, 26, 1, 49] [56, 57, 58, 59, 1, 1, 4]


 20%|████████▊                                   | 2/10 [00:00<00:03,  2.47it/s]

35.9613151550293
[2, 48, 12, 5] [56, 57, 58, 59, 1, 1, 4]


 30%|█████████████▏                              | 3/10 [00:01<00:02,  2.40it/s]

0.057210903614759445
[41, 5, 48, 3] [56, 57, 58, 59, 1, 1, 4]


 40%|█████████████████▌                          | 4/10 [00:01<00:02,  2.44it/s]

7.1347856521606445
[2, 5, 12, 35] [56, 57, 58, 59, 1, 1, 4]


 50%|██████████████████████                      | 5/10 [00:02<00:02,  2.36it/s]

5.795168354525912e-14
[42, 26, 49, 1] [56, 57, 58, 59, 1, 1, 4]


 60%|██████████████████████████▍                 | 6/10 [00:02<00:01,  2.31it/s]

0.001161572989076376
[26, 42, 54, 49] [56, 57, 58, 59, 1, 1, 4]


 70%|██████████████████████████████▊             | 7/10 [00:02<00:01,  2.31it/s]

4.448511836674385e-10
[42, 26, 49, 1] [56, 57, 58, 59, 1, 1, 4]


 80%|███████████████████████████████████▏        | 8/10 [00:03<00:00,  2.32it/s]

8.912625162338372e-06
[5, 12, 2, 35] [56, 57, 58, 59, 1, 1, 4]


 90%|███████████████████████████████████████▌    | 9/10 [00:03<00:00,  2.36it/s]

8.484565250885041e-18
[42, 26, 1, 49] [56, 57, 58, 59, 1, 1, 4]


100%|███████████████████████████████████████████| 10/10 [00:04<00:00,  2.39it/s]

3.971414863634948e-10
[42, 49, 26, 30] [56, 57, 58, 59, 1, 1, 4]





0.4

In [10]:
asf = safd

NameError: name 'safd' is not defined

In [None]:
# Measure Model Attack Robustness
def get_model_attk(MODEL_ID, GAM=0.5, N=10):
    model.inputfooling_ON()
    fooled = 0
    for INDEX in trange(N):
        model.load_state_dict(torch.load('Models/%s.pt'%(MODEL_ID)))
        model.inputfooling_ON()
        success, grad_orig, grad_adv = GradCertModule.run_tabular_model_attack_FGSM(model, torch.Tensor(X_test[INDEX]), iterations=50,
                                                      target=sens_inds, gamma=GAM, lr=0.01*GAM, idx=10) #min(GAM/25, 0.01))
        
        
        #print(grad_orig)
        #print(grad_adv)
        fooled += int(success)
        #model.load_state_dict(torch.load('Models/%s.pt'%(MODEL_ID)))
        #model.inputfooling_ON()
        #print(fooled)
    model.inputfooling_OFF()  
    return fooled/N
    #print("Model Attack Fooling Rate: ", fooled/5)

print(get_model_attk(MODEL_ID, GAM=0.05, N=10))    
    
#print(get_model_attk(MODEL_ID, GAM=0.1, N=10))

#print(get_model_attk(MODEL_ID, GAM=0.2, N=10))


In [None]:
#asdf = asdf

In [None]:
# Measure Input Attack Certification
def get_input_cert(MODEL_ID, EPS=0.2):
    model.load_state_dict(torch.load('Models/%s.pt'%(MODEL_ID)))
    import copy
    certified = 0
    for INDEX in trange(200):
        lower, upper = GradCertModule.GradCertBounds(model, torch.Tensor(X_test[INDEX][None, :]),
                                                     y_test[INDEX], eps=EPS, gam=0.00, nclasses=2)

        upper = np.squeeze(upper.detach().numpy())
        lower = np.squeeze(lower.detach().numpy())
        #print(upper[sens_inds])
        #print(lower[sens_inds])
        temp = copy.deepcopy(lower)
        for i in sens_inds:
            temp[i] = upper[i]
        #print(temp[sens_inds])
        top_idx = np.squeeze(np.argsort(temp))
        top_idx = list(reversed(top_idx))
        #print(set(top_idx[0:10]))
        #print( set(sens_inds))
        cert = not bool(set(top_idx[0:5]) & set(sens_inds))
        certified += int(cert)
        #break
    #print("Input Attack Certified: ", certified/200)
    return certified/200

get_input_cert(MODEL_ID)

In [None]:
# Measure Input Attack Certification
def get_model_cert(MODEL_ID, GAM=0.2):
    model.load_state_dict(torch.load('Models/%s.pt'%(MODEL_ID)))
    import copy
    certified = 0
    for INDEX in trange(200):
        lower, upper = GradCertModule.GradCertBounds(model, torch.Tensor(X_test[INDEX][None, :]),
                                                     y_test[INDEX], eps=0.00, gam=GAM, nclasses=2)

        upper = np.squeeze(upper.detach().numpy())
        lower = np.squeeze(lower.detach().numpy())
        #print(upper[sens_inds])
        #print(lower[sens_inds])
        temp = copy.deepcopy(lower)
        for i in sens_inds:
            temp[i] = upper[i]
        #print(temp[sens_inds])
        top_idx = np.squeeze(np.argsort(temp))
        top_idx = list(reversed(top_idx))
        #print(set(top_idx[0:10]))
        #print( set(sens_inds))
        cert = not bool(set(top_idx[0:5]) & set(sens_inds))
        certified += int(cert)
        #break
    return certified/200
    #print("Model Attack Certified: ", certified/200)
get_model_cert(MODEL_ID)

In [None]:
def gen_model_id(GAM_T=0.0, EPS_T=0.0):

    ALPHA = 1.0            # Regularization Parameter (Weights the Reg. Term)
    EPSILON = EPS_T         # Input Peturbation Budget at Training Time

    LEARN_RATE = 0.0005     # Learning Rate Hyperparameter
    HIDDEN_DIM = 256       # Hidden Neurons Hyperparameter
    HIDDEN_LAY = 2         # Hidden Layers Hyperparameter
    MAX_EPOCHS = 25

    EPSILON_LINEAR = True   # Put Epsilon on a Linear Schedule?
    GAMMA_LINEAR = True     # Put Gamma on a Linear Schedule?
    
    MODEL_ID = "%s_FCN_e=%s_g=%s_a=%s_l=%s_h=%s_s=%s"%(dataset, EPSILON, GAM_T, ALPHA, HIDDEN_LAY, HIDDEN_DIM, SCHEDULED)     
    print(MODEL_ID)
    return MODEL_ID


    

In [None]:
# Benchmark each baseline along with our method



In [None]:


CERT_VALS = []
eps_vals = [0.0, 0.01, 0.02, 0.03, 0.04, 0.05]
for eps in eps_vals:
    M_ID = gen_model_id(EPS_T=eps)
    val = []
    for e_test in np.linspace(0, 0.2, 20): #[0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4]:
        certified = get_input_cert(M_ID, EPS=e_test)
        val.append(certified)
    #print("*****")
    #print(gam, val)
    #print("*****")
    CERT_VALS.append(val)
    
    


In [None]:
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
#sns.set_style('darkgrid')
sns.set_context('poster')
plt.figure(figsize=(12, 8), dpi=100)


eps_vals = [0.0, 0.01, 0.02, 0.03, 0.04, 0.05]
g_test = [0.0, 0.01, 0.02, 0.03, 0.05, 0.06]

pal = sns.cubehelix_palette(n_colors=len(gamma_vals), start=.5, rot=-.75)
print(pal.as_hex())

for i in range(len(eps_vals)):
    print(CERT_VALS[i])
    plt.plot(CERT_VALS[i], label=gamma_vals[i], 
             color=pal[i], linewidth=10)
plt.legend()
ax = plt.gca()
ax.set_xticks([0,4,9,14,19])
labs = [round(i, 2) for i in np.linspace(0, 0.2, 5)]
ax.set_xticklabels(labs)
#ax.set_xticklabels(g_test)
plt.title("%s"%(dataset))
plt.ylabel("Input Certified Robustness")
plt.xlabel(r"Magnitude of $\epsilon$")
ax.get_legend().set_title(r"$\epsilon_t$")
plt.show()


In [None]:

CERT_VALS = []
gam_vals = [0.0, 0.01, 0.02, 0.03, 0.04, 0.05]
for gam in gam_vals:
    M_ID = gen_model_id(gam)
    val = []
    for g_test in np.linspace(0, 0.2, 20): #[0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4]:
        certified = get_model_cert(M_ID, GAM=g_test)
        val.append(certified)
    print("*****")
    print(gam, val)
    print("*****")
    CERT_VALS.append(val)
    

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
#sns.set_style('darkgrid')
sns.set_context('poster')
plt.figure(figsize=(12, 8), dpi=100)


gamma_vals = [0.0, 0.01, 0.02, 0.03, 0.04, 0.05]
g_test = [0.0, 0.01, 0.02, 0.03, 0.05, 0.06]

pal = sns.cubehelix_palette(n_colors=len(gamma_vals), start=.5, rot=-.75)
print(pal.as_hex())

for i in range(len(gam_vals)):
    print(CERT_VALS[i])
    plt.plot(CERT_VALS[i], label=gamma_vals[i], 
             color=pal[i], linewidth=10)
plt.legend()
ax = plt.gca()
ax.set_xticks([0,4,9,14,19])
labs = [round(i, 2) for i in np.linspace(0, 0.2, 5)]
ax.set_xticklabels(labs)
#ax.set_xticklabels(g_test)
plt.title("%s"%(dataset))
plt.ylabel("Model Certified Robustness")
plt.xlabel(r"Magnitude of $\gamma$")
ax.get_legend().set_title(r"$\gamma_t$")
plt.show()

In [None]:
"""
ATTK_VALS = []
gam_vals = [0.0, 0.01, 0.02, 0.03, 0.04, 0.05]
for gam in gam_vals:
    M_ID = gen_model_id(gam)
    val = []
    for g_test in [0.01, 0.05, 0.075, 0.1, 0.125, 0.175]:
        certified = get_model_attk(M_ID, GAM=g_test, N=50)
        print(certified)
        val.append(certified)
    print("*****")
    print(gam, val)
    print("*****")
    ATTK_VALS.append(val)
"""   

In [None]:

ATTK_VALS = [[0.0, 0.04, 0.04, 0.1, 0.18, 0.66], 
            [0.02, 0.08, 0.02, 0.04, 0.0, 0.0], 
            [0.02, 0.02, 0.0, 0.02, 0.0, 0.0], 
            [0.04, 0.04, 0.02, 0.02, 0.0, 0.02], 
            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 
            [0.0, 0.0, 0.08, 0.34, 0.4, 0.6]]



import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
#sns.set_style('darkgrid')
sns.set_context('poster')
plt.figure(figsize=(12, 8), dpi=100)


gam_vals = [0.0, 0.01, 0.02, 0.03, 0.04, 0.05]
g_test_vals = [0.01, 0.05, 0.075, 0.1, 0.125, 0.25]
    
    
pal = sns.cubehelix_palette()
pal = pal.as_hex()
pal = [i for i in reversed(pal)]
print(pal)
#print(pal.as_hex())

for i in range(len(gam_vals)):
    plt.plot(1-np.asarray(ATTK_VALS[i]), label=gam_vals[i], color=pal[i], linewidth=10)
plt.legend()
ax = plt.gca()

ax.set_xticks(range(len(g_test_vals)))
ax.set_xticklabels(g_test_vals)

plt.title("%s"%(dataset))
plt.ylabel("Model Attack Robustness")
plt.xlabel(r"Magnitude of $\gamma$")
ax.get_legend().set_title(r"$\gamma_t$")

plt.show()