In [None]:
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import RandomizedSearchCV
from sklearn.model_selection import train_test_split
from sklearn import metrics
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
import warnings
warnings.filterwarnings('ignore')
import pandas as pd
import numpy as np

def label_bcva(row): 
    final_bcva = row.target_va
    toRtn = 0
    if final_bcva < 60: 
        if final_bcva < 40: toRtn = 0
        else: toRtn = 1
    else:
        if final_bcva < 80: toRtn = 2
        else: toRtn = 3
    return toRtn

df = pd.read_csv("~/Documents/Github/paper/input/df_3_years.csv")
df['outcome'] = df.apply(lambda row: label_bcva(row), axis=1)

X, y = df.drop(columns=['target_va', 'outcome']), df.outcome.values
X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2, random_state=4)

def score(model, X, y, cv=5, scoring='accuracy'):
    scores = cross_val_score(model, X, y, cv=cv, scoring=scoring)
    return np.mean(scores), np.std(scores)

In [None]:
from pytorch_tabnet.tab_model import TabNetClassifier

import torch
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
np.random.seed(0)

from matplotlib import pyplot as plt
%matplotlib inline

X, y = df.drop(columns=['target_va', 'outcome']).values, df.outcome.values
X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2, random_state=4)

In [None]:
clf = TabNetClassifier(
    n_d=64, n_a=64, n_steps=5,
    gamma=1.5, n_independent=2, n_shared=2,
    lambda_sparse=1e-4, momentum=0.3, clip_value=2.,
    optimizer_fn=torch.optim.Adam,
    optimizer_params=dict(lr=2e-2),
    scheduler_params = {"gamma": 0.95,
                     "step_size": 20},
    scheduler_fn=torch.optim.lr_scheduler.StepLR, epsilon=1e-15,
    verbose=0
)

In [None]:
clf.fit(
    X_train=X_train, y_train=y_train,
    eval_set=[(X_train, y_train), (X_valid, y_valid)],
    eval_name=['train', 'valid'],
    max_epochs=1000, patience=100,
    batch_size=64
) 

In [4]:
import os 
os.chdir('/Users/charlesoneill/Documents/GitHub/paper/results')

In [5]:
def return_tabnet():
    return TabNetClassifier(
    n_d=64, n_a=64, n_steps=5,
    gamma=1.5, n_independent=2, n_shared=2,
    lambda_sparse=1e-4, momentum=0.3, clip_value=2.,
    optimizer_fn=torch.optim.Adam,
    optimizer_params=dict(lr=2e-2),
    scheduler_params = {"gamma": 0.95,
                     "step_size": 20},
    scheduler_fn=torch.optim.lr_scheduler.StepLR, epsilon=1e-15,
    verbose=0)

In [6]:
def train_tabnet(year):
    df = pd.read_csv(f"~/Documents/Github/paper/input/df_{year}_years.csv")
    df['outcome'] = df.apply(lambda row: label_bcva(row), axis=1)
    X, y = df.drop(columns=['target_va', 'outcome']).values, df.outcome.values
    X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2, random_state=4)
    clf = return_tabnet()
    clf.fit(X_train=X_train, y_train=y_train, 
            eval_set=[(X_train, y_train), (X_valid, y_valid)],
            eval_name=['train', 'valid'], max_epochs=1000, patience=100,
            batch_size=64)
    return clf

In [None]:
clf_1 = train_tabnet(1)
clf_2 = train_tabnet(2)
clf_3 = train_tabnet(3)
model_lst = [clf_1, clf_2, clf_3]

In [None]:
def bcva_prob_dist(model, sample, year, ground_truth):
    t = [0, 30, 50, 70, 90, 100]
    preds_proba = model.predict_proba(sample).tolist()[0]
    preds_proba.insert(0, 0.0)
    preds_proba.append(0)
    plt.plot(t, preds_proba, color="orange")
    plt.fill_between(t, preds_proba, color="navy")
    plt.axvline(x=ground_truth, color='r', linestyle='-', label="True vision")
    plt.ylabel("Probability of BCVA")
    plt.xlabel(f"Vision at end of Year {year} (logMAR letters)")
    plt.legend()
    plt.savefig(f"year{year}_patient_distribution.png", dpi=300)
    plt.show()

In [None]:
sample = X_valid[11:12]
bcva_prob_dist(clf, sample, 3, 40)