In [2]:
import wfdb
import numpy as np
import os
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.utils.class_weight import compute_class_weight
from tensorflow.keras import layers, models, callbacks, utils, Input

# --- PARAMETRI ---
RECORD_PATH = "C:\\Users\\Fefo\\Desktop\\TMR\\mit-bih-arrhythmia-database-1.0.0"
RECORDS = ['100', '101', '102', '103', '104', '105', '106', '107', '108', '109',
           '111', '112', '113', '114', '115', '116', '117', '118', '119', '121',
           '122', '123', '124', '200', '201', '202', '203', '205', '207', '208',
           '209', '210', '212', '213', '214', '215', '217', '219', '220', '221',
           '222', '223', '228', '230', '231', '232', '233', '234']
WINDOW_SIZE = 1080  # 3 secondi a 360 Hz
TARGET_CLASSES = {'N': 'Normale', 'A': 'AFIB', 'V': 'PVC', 'F': 'Altro'}

# --- ESTRAZIONE BATTITI ---
def extract_labeled_beats(record_list):
    segments, labels = [], []
    for rec in record_list:
        try:
            record = wfdb.rdrecord(f'{RECORD_PATH}/{rec}')
            ann = wfdb.rdann(f'{RECORD_PATH}/{rec}', 'atr')
            signal = record.p_signal[:, 0]
            for i, sym in enumerate(ann.symbol):
                if sym in TARGET_CLASSES:
                    idx = ann.sample[i]
                    if idx - WINDOW_SIZE//2 >= 0 and idx + WINDOW_SIZE//2 <= len(signal):
                        segment = signal[idx - WINDOW_SIZE//2: idx + WINDOW_SIZE//2]
                        segments.append(segment)
                        labels.append(TARGET_CLASSES[sym])
        except Exception as e:
            print(f"Errore nel record {rec}: {e}")
            continue
    return np.array(segments), np.array(labels)

# --- DATI ---
X, y = extract_labeled_beats(RECORDS)
print(f"\nTotale segmenti estratti: {len(X)}")
print(f"Distribuzione classi: {np.unique(y, return_counts=True)}")

# --- NORMALIZZAZIONE ---
X = (X - X.mean(axis=1, keepdims=True)) / (X.std(axis=1, keepdims=True) + 1e-8)
X = X[..., np.newaxis]

# --- ENCODING ---
encoder = LabelEncoder()
y_encoded = encoder.fit_transform(y)
y_categorical = utils.to_categorical(y_encoded)

# --- SPLIT ---
X_train, X_test, y_train, y_test, y_encoded_train, y_encoded_test = train_test_split(
    X, y_categorical, y_encoded, stratify=y_encoded, test_size=0.2, random_state=42
)

# --- CLASS WEIGHTS ---
weights = compute_class_weight(class_weight='balanced', classes=np.unique(y_encoded_train), y=y_encoded_train)
class_weights = dict(enumerate(weights))

# --- MODELLO CNN + BiLSTM ---
model = models.Sequential([
    Input(shape=(WINDOW_SIZE, 1)),
    layers.Conv1D(64, kernel_size=7, activation='relu', padding='same'),
    layers.BatchNormalization(),
    layers.MaxPooling1D(2),

    layers.Conv1D(128, kernel_size=5, activation='relu', padding='same'),
    layers.BatchNormalization(),
    layers.MaxPooling1D(2),

    layers.Bidirectional(layers.LSTM(64)),
    layers.Dropout(0.5),
    layers.Dense(64, activation='relu'),
    layers.Dense(4, activation='softmax')
])

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# --- CALLBACKS ---
early_stop = callbacks.EarlyStopping(patience=5, restore_best_weights=True)

# --- TRAINING ---
history = model.fit(
    X_train, y_train,
    epochs=30,
    batch_size=64,
    validation_data=(X_test, y_test),
    class_weight=class_weights,
    callbacks=[early_stop],
    verbose=1
)

# --- VALUTAZIONE ---
y_pred_prob = model.predict(X_test)
y_pred = np.argmax(y_pred_prob, axis=1)
y_true = np.argmax(y_test, axis=1)

# --- CONFUSION MATRIX ---
cm = confusion_matrix(y_true, y_pred)
labels = encoder.classes_

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=labels, yticklabels=labels)
plt.xlabel('Predetto')
plt.ylabel('Reale')
plt.title('Confusion Matrix - CNN + BiLSTM (finestra 3s)')
plt.tight_layout()
plt.show()

# --- CLASSIFICATION REPORT ---
report = classification_report(y_true, y_pred, target_names=labels, digits=3)
print("\nREPORT DI CLASSIFICAZIONE:\n")
print(report)



Totale segmenti estratti: 85392
Distribuzione classi: (array(['AFIB', 'Altro', 'Normale', 'PVC'], dtype='<U7'), array([ 2544,   802, 74923,  7123]))
Epoch 1/30
[1m1068/1068[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 122ms/step - accuracy: 0.4141 - loss: 1.1149

KeyboardInterrupt: 