In [1]:
from src.models import Exponential_Model
from src.criterion import right_censored,RightCensorWrapper
from src.load_data import load_datasets,load_dataframe
from src.utils import train_robust,lower_bound
from src.visualizations import visualize_population_curves_attacked,visualize_individual_curves_attacked

from torch.optim import Adam
import torch
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import seaborn as sns

from lifelines import KaplanMeierFitter
from lifelines.utils import concordance_index
from lifelines import CoxPHFitter

from auto_LiRPA import BoundedModule, BoundedTensor

from tqdm import tqdm
import pandas as pd
import numpy as np
from copy import deepcopy

  from .autonotebook import tqdm as notebook_tqdm
  from pkg_resources import packaging  # type: ignore[attr-defined]


In [2]:
class ARGS(object):
    def __init__(self):
        pass

In [3]:
args = ARGS()
args.verify=False
args.device="cpu"

args.seed = 123

args.eps=0.5
args.norm=np.inf
args.bound_type = "CROWN-IBP"
args.num_epochs=25
args.lr = 1e-3
args.batch_size= 128
args.scheduler_name = "SmoothedScheduler"
args.scheduler_opts = "start=5,length=10"
args.hidden_dims = [15,15]
args.save_model = ""
args.dataset = "Dialysis"


In [4]:
# GOOD DATASETS
# 1. TRACE
# 2. divorce 
# 3. Dialysis
dataset_train,dataset_test = load_datasets(args.dataset,test_size=0.2)

In [5]:
input_dims = dataset_train.tensors[0].shape[1]
output_dim = 1

In [6]:
dataloader_train = DataLoader(dataset_train,batch_size=args.batch_size,shuffle=True)
dataloader_test = DataLoader(dataset_test,batch_size=args.batch_size,shuffle=False)

dataloader_train.mean = dataloader_test.mean = dataset_train.mean
dataloader_train.std = dataloader_test.std = dataset_train.std


dataset_train.tensors[0].shape

torch.Size([5444, 74])

In [7]:
clf_robust = Exponential_Model(input_dim=input_dims,hidden_layers=args.hidden_dims)
clf_fragile = Exponential_Model(input_dim=input_dims,hidden_layers=args.hidden_dims)
clf_fragile.load_state_dict(deepcopy(clf_robust.state_dict()))


# model = BoundedModule(clf, X_train)
model_robust_wrap = BoundedModule(RightCensorWrapper(clf_robust),dataset_train.tensors)
model_fragile_wrap = BoundedModule(RightCensorWrapper(clf_fragile),dataset_train.tensors)

In [None]:
train_robust(model_robust_wrap,dataloader_train,dataloader_test,method="robust",args=args)



Epoch 1, learning rate [0.001]
[ 1:   0]: eps=0.00000000 Loss=1850.1527 Time=0.0020
[ 1:  10]: eps=0.00000000 Loss=1868.9953 Time=0.0023
[ 1:  20]: eps=0.00000000 Loss=1794.7013 Time=0.0022
[ 1:  30]: eps=0.00000000 Loss=1713.6873 Time=0.0022
[ 1:  40]: eps=0.00000000 Loss=1631.3091 Time=0.0022
[ 1:  42]: eps=0.00000000 Loss=1608.4964 Time=0.0021
Epoch time: 0.1135, Total time: 0.1135
Evaluating...
[ 1:  10]: eps=0.00000000 Loss=1209.5594 Time=0.0010
Epoch 2, learning rate [0.001]
[ 2:   0]: eps=0.00000000 Loss=1068.3817 Time=0.0022
[ 2:  10]: eps=0.00000000 Loss=1144.3256 Time=0.0021
[ 2:  20]: eps=0.00000000 Loss=1043.7619 Time=0.0020
[ 2:  30]: eps=0.00000000 Loss=949.4128 Time=0.0019
[ 2:  40]: eps=0.00000000 Loss=856.8435 Time=0.0020
[ 2:  42]: eps=0.00000000 Loss=838.6754 Time=0.0020
Epoch time: 0.1073, Total time: 0.2208
Evaluating...
[ 2:  10]: eps=0.00000000 Loss=443.8868 Time=0.0012
Epoch 3, learning rate [0.001]
[ 3:   0]: eps=0.00000000 Loss=436.5890 Time=0.0021
[ 3:  10]: 

