##### Imports

In [12]:
import pandas as pd
import h5py

from tensorflow.keras.callbacks import (
    ModelCheckpoint, EarlyStopping, ReduceLROnPlateau)

from tqdm.keras import TqdmCallback

import wandb
from wandb.keras import WandbCallback

from data_io import get_data
from utils import save_predictions
from model_parts import get_attention_model, modify_model

In [13]:
import warnings
warnings.filterwarnings('ignore', category=UserWarning)

In [14]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Initialize result file

In [15]:
results = pd.DataFrame(
    {'model' : [], 'acc': [], 'f1-weighted': [], 'f1-macro': [], 'bacc': []})

# MITBIH Model

### Load data

In [16]:
X_train_mit, y_train_mit, X_test_mit, y_test_mit = get_data(dataset='mitbih')

### Define model

#### W&B setup

In [17]:
run = wandb.init(
    project='ML4HC-project4',
    config={
        'learning_rate': 2e-3,
        'lr_decay_steps': 256,
        'lr_decay_rate': 0.998,
        'epochs': 8,
        'batch_size': 32,
        'num_filters': [32, 32, 256],
        'num_blocks_list': [2, 2, 2],
        'kernel_sizes': [5, 3, 3],
        'loss_function': 'sparse_categorical_crossentropy',
        'architecture': 'Attention',
        'dataset': 'MITBIH',
        'mode': 'training',
        'ndim': 256,
        'nheads': 3
    })

config_mit = wandb.config

