In [1]:
# docker shenanigans, please ignore 
#%cd /app/notebooks/enhance_crt_net/

### CRT-Net MIT-BIH Simple (200 samples, 2 leads, 5 classes)

In [4]:
import tensorflow as tf
from importlib import reload
from src import crtnet_models
reload(crtnet_models)

tf.keras.backend.clear_session()
model = crtnet_models.crt_net_original(
    n_classes=5,
    input_shape=(200,2),
    n_vgg_blocks=1,
    binary=False, # set this to true if using multilabel output (disables softmax and categorical cross entropy). MIT-BIH is not multilabel.
    use_focal=True, # addresses significant class imbalance (enables focal cross entropy)
    metrics=['accuracy', 'f1_score'], # May be better to evaluate on F1 score if using early stopping
    d_model=128, # default feature dim size (d_ffn set to 2*d_model)
)
model.summary()
del model


Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 200, 2)]          0         
                                                                 
 vgg_net (VGGNet)            (None, 100, 128)          50176     
                                                                 
 bi_gru (BiGRU)              (None, 100, 256)          198144    
                                                                 
 stacked_transformer_encode  (None, 100, 256)          1583104   
 r (StackedTransformerEncod                                      
 er)                                                             
                                                                 
 global_average_pooling1d (  (None, 256)               0         
 GlobalAveragePooling1D)                                         
                                                             

### CRT-Net MIT-BIH Alternate (200 samples, 2 leads, 5 classes)

The provided CRT-Net models.py has some alterations which may be the result of tuning the model:
- Leaky ReLU (alpha=0.3) activation instead of ReLU.
- Dropout (rate=0.2) after every VGG block and the BiGRU layer.
- Sine position encoding uses max position encoding of 2048, instead of default 10000
- Additional dropout between transformer encoders and global pooling
- Additional dense layer before output (units=4*n_classes, SeLU activation)

In [45]:
import tensorflow as tf
from importlib import reload
from src import crtnet_models
reload(crtnet_models)

tf.keras.backend.clear_session()
model = crtnet_models.crt_net_original_alt(
    n_classes=5,
    input_shape=(200,2),
    n_vgg_blocks=1,
    binary=False, # set this to true if using multilabel output (disables softmax and categorical cross entropy). MIT-BIH is not multilabel.
    use_focal=True, # addresses significant class imbalance (enables focal cross entropy)
    metrics=['accuracy', 'f1_score'], # May be better to evaluate on F1 score if using early stopping
    d_model=128, # default feature dim size (d_ffn set to 2*d_model)
)
model.summary()
del model

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 200, 2)]          0         
                                                                 
 vgg_net (VGGNet)            (None, 100, 128)          443520    
                                                                 
 bi_gru (BiGRU)              (None, 100, 256)          198144    
                                                                 
 stacked_transformer_encode  (None, 100, 256)          1583104   
 r (StackedTransformerEncod                                      
 er)                                                             
                                                                 
 dropout_2 (Dropout)         (None, 100, 256)          0         
                                                                 
 global_average_pooling1d (  (None, 256)               0     

### CRT-Net CPSC Simple (3000 samples, 12 leads, 9 classes)

In [46]:
import tensorflow as tf
from importlib import reload
from src import crtnet_models
reload(crtnet_models)

tf.keras.backend.clear_session()
model = crtnet_models.crt_net_original(
    n_classes=9,
    input_shape=(3000,12),
    n_vgg_blocks=5, # increased signal length so more CNN blocks to downsample (3000 / 2**5 -> 94)
    binary=True, # set this to true if using multilabel output (disables softmax and categorical cross entropy). CPSC can be multilabel.
    use_focal=True, # addresses significant class imbalance (enables focal cross entropy)
    metrics=['accuracy', 'f1_score'], # May be better to evaluate on F1 score if using early stopping
    d_model=128, # default feature dim size (d_ffn set to 2*d_model)
)
model.summary()
del model


Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 3000, 12)]        0         
                                                                 
 vgg_net (VGGNet)            (None, 94, 128)           448256    
                                                                 
 bi_gru (BiGRU)              (None, 94, 256)           198144    
                                                                 
 stacked_transformer_encode  (None, 94, 256)           1583104   
 r (StackedTransformerEncod                                      
 er)                                                             
                                                                 
 global_average_pooling1d (  (None, 256)               0         
 GlobalAveragePooling1D)                                         
                                                             

### CRT-Net CPSC Alternate (3000 samples, 12 leads, 9 classes)

The provided CRT-Net models.py has some alterations which may be the result of tuning the model:
- Leaky ReLU (alpha=0.3) activation instead of ReLU.
- Dropout (rate=0.2) after every VGG block and the BiGRU layer.
- Sine position encoding uses max position encoding of 2048, instead of default 10000
- Additional dropout between transformer encoders and global pooling
- Additional dense layer before output (units=4*n_classes, SeLU activation)