In [None]:
train_robust(model_fragile_wrap,dataloader_train,dataloader_test,method="natural",args=args)

In [None]:
X_train,T_train,E_train = dataloader_train.dataset.tensors
t = torch.linspace(0,T_train.max(),10000)

St_robust_x = clf_robust.survival_qdf(X_train,t).detach()
St_fragile_x = clf_fragile.survival_qdf(X_train,t).detach()

kmf = KaplanMeierFitter()
kmf.fit(durations=T_train,event_observed=E_train)
St_kmf  = kmf.predict(times=t.ravel().numpy())

plt.figure(figsize=(10,10))
plt.plot(t,St_kmf)
plt.plot(t,St_fragile_x.mean(0))
plt.plot(t,St_robust_x.mean(0))

plt.ylabel("S(t)"); plt.xlabel("Time")
plt.legend(["Kaplan Meier Numerical","Neural Network Normal","Neural Network Robust"])
plt.title("Train Population Survival Curves")
plt.show()

In [None]:
epsilons = [1,.8,0.7,.6,0.5,0.1,0.07,0.05]
visualize_population_curves_attacked(clf_fragile,clf_robust,dataloader_train,epsilons=epsilons)

In [None]:
epsilons = [1,.8,0.7,.6,0.5,0.1,0.07,0.05]
visualize_population_curves_attacked(clf_fragile,clf_robust,dataloader_test,epsilons=epsilons)

In [None]:
visualize_individual_curves_attacked(clf_robust,dataloader_train,epsilon=0.3,order="ascending")

In [None]:
visualize_individual_curves_attacked(clf_robust,dataloader_train,epsilon=0.3,order="descending")

In [None]:
visualize_individual_curves_attacked(clf_robust,dataloader_test,epsilon=0.3,order="ascending")

In [None]:
visualize_individual_curves_attacked(clf_robust,dataloader_test,epsilon=0.3,order="descending")

In [None]:
lb,ub = lower_bound(clf_robust,dataset_test.,0.1)
St_lb = torch.exp(-ub*t).detach()

In [None]:
plt.figure(figsize=(10,10))


St_robust_x = clf_robust.survival_qdf(X_train, t).detach()

test_cases = 30

colors = list(plt.cm.brg(np.linspace(0,1,test_cases))) + ["crimson", "indigo"]

cases = np.argsort(torch.linalg.norm(St_lb - St_given_x,axis=1))[0:test_cases]
print(torch.linalg.norm(St_lb - St_given_x,axis=1)[cases])

for i,case in enumerate(tqdm(cases)):
    plt.plot(t,St_given_x[case],color=colors[i])
    plt.plot(t,St_lb[case],'--',color=colors[i])
    
plt.ylabel("S(t)"); plt.xlabel("Time")
plt.title("Individual Survival Curves Train")

In [None]:
plt.figure(figsize=(10,10))
# lb,ub = lower_bound(model,X_train,0.1)
    
test_cases = 30
cases = torch.flip(np.argsort(torch.linalg.norm(St_lb - St_given_x,axis=1)),dims=(0,))[0:test_cases]
print(torch.linalg.norm(St_lb - St_given_x,axis=1)[cases])
for i,case in enumerate(tqdm(cases)):
    plt.plot(t,St_given_x[case],color=colors[i])
    plt.plot(t,St_lb[case],'--',color=colors[i])
    
plt.ylabel("S(t)"); plt.xlabel("Time")
plt.title("Individual Survival Curves Train")
