<a href="https://colab.research.google.com/github/karimadadda/Deep_learning_project/blob/main/medical_diagnosis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from loading_data import *
from model import *
from config import config
import sys

import tensorflow as tf
from sklearn.model_selection import KFold
from sklearn.utils import shuffle
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
import warnings
warnings.filterwarnings('ignore')


if __name__ == '__main__':

    np.random.seed(2019)
    tf.random.set_seed(2019)

    config = config()
    data_file = './sample_data.xlsx'

    if len(sys.argv) < 2:
        data_file = sys.argv[1]

    x_bin_features, feats, tokens, feat_max, y = data_preparation(data_file)

    kf = KFold(n_splits=5, random_state=2019, shuffle=True)
    fold = 1

    accuracy = []
    f1 = []
    auc = []

    for train_index, test_index in kf.split(x_bin_features):

        x_bf_train, x_bf_test = x_bin_features[train_index], x_bin_features[test_index]
        y_train, y_test = y[train_index], y[test_index]


        def compute_loss(label, pred):

            return criterion(label, pred)


        def train_step(x_bin, t):

            with tf.GradientTape() as tape:
                pred, _, _ = model(x_bin)
                loss = compute_loss(t, pred)

            grads = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
            train_loss(loss)

            return pred


        def test_step(x_bin, t):

            pred, _, _ = model(x_bin)
            loss = compute_loss(t, pred)
            test_loss(loss)

            return pred


        model = Graph(tokens, config.embedding, feat_max, config.num_heads, config.dropout_rate)
        optimizer = tf.keras.optimizers.Adam(learning_rate=config.learning_rate, beta_1=config.beta_1,
                                             beta_2=config.beta_2, epsilon=config.epsilon)

        epochs = config.epochs
        batch_size = config.batch_size
        n_batches = x_bin_features.shape[0] // batch_size

        criterion = tf.losses.BinaryCrossentropy()

        train_loss = tf.keras.metrics.Mean()
        test_loss = tf.keras.metrics.Mean()

        es = []
        preds_temp = []
        stop = False

        for epoch in range(epochs):

            # early stopping
            if stop == False:
                _x_bf_train, _y_train = shuffle(x_bf_train, y_train, random_state=2019)

                for batch in range(n_batches):
                    start = batch * batch_size
                    end = start + batch_size
                    trainpreds = train_step(_x_bf_train[start:end], _y_train[start:end])

                testpreds = test_step(x_bf_test, y_test)
                score = roc_auc_score(y_test, testpreds)
                es.append(score)

                print(' epoch:', epoch, ' auc:', score)
                preds_temp.append(testpreds)

                if len(es) - np.argmax(es) > config.tolerance:
                    stop = True

            else:
                break

        num = np.argmax(es)
        print('fold:', fold, ' epoch:', num)

        pred_temp_thres = np.int32(preds_temp[num] > 0.5)

        acc_temp = accuracy_score(y_test, pred_temp_thres)
        accuracy.append(acc_temp)
        print('fold:', fold, ' accuracy:', acc_temp)

        f1_temp = f1_score(y_test, pred_temp_thres)
        f1.append(f1_temp)
        print('fold:', fold, ' f1_score:', f1_temp)

        auc_temp = roc_auc_score(y_test, preds_temp[num])
        auc.append(auc_temp)
        print('fold:', fold, ' auc:', auc_temp)

        fold += 1

    print('###################################################')
    print('auc:', np.mean(auc))
    print('f1 score:', np.mean(f1))
    print('accuracy:', np.mean(accuracy))
    print('\n')


 epoch: 0  auc: 0.40541252965468766
 epoch: 1  auc: 0.41477023108689914
 epoch: 2  auc: 0.41696687461558735
 epoch: 3  auc: 0.42263421491960285
 epoch: 4  auc: 0.42636850891837275
 epoch: 5  auc: 0.4330463052455848
 epoch: 6  auc: 0.43950443721992793
 epoch: 7  auc: 0.44367805992443543
 epoch: 8  auc: 0.4464897636411563
 epoch: 9  auc: 0.4500043932870574
 epoch: 10  auc: 0.45184957385115543
 epoch: 11  auc: 0.45523240488533523
 epoch: 12  auc: 0.4622177313065635
 epoch: 13  auc: 0.4718390299622177
 epoch: 14  auc: 0.48110886565328176
 epoch: 15  auc: 0.4939372638608206
 epoch: 16  auc: 0.5080836481855724
 epoch: 17  auc: 0.5228450926983569
 epoch: 18  auc: 0.5406379052807312
 epoch: 19  auc: 0.5587821808276953
 epoch: 20  auc: 0.5773218522098235
 epoch: 21  auc: 0.5977506370266233
 epoch: 22  auc: 0.6188384149020297
 epoch: 23  auc: 0.6395307969422722
 epoch: 24  auc: 0.6609920042175556
 epoch: 25  auc: 0.6806519637993146
 epoch: 26  auc: 0.6977857833230824
 epoch: 27  auc: 0.713689482

KeyboardInterrupt: ignored