[34m[1mwandb[0m: wandb version 0.10.31 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


#### Set-up model

In [18]:
original_model = get_attention_model(config_mit, nclass=5)

Model: "functional_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, 187, 1)]     0                                            
__________________________________________________________________________________________________
batch_normalization_14 (BatchNo (None, 187, 1)       4           input_2[0][0]                    
__________________________________________________________________________________________________
conv1d_7 (Conv1D)               (None, 187, 32)      192         batch_normalization_14[0][0]     
__________________________________________________________________________________________________
batch_normalization_15 (BatchNo (None, 187, 32)      128         conv1d_7[0][0]                   
_______________________________________________________________________________________

In [19]:
file_path_mit = './models/attention/attention_mitbih.h5'

checkpoint = ModelCheckpoint(
    file_path_mit, monitor='val_f1-macro', verbose=1,
    save_best_only=True, mode='max')
early = EarlyStopping(
    monitor='val_f1-macro', mode='max', patience=5, verbose=1)
redonplat = ReduceLROnPlateau(
    monitor='val_acc', mode='max', patience=3, verbose=2)

### Training

In [20]:
history = original_model.fit(
    X_train_mit, y_train_mit,
    epochs=config_mit.epochs,
    batch_size=config_mit.batch_size,
    verbose=0,
    callbacks=[
        checkpoint, early, redonplat,
        WandbCallback(), TqdmCallback(verbose=1)],
    validation_split=0.1)



Epoch 00004: val_f1-macro did not improve from 0.90123

Epoch 00005: val_f1-macro improved from 0.90123 to 0.91734, saving model to ./models/attention/attention_mitbih.h5

Epoch 00006: val_f1-macro did not improve from 0.91734

Epoch 00007: val_f1-macro did not improve from 0.91734

Epoch 00008: val_f1-macro improved from 0.91734 to 0.92624, saving model to ./models/attention/attention_mitbih.h5



HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…


Epoch 00001: val_f1-macro improved from -inf to 0.82220, saving model to ./models/attention/attention_mitbih.h5

Epoch 00002: val_f1-macro did not improve from 0.82220


KeyboardInterrupt: 

### Evaluation

In [None]:
loss, acc, f1_weighted, f1_macro, b_acc = original_model.evaluate(
    X_test_mit, y_test_mit, batch_size=512, verbose=2)

wandb.log({'Test Accuracy': acc})
wandb.log({'Test F1 Weighted': f1_weighted})
wandb.log({'Test F1 Macro': f1_macro})
wandb.log({'Test Balanced Accuracy': b_acc})

43/43 - 7s - loss: 0.0616 - acc: 0.9827 - f1-weighted: 0.9822 - f1-macro: 0.9008 - balanced_accuracy: 0.8917


In [None]:
run.join()

VBox(children=(Label(value=' 17.40MB of 17.40MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.…

0,1
epoch,7.0
loss,0.05383
acc,0.98386
f1-weighted,0.98232
f1-macro,0.91783
balanced_accuracy,0.93733
val_loss,0.05224
val_acc,0.98401
val_f1-weighted,0.98196
val_f1-macro,0.92624


0,1
epoch,▁▂▃▄▅▆▇█
loss,█▄▃▂▂▂▁▁
acc,▁▅▆▇▇███
f1-weighted,▁▅▆▇▇███
f1-macro,▁▅▆▆▇███
balanced_accuracy,▁▅▆▇▇███
val_loss,█▇▃▃▂▄▃▁
val_acc,▁▂▇▆▇▅▆█
val_f1-weighted,▁▃▆▆▇▅▆█
val_f1-macro,▁▂▆▅▇▅▅█


In [None]:
results = pd.concat(
    [results, pd.DataFrame([['MITBIH', acc, f1_weighted, f1_macro, b_acc]],
                           columns=results.columns)],
    axis=0).reset_index(drop=True)

### Predictions on the Test Set

In [None]:
save_predictions(original_model, X_test_mit, 'attention_rnn', 'mit_test')
save_predictions(original_model, X_train_mit, 'attention_rnn', 'mit_train')

# PTBDB Transfer Learning

### Load data

In [None]:
X_train_pt, y_train_pt, X_test_pt, y_test_pt = get_data(dataset='ptbdb')

## Frozen

### Define model

#### W&B setup

In [None]:
run = wandb.init(
    project='ML4HC-project4',
    config={
        'learning_rate': 2e-3,
        'lr_decay_steps': 256,
        'lr_decay_rate': 0.96,
        'kernel_sizes': [5, 5, 3, 3],
        'num_filters': [32, 32, 128, 128],
        'epochs': 4,
        'batch_size': 32,
        'loss_function': 'sparse_categorical_crossentropy',
        'architecture': 'Attention',
        'dataset': 'PTBDB',
        'mode': 'frozen',
        'ndim': 256,
        'nheads': 4
    })

config_pt_frozen = wandb.config

#### Set-up model

In [None]:
pretrained_model = get_attention_model(
    nclass=5, prepare_model=False, config=config_pt_frozen)

In [None]:
file_path_pt_frozen = './models/attention/attention_pt_frozen.h5'

checkpoint = ModelCheckpoint(
    file_path_pt_frozen, monitor='val_f1-macro', verbose=1,
    save_best_only=True, mode='max')
early = EarlyStopping(
    monitor='val_f1-macro', mode='max', patience=5, verbose=1)
redonplat = ReduceLROnPlateau(
    monitor='val_acc', mode='max', patience=3, verbose=2)

#### Initialize model

In [None]:
pretrained_model.load_weights(file_path_mit)

#### Modify the model

In [None]:
fine_tuned_model = modify_model(pretrained_model, config_pt_frozen)

#### Train model

In [None]:
fine_tuned_model.fit(
    X_train_pt, y_train_pt,
    epochs=config_pt_frozen.epochs,
    batch_size=config_pt_frozen.batch_size,
    verbose=0,
    callbacks=[
        checkpoint, early, redonplat,
        WandbCallback(), TqdmCallback(verbose=1)],
    validation_split=0.1)

### Evaluation

In [None]:
loss, acc, f1_weighted, f1_macro, b_acc = fine_tuned_model.evaluate(
    X_test_pt, y_test_pt, batch_size=512, verbose=2)

wandb.log({'Test Accuracy': acc})
wandb.log({'Test F1 Weighted': f1_weighted})
wandb.log({'Test F1 Macro': f1_macro})
wandb.log({'Test Balanced Accuracy': b_acc})

In [None]:
run.join()

In [None]:
results = pd.concat(
    [results, pd.DataFrame([['PTBDB-frozen', acc, f1_weighted, f1_macro, b_acc]],
                           columns=results.columns)],
    axis=0).reset_index(drop=True)

### Predictions on the Test Set

In [None]:
save_predictions(fine_tuned_model, X_test_mit, 'attention_rnn', 'ptbdb_frozen_test')
save_predictions(fine_tuned_model, X_train_mit, 'attention_rnn', 'ptbdb_frozen_train')

## Whole Model

### Define model

#### W&B setup

In [None]:
run = wandb.init(
    project='ML4HC-project4',
    config={
        'learning_rate': 2e-3,
        'lr_decay_steps': 256,
        'lr_decay_rate': 0.96,
        'kernel_sizes': [5, 5, 3, 3],
        'num_filters': [32, 32, 128, 128],
        'epochs': 4,
        'batch_size': 32,
        'loss_function': 'sparse_categorical_crossentropy',
        'architecture': 'Attention',
        'dataset': 'PTBDB',
        'mode': 'whole',
        'ndim': 256,
        'nheads': 4
    })

config_pt_whole = wandb.config

#### Set-up model

In [None]:
pretrained_model = get_attention_model(
    nclass=5, prepare_model=False, config=config_pt_whole)

In [None]:
file_path_pt_whole = './models/attention/attention_pt_whole.h5'

checkpoint = ModelCheckpoint(
    file_path_pt_whole, monitor='val_f1-macro', verbose=1,
    save_best_only=True, mode='max')
early = EarlyStopping(
    monitor='val_f1-macro', mode='max', patience=5, verbose=1)
redonplat = ReduceLROnPlateau(
    monitor='val_acc', mode='max', patience=3, verbose=2)

#### Initialize model

In [None]:
pretrained_model.load_weights(file_path_mit)

#### Modify the model

In [None]:
fine_tuned_model = modify_model(pretrained_model, config_pt_whole)

#### Train model

In [None]:
fine_tuned_model.fit(
    X_train_pt, y_train_pt,
    epochs=config_pt_whole.epochs,
    batch_size=config_pt_whole.batch_size,
    verbose=0,
    callbacks=[
        checkpoint, early, redonplat,
        WandbCallback(), TqdmCallback(verbose=1)],
    validation_split=0.1)

### Evaluation

In [None]:
loss, acc, f1_weighted, f1_macro, b_acc = fine_tuned_model.evaluate(
    X_test_pt, y_test_pt, batch_size=512, verbose=2)

wandb.log({'Test Accuracy': acc})
wandb.log({'Test F1 Weighted': f1_weighted})
wandb.log({'Test F1 Macro': f1_macro})
wandb.log({'Test Balanced Accuracy': b_acc})

In [None]:
run.join()

In [None]:
results = pd.concat(
    [results, pd.DataFrame([['PTBDB-whole', acc, f1_weighted, f1_macro, b_acc]],
                           columns=results.columns)],
    axis=0).reset_index(drop=True)

# Save all Results

In [None]:
results.to_csv(
    './output/attention_results_epochs{}_batch{}.csv'.format(
    config_pt_frozen.epochs, config_pt_frozen.batch_size))

### Predictions on the Test Set

In [None]:
save_predictions(fine_tuned_model, X_test_mit, 'attention_rnn', 'ptbdb_whole_test')
save_predictions(fine_tuned_model, X_train_mit, 'attention_rnn', 'ptbdb_whole_train')