# CNN Transfer Learning - PTB Dataset

Transfer learning from the best-performing CNN8 model trained on MIT-BIH dataset to PTB dataset for binary MI detection.

**Approach:**
- Load pretrained CNN8 model (trained on MIT-BIH)
- Freeze convolutional layers (first 4 residual blocks)
- Unfreeze last residual block for fine-tuning
- Add new classifier layers adapted for binary classification

Results contribute to Tables 9 and 10 in Rendering 2.

In [None]:
import numpy as np
import pandas as pd
from datetime import datetime
import matplotlib.pyplot as plt
%matplotlib inline 

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, BatchNormalization, Input, Conv1D, MaxPooling1D, Flatten, Add, ReLU, LSTM, Reshape, Concatenate, Activation
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping
from tensorflow.keras.models import load_model, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import Precision, Recall
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras.optimizers.schedules import ExponentialDecay
from tensorflow.keras.regularizers import l1_l2

from sklearn.metrics import accuracy_score, classification_report, f1_score
import numpy as np

from pathlib import Path
import re 

import pickle
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
print(tf.config.list_physical_devices('GPU'))  # should show []
from contextlib import redirect_stdout
import json
from collections import Counter

from src.visualization.visualization import plot_training_history 
from src.visualization.confusion_matrix import plot_confusion_matrix

In [None]:
SAMPLING_METHOD = "SMOTE"
REMOVE_OUTLIERS = False
OUTPUT_PATH = "models/PTB_04_02_dl_models/CNN8_TRANSFER/"
REPORTS_PATH = "reports/deep_learning/cnn8_transfer/" 
results_csv = REPORTS_PATH+"09_DL_model_optimization.csv"
model_names = ["cnn8_sm"]
models = {k: {} for k in model_names}

#import MIT data
X_ptb_train = pd.read_csv('data/processed/ptb/X_ptb_train.csv')
y_ptb_train = pd.read_csv('data/processed/ptb/y_ptb_train.csv')

X_ptb_train_sm = pd.read_csv('data/processed/ptb/X_ptb_train_sm.csv')
y_ptb_train_sm = pd.read_csv('data/processed/ptb/y_ptb_train_sm.csv')

X_ptb_val = pd.read_csv('data/processed/ptb/X_ptb_val.csv')
y_ptb_val = pd.read_csv('data/processed/ptb/y_ptb_val.csv')

X_ptb_test = pd.read_csv('data/processed/ptb/X_ptb_test.csv')
y_ptb_test = pd.read_csv('data/processed/ptb/y_ptb_test.csv')

display(X_ptb_train.shape)
display(y_ptb_train.shape)

display(X_ptb_train_sm.shape)
display(y_ptb_train_sm.shape)

display(X_ptb_val.shape)
display(y_ptb_val.shape)

display(X_ptb_test.shape)
display(y_ptb_test.shape)

# Reshape the data for 1D CNN
X_ptb_train_cnn = np.expand_dims(X_ptb_train, axis=2)
X_ptb_train_sm_cnn = np.expand_dims(X_ptb_train_sm, axis=2)
X_ptb_val_cnn = np.expand_dims(X_ptb_val, axis=2)
X_ptb_test_cnn = np.expand_dims(X_ptb_test, axis=2)

display(X_ptb_train_cnn.shape)
display(y_ptb_train.shape)

display(X_ptb_train_sm_cnn.shape)
display(y_ptb_train_sm.shape)

display(X_ptb_val_cnn.shape)
display(y_ptb_val.shape)

display(X_ptb_test_cnn.shape)
display(y_ptb_test.shape)


def parse_epoch_from_name(name, default_epochs=512):
    # Expect pattern like ..._epoch_12_...; returns int if found else default
    m = re.search(r"epoch_(\d+)", name)
    return int(m.group(1)) if m else default_epochs

def parse_val_loss_from_name(name):
    # Expect pattern like ..._valloss_0.1234.keras
    m = re.search(r"valloss_([0-9]+\.[0-9]+)", name)
    return float(m.group(1)) if m else np.nan

In [None]:
# load model
model_trained = load_model('models/MIT_02_03_dl_models/CNN/cnn8_sm_BS512_best.keras')
model_trained.summary()

# find last MaxPooling1D layer dynamically
last_pool_layer = None
for layer in model_trained.layers[::-1]:  
    if isinstance(layer, MaxPooling1D):
        last_pool_layer = layer
        break

if last_pool_layer is None:
    raise ValueError("Model has no MaxPooling1D layer!")

print("Using pooling layer:", last_pool_layer.name)

# build feature extractor
feature_extractor = Model(
    inputs=model_trained.input,
    outputs=last_pool_layer.output
)

