# Setup

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install tensorflow

In [4]:
import os
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import tensorflow as tf

from datetime import datetime
from pathlib import Path
from keras.callbacks import EarlyStopping, TensorBoard, ModelCheckpoint
from sklearn.metrics import confusion_matrix, accuracy_score, recall_score, precision_score, f1_score
from sklearn.model_selection import StratifiedKFold
from tensorflow import keras
from tensorflow.keras.layers import LSTM, Dense

In [46]:
DATA_DIR = 'TMP_DIR/lstm'
LSTM_DATA_DIR = os.path.join('/content/drive/MyDrive', DATA_DIR)
OUTPUT_DIR = os.path.join(LSTM_DATA_DIR, 'output')

NUM_READ = 150
NUM_SPLIT = 5
NUM_SEED = 8
STANDARD_DEVIATION = 1.0

STATE = {
    "healthy": 0, 
    "sick group A": 1,
    "sick group B": 2
}

FEATURES = ["movement_distance_head", "movement_distance_tail", "movement_distance_center"]

In [34]:
def set_seeds(seed=8):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    session_conf = tf.compat.v1.ConfigProto(
        intra_op_parallelism_threads=1,
        inter_op_parallelism_threads=1
    )
    sess = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=session_conf)
    tf.compat.v1.keras.backend.set_session(sess)

set_seeds(NUM_SEED)

In [35]:
def has_null(df):
    if not df.notnull().all().all():
        raise ValueError("Your data has null field.")

In [36]:
def get_current_datetime_str() -> str:
    now = str(datetime.now())
    return f'{now[0:4]}{now[5:7]}{now[8:10]}{now[11:13]}{now[14:16]}{now[17:19]}'

In [37]:
X = []
Y = []
y = []

for state_dir in Path(LSTM_DATA_DIR).glob("*"):
    csv_paths = state_dir.glob("*.csv")
    for csv_path in csv_paths:
        y.append(STATE[state_dir.name])
        Y.append([STATE[state_dir.name], csv_path.name])
        df = pd.read_csv(str(csv_path)).iloc[:NUM_READ]
        df = df.loc[:, FEATURES]
        has_null(df)
        frames = []
        X.append(frames)
        for row in df.itertuples(name=None):
            frames.append(row[1:])

In [38]:
X = np.array(X)
Y = np.array(Y)
y = np.array(y)

In [None]:
X.shape

# Training using k-fold cross validation

In [40]:
kfold = StratifiedKFold(n_splits=NUM_SPLIT, shuffle=True, random_state=NUM_SEED)

In [None]:
for idx, (idx_train, idx_test) in enumerate(kfold.split(X, y)):
    X_train = X[idx_train]
    X_test = X[idx_test]
    Y_train = Y[idx_train]
    Y_test = Y[idx_test]

    y_train, csvfilename_train = np.hsplit(Y_train, 2)
    y_test, csvfilename_test = np.hsplit(Y_test, 2) 
    y_train = np.array(y_train.flatten(), dtype='uint8')
    csvfilename_train = csvfilename_train.flatten()
    y_test = np.array(y_test.flatten(), dtype='uint8')
    csvfilename_test = csvfilename_test.flatten()
    
    model = keras.Sequential()
    model.add(LSTM(128, input_shape=(X_train.shape[1], X_train.shape[2])))
    model.add(Dense(3, activation='softmax'))
    print(model.summary())

    idx_dir = "cv" + str(idx+1)

    model_dir = os.path.join(OUTPUT_DIR, f'models/{idx_dir}')
    os.makedirs(model_dir, exist_ok=True)

    log_dir = os.path.join(OUTPUT_DIR, f'tensorlog/{idx_dir}')
    os.makedirs(log_dir, exist_ok=True)

    model_path = os.path.join(model_dir, f'{get_current_datetime_str()}.hdf5')
    cp_cb = ModelCheckpoint(filepath=model_path, monitor='val_loss', verbose=1, save_best_only=True, mode='auto')
    tb_cb = TensorBoard(log_dir=log_dir, histogram_freq=1)

    model.compile(
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
        optimizer="adam",
        metrics=["accuracy"],
    )

    history = model.fit(
        X_train,
        y_train, 
        validation_data=(X_test, y_test), 
        batch_size=2, 
        epochs=30,     
        callbacks=[
            cp_cb,
            tb_cb
      ]
    )

    y_pred = tf.argmax(model.predict(X_test), axis=-1).numpy()
    Y_analysis = np.array([y_test, y_pred, csvfilename_test])
    Y_analysis = Y_analysis.transpose()

    df_result = pd.DataFrame(data=Y_analysis, columns=['actual', 'pred', 'filename'])
    result_dir = os.path.join(OUTPUT_DIR, f'result/{idx_dir}')
    os.makedirs(result_dir, exist_ok=True)
    df_result.to_csv(os.path.join(result_dir, 'result.csv'), index=False)

In [None]:
for idx in range(NUM_SPLIT):
    preds = []
    acts = []
    idx_dir = "cv" + str(idx+1)
    result_path = os.path.join(OUTPUT_DIR, f'result/{idx_dir}/result.csv')
    df = pd.read_csv(result_path)

    actual = df['actual']
    acts.extend(actual.values)
    pred = df['pred']
    preds.extend(pred.values)
    
    acts = pd.Series(acts)
    preds = pd.Series(preds)
    cm = pd.DataFrame(
        data=confusion_matrix(acts, preds),
        index=["Healthy", "Sick group A", "Sick group B"], 
        columns=["Healthy", "Sick group A", "Sick group B"]
    )

    fig, ax = plt.subplots(1, 2, figsize=(10,6))

    sns.heatmap(cm, square=True, cbar=True, annot=True, cmap='Blues', fmt='d', ax=ax[0])
    ax[0].set_title("About Count")
    ax[0].set_xlabel("Prediction", fontsize=11, rotation=0)
    ax[0].set_ylabel("Ground truth", fontsize=11)

    if idx == 0: 
        cv = str(idx+1) + "st CV"
    elif idx == 1:
        cv = str(idx+1) + "nd CV"
    elif idx == 2:
        cv = str(idx+1) + "rd CV"
    else:
        cv = str(idx+1) + "th CV"

    fig.suptitle(f'Confusion matrix ({cv})', fontsize=16)
    cm = cm.astype('float64')
    num_data = len(acts)
    for index, row in cm.iterrows():
        row['Healthy'] = (row['Healthy'] / num_data)
        row['Sick group A'] = (row['Sick group A'] / num_data)
        row['Sick group B'] = (row['Sick group B'] / num_data)

    sns.heatmap(cm, square=True, cbar=True, annot=True, cmap='Blues', fmt=".3f", ax=ax[1])
    ax[1].set_title("About Ratio")
    ax[1].set_xlabel("Prediction", fontsize=11, rotation=0)
    ax[1].set_ylabel("Ground truth", fontsize=11)

    print(f"--------------------Cross Validation Round {idx+1}--------------------")
    plt.show()

    print("Evaluation Index: [Healthy, Sick Group A, Sick Group B]")
    print(f'Racall: {recall_score(actual, pred, average=None)}')
    print(f'Presision{precision_score(actual, pred, average=None)}')
    print(f'F1 Score: {f1_score(actual, pred, average=None)}')
    print(f'Accuracy: {accuracy_score(actual, pred)}')
    print()