In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf

from model import get_model
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, LearningRateScheduler, ReduceLROnPlateau

In [None]:
df_1 = pd.read_csv("data/ptbdb_normal.csv", header=None)
df_2 = pd.read_csv("data/ptbdb_abnormal.csv", header=None)
df = pd.concat([df_1, df_2])

df_train, df_test = train_test_split(df, test_size = 0.2, random_state = 42, stratify = df[len(df.columns)-1])
df_train.shape, df_test.shape

In [None]:
df.head()

In [None]:
df.describe()

In [None]:
len(df_train.columns), len(df_train)

In [None]:
df[len(df_train.columns)-1].value_counts()

In [None]:
chkpt = ModelCheckpoint(filepath = "recurrent_model_initial.h5", monitor = "val_acc", mode = 'max', save_best_only = True, verbose = 1)
early = EarlyStopping(monitor = "val_acc", mode = "max", patience = 5, verbose = 1)
redonplat = ReduceLROnPlateau(monitor = "val_acc", mode = "max", patience = 3, verbose = 2)

callbacks_list = [chkpt, early, redonplat]

In [None]:
strategy = tf.distribute.experimental.CentralStorageStrategy()

print(f'Number of devices: {strategy.num_replicas_in_sync}')

In [None]:
with strategy.scope():
    model = get_model(df_train.shape[0], 1, 1)
    
    opt = tf.keras.optimizers.Adam(0.001)

    model.compile(optimizer = opt, loss = tf.keras.losses.binary_crossentropy, metrics = ['acc'])

In [None]:
model.summary()

In [None]:
model.fit(X, Y, batch_size = 128, epochs = 10, verbose = 1, callbacks = callbacks_list, validation_split = 0.1)

In [None]:
i = np.random.choice(X.shape[0])
plt.plot(X[i].flatten())
plt.title("Heartbeat Class: "+str(Y[i]))
plt.show()

In [None]:
pred_test = model.predict(X_test)
pred_test = (pred_test > 0.5).astype(np.int8)

f1 = f1_score(Y_test, pred_test)

print("Test f1 score : %s "% f1)

acc = accuracy_score(Y_test, pred_test)

print("Test accuracy score : %s "% acc)

In [None]:
model.save("ptbdb.h5")