In [None]:
#Freeze all convolutional layers in feature_extractor
for layer in feature_extractor.layers:
    layer.trainable = False

In [None]:
input_layer = model_trained.input
output = feature_extractor.output

In [None]:
#transfer 2, with added dropout
x = Flatten()(output)
x = Dense(32, activation='relu')(x)
x = Dense(32, activation='relu')(x)
x = Dropout(0.1)(x)
output_layer_2 = Dense(2, activation='softmax')(x)
transfer_model_2 = Model(inputs=input_layer, outputs=output_layer_2)

#transfer 3, with changed dropout
x = Flatten()(output)
x = Dense(32, activation='relu')(x)
x = Dense(32, activation='relu')(x)
x = Dropout(0.4)(x)
output_layer_3 = Dense(2, activation='softmax')(x)
transfer_model_3 = Model(inputs=input_layer, outputs=output_layer_3)

#transfer 4, with dropout and batch normalization 
x = Flatten()(output)
x = Dense(32)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Dropout(0.3)(x)

x = Dense(32)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Dropout(0.3)(x)

output_layer_4 = Dense(2, activation='softmax')(x)
transfer_model_4 = Model(inputs=input_layer, outputs=output_layer_4)


#transfer 5, with changed dropout and batch normalization 
x = Flatten()(output)
x = Dense(32)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Dropout(0.4)(x)

x = Dense(32)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Dropout(0.4)(x)

output_layer_5 = Dense(2, activation='softmax')(x)
transfer_model_5 = Model(inputs=input_layer, outputs=output_layer_5)

#transfer 6
x = Flatten()(output)
x = Dense(32, activation='relu')(x)
x = Dense(32, activation='relu')(x)
x = Dropout(0.3)(x)
output_layer_6 = Dense(2, activation='softmax')(x)
transfer_model_6 = Model(inputs=input_layer, outputs=output_layer_6)

#transfer 7 
x = Flatten()(output)
x = Dense(32)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Dropout(0.1)(x)

x = Dense(32)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Dropout(0.1)(x)

output_layer_7 = Dense(2, activation='softmax')(x)
transfer_model_7 = Model(inputs=input_layer, outputs=output_layer_7)

#transfer 8
x = Flatten()(output)
x = Dense(32)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Dropout(0.2)(x)

x = Dense(32)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Dropout(0.2)(x)

output_layer_8 = Dense(2, activation='softmax')(x)
transfer_model_8 = Model(inputs=input_layer, outputs=output_layer_8)

In [None]:
# Build new classifier for 2 class problem

# Input same shape as original
input_layer = Input(shape=(187, 1))
x = feature_extractor(input_layer, training=False)  # frozen convolutional base

#add transfer model here: transfer 6
x = Flatten()(x)
x = Dense(32, activation='relu')(x)
x = Dense(32, activation='relu')(x)
x = Dropout(0.3)(x)

output_layer_2class = Dense(2, activation='softmax')(x)

transfer_model_2class = Model(inputs=input_layer, outputs=output_layer_2class)

In [None]:
models = {
    "transfer_model_2": transfer_model_2,
    "transfer_model_3": transfer_model_3,
    "transfer_model_4": transfer_model_4,
    "transfer_model_5": transfer_model_5,
    "transfer_model_6": transfer_model_6,
    "transfer_model_7": transfer_model_7,
    "transfer_model_8": transfer_model_8,
    "transfer_model_2class": transfer_model_2class
}

In [None]:
def append_metrics_to_csv(metrics: dict, csv_path: str):
    """
    Appends or updates a row in the CSV based on a unique key:
    (model_name, batch_size, training_size, lr_start, lr_schedule).

    If a matching row exists, it is replaced with the new row.
    All other rows remain unchanged.
    """

    df_new = pd.DataFrame([metrics])

    # identifier columns for uniqueness
    key_cols = ["model_name", "batch_size", "training_size", "lr_start", "lr_schedule"]

    if os.path.exists(csv_path):
        df_existing = pd.read_csv(csv_path)

        # ensure column alignment
        all_cols = sorted(set(df_existing.columns).union(df_new.columns))
        df_existing = df_existing.reindex(columns=all_cols)
        df_new = df_new.reindex(columns=all_cols)

        # remove row(s) with matching key
        mask_match = np.ones(len(df_existing), dtype=bool)
        for col in key_cols:
            mask_match &= (df_existing[col] == df_new.iloc[0][col])

        df_existing = df_existing[~mask_match]  # drop matching row(s)

        # append new row
        df_combined = pd.concat([df_existing, df_new], ignore_index=True)

        # save
        df_combined.to_csv(csv_path, index=False)

    else:
        # create new CSV
        df_new.to_csv(csv_path, index=False)