In [49]:
import tensorflow as tf
from importlib import reload
from src import crtnet_models
reload(crtnet_models)

tf.keras.backend.clear_session()
model = crtnet_models.crt_net_original_alt(
    n_classes=9,
    input_shape=(3000,12),
    n_vgg_blocks=5, # increased signal length so more CNN blocks to downsample (3000 / 2**5 -> 94)
    binary=True, # set this to true if using multilabel output (disables softmax and categorical cross entropy). CPSC can be multilabel.
    use_focal=True, # addresses significant class imbalance (enables focal cross entropy)
    metrics=['accuracy', 'f1_score'], # May be better to evaluate on F1 score if using early stopping
    d_model=128, # default feature dim size (d_ffn set to 2*d_model)
)
model.summary()
del model


Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 3000, 12)]        0         
                                                                 
 vgg_net (VGGNet)            (None, 94, 128)           2414976   
                                                                 
 bi_gru (BiGRU)              (None, 94, 256)           198144    
                                                                 
 stacked_transformer_encode  (None, 94, 256)           1583104   
 r (StackedTransformerEncod                                      
 er)                                                             
                                                                 
 dropout_6 (Dropout)         (None, 94, 256)           0         
                                                                 
 global_average_pooling1d (  (None, 256)               0     

### CRT-Net Modular (Ablation & Experimental Architecture)

Notes for ablation:

- It may be useful to evaluate the model for VGGNet, VGGNet+BiGRU, Transformer, VGGNet+Transformer, in addition to standard VGGNet+BiGRU+Transformer.
  - Testing BiGRU without VGGNet to downsample may be very slow for long signals, so we might ignore this.

Notes for experimental architecture:
- Options for cnn (squeezenet, cnnsvm) and att (rwkv) may be implemented and evaluated in comparison to the original architecutre.

In [5]:
import tensorflow as tf
from importlib import reload
from src import crtnet_models
reload(crtnet_models)

# cnn_type options:
#  None
#  'vggnet'
#  'squeezenet'  (not impl)
#  'cnnsvm'      (not impl)

# rnn_type options:
#  None
#  'bigru'

# att_type options:
#  None
#  'transformer'
#  'rwkv'        (not impl)
 
tf.keras.backend.clear_session()
model = crtnet_models.crt_net_modular(
    n_classes=9,
    input_shape=(3000,12),
    n_vgg_blocks=5,

    cnn_type='vggnet',
    rnn_type=None, # disabled rnn
    att_type=None, # disabled transformer

    alternate_arch=True, # alternate modifications like leaky relu enabled
    
    binary=True,
    use_focal=True,
)
model.summary()
del model


Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 3000, 12)]        0         
                                                                 
 vgg_net (VGGNet)            (None, 94, 128)           2414976   
                                                                 
 global_average_pooling1d (  (None, 128)               0         
 GlobalAveragePooling1D)                                         
                                                                 
 dense (Dense)               (None, 18)                2322      
                                                                 
 dense_1 (Dense)             (None, 9)                 171       
                                                                 
Total params: 2417469 (9.22 MB)
Trainable params: 2417469 (9.22 MB)
Non-trainable params: 0 (0.00 Byte)
_______________________

with RWKV (this takes a while to compile. 1 minute on my docker container)

In [5]:
import tensorflow as tf
from importlib import reload
from src import crtnet_models
reload(crtnet_models)
 
tf.keras.backend.clear_session()
model = crtnet_models.crt_net_modular(
    n_classes=9,
    input_shape=(3000,12),
    n_vgg_blocks=5,

    cnn_type='vggnet',
    rnn_type='bigru',
    att_type='rwkv', # using RWKV instead of stacked transformer
    rkwv_stack_multiplier=1, # RWKV allows deeper stacks. If it runs well, try setting it to 4 (for 4x4=16 total RWKV layers- 16 million params lol)

    alternate_arch=True, # alternate modifications like leaky relu enabled
    
    binary=True,
    use_focal=True,
)
model.summary()
del model


Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 3000, 12)]        0         
                                                                 
 vgg_net (VGGNet)            (None, 94, 128)           2414976   
                                                                 
 bi_gru (BiGRU)              (None, 94, 256)           198144    
                                                                 
 stacked_rwkv (StackedRWKV)  (None, 94, 256)           3420160   
                                                                 
 global_average_pooling1d (  (None, 256)               0         
 GlobalAveragePooling1D)                                         
                                                                 
 dense (Dense)               (None, 18)                4626      
                                                             