In [None]:
for model_name, model in models.items():
    print("*"*80)
    print("*"*5,'\t',model_name,"*"*5)
    print("*"*80)

    BATCH_SIZE = 512
    EPOCHS = 500

    initial_learning_rate = 1e-3
    lr_schedule = ExponentialDecay(
        initial_learning_rate,
        decay_steps=1000,
        decay_rate=0.96)
    
    #Early stopping
    early_stop = EarlyStopping(
        monitor='val_loss',        # what to monitor 
        patience=20,               # how many epochs with no improvement before stopping
        restore_best_weights=True, 
        min_delta=0.001            #only stop if improvement < 0.001
    )

    #Compile when lr exp decay
    model.compile(
        optimizer=Adam(learning_rate=lr_schedule),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])


    #Define where and how to save the best model, note lr and bs
    checkpoint = ModelCheckpoint(
        filepath=f'{OUTPUT_PATH}{model_name}_BS{BATCH_SIZE}_best.keras',
        monitor='val_loss',            # metric to monitor
        mode='min',                    # minimize loss
        save_best_only=True,          
        verbose=1                      # print message when a model is saved
    )

    #Training
    history = model.fit(                      
        X_ptb_train_sm_cnn, #PTB data
        y_ptb_train_sm, #PTB data
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        validation_data=(X_ptb_val_cnn, y_ptb_val), #PTB data
        callbacks=[checkpoint, early_stop] 
    )

    # ----------------------------
    # 1) Basic training info
    # ----------------------------
    hist = history.history
    n_epochs = len(hist["loss"])

    # best epoch according to val_loss
    best_epoch = int(np.argmin(hist["val_loss"]))
    last_epoch = n_epochs - 1

    # ----------------------------
    # 2) Predictions for test
    # ----------------------------

    # !!! make sure to set restore_best_weights=True in early stopping
    # or load best model before predicting 
    y_prob = model.predict(X_ptb_test)
    y_pred = np.argmax(y_prob, axis=1)

    # ----------------------------
    # 3) F1-macro and per-class F1 on TEST
    # ----------------------------
    f1_macro = f1_score(y_ptb_test, y_pred, average="macro")
    f1_per_class = f1_score(y_ptb_test, y_pred, average=None)  # array (0–4)
    acc = accuracy_score(y_ptb_test, y_pred)

    # ----------------------------
    # 4) Build output dictionary
    # ----------------------------
    # flatten F1-per-class into separate columns
    f1_class_columns = {
        f"test_f1_class_{i}": float(score) 
        for i, score in enumerate(f1_per_class)
    }

    metrics = {
        "model_name": model_name,
        "batch_size": BATCH_SIZE,
        "training_size": X_ptb_train_sm_cnn.shape[0],
        "lr_start": initial_learning_rate,
        "lr_schedule": 'EXP_DECAY',
        "best_epoch": best_epoch,
        "last_epoch": last_epoch,
        
        # best epoch values (from history)
        "train_loss_best": float(hist["loss"][best_epoch]),
        "val_loss_best": float(hist["val_loss"][best_epoch]),
        "train_acc_best": float(hist["accuracy"][best_epoch]),
        "val_acc_best": float(hist["val_accuracy"][best_epoch]),

        # last epoch values
        "train_loss_last": float(hist["loss"][last_epoch]),
        "val_loss_last": float(hist["val_loss"][last_epoch]),
        "train_acc_last": float(hist["accuracy"][last_epoch]),
        "val_acc_last": float(hist["val_accuracy"][last_epoch]),

        # TEST metrics
        "test_f1_macro": float(f1_macro),
        "test_accuracy": float(acc)
    }

    # merge F1-per-class columns
    metrics.update(f1_class_columns)


    append_metrics_to_csv(metrics, csv_path=results_csv)


    with open(f"{OUTPUT_PATH}{model_name}_BS{BATCH_SIZE}_{initial_learning_rate}_full.pkl", "wb") as f: #change for model
        pickle.dump(history.history, f)

    fig_cm, ax_cm = plot_confusion_matrix(
        y_true=y_ptb_test,
        y_pred=y_pred,
        normalize=True,
        class_names=["1","2"],
        title=f"Confusion Matrix — {model_name}"
    )


    fig_cm.savefig(
        f"{REPORTS_PATH}/{model_name}_BS{BATCH_SIZE}_{initial_learning_rate}_confusion_matrix.png",
        dpi=300,
        bbox_inches="tight"
    )
    plt.close(fig_cm)

    plot_training_history(
        history=history,                     # raw history
        save_dir=REPORTS_PATH,               # where plots go
        prefix=f"{model_name}_BS{BATCH_SIZE}_{initial_learning_rate}_training_history"  # prefix for filenames
    )
    