In [None]:
cd /tf/

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import xml.etree.ElementTree as ET
import numpy as np

In [None]:
ecg_path = '/content/drive/MyDrive/지정원_심전도/22619027.npy'
ecg_sig = np.load(ecg_path)

print(ecg_sig.shape)

In [None]:
def get_text(root, section_name, tag):
    sec = root.find(f".//{section_name}")
    if sec is not None:
        elem = sec.find(tag)
        return elem.text if elem is not None else None
    return None

In [None]:
tree = ET.parse(ecg_path)
root = tree.getroot()

PatientID          = get_text(root, "PatientDemographics", "PatientID")
gender             = get_text(root, "PatientDemographics","Gender")
DateofBirth        = get_text(root, "PatientDemographics","DateofBirth")
age                = get_text(root, "PatientDemographics","PatientAge")

print("환자 정보")
print(f"- Patient ID            : {PatientID}")
print(f"- Gender                : {gender}")
print(f"- Date of Birth         : {DateofBirth}")
print(f"- Age                   : {age}")


In [None]:
waveforms = root.findall("Waveform")
print("모든 Waveform 정보 (LeadData 제외):\n")

for idx, wf in enumerate(waveforms):
    #print(f"[Waveform {idx+1}]")
    for child in wf:
        if child.tag != "LeadData":
            #print(f"{child.tag}: {child.text}")
    #print("-" * 40)

In [None]:
rhythm_waveform = waveforms[1]

In [None]:
lead_ids = []
for child in rhythm_waveform:
    if child.tag == "LeadData":
        lead_id = child.findtext("LeadID")
        lead_ids.append(lead_id)

#print("Rhythm Waveform의 LeadID 목록:", lead_ids)

In [None]:
lead_2 = None
for lead_data in rhythm_waveform.findall("LeadData"):
    lead_id = lead_data.findtext("LeadID")
    if lead_id == "II":
        lead_2 = lead_data
        break

In [None]:
if lead_2 is not None:
    #print("II LeadData (WaveFormData 제외):")
    for elem in lead_2:
        if elem.tag != "WaveFormData":
            print(f"{elem.tag}: {elem.text}")
else:
    #print("II 가 존재하지 않습니다.")

In [None]:
import base64
import numpy as np

In [None]:
encoded_data = lead_2.findtext("WaveFormData")
encoded_data

In [None]:
if lead_2 is not None:
    encoded_data = lead_2.findtext("WaveFormData")
    decoded_bytes = base64.b64decode(encoded_data)

    # int16 배열로 복원
    waveform_raw = np.frombuffer(decoded_bytes, dtype=np.int16)

    # μV 단위로 변환
    scale = float(lead_2.findtext("LeadAmplitudeUnitsPerBit"))
    waveform_uv = waveform_raw * scale

    #print(f"복원된 샘플 수: {waveform_uv.shape[0]}")
    #print(f"μV 단위 신호 (앞 10개): {waveform_uv[:10]}")
else:
    #print("Rhythm → LeadID='II' 데이터가 없습니다.")

In [None]:
import matplotlib.pyplot as plt

In [None]:

fs = 500
duration_sec = len(waveform_uv) / fs
time = np.linspace(0, duration_sec, len(waveform_uv))

plt.figure(figsize=(15, 4))
plt.plot(time, waveform_uv, label="Lead II", color="black")
plt.title("ECG Waveform (Lead II, Rhythm)")
plt.xlabel("Time (seconds)")
plt.ylabel("Amplitude (μV)")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()


In [None]:
qrs_times = root.find("QRSTimesTypes")

In [None]:
if qrs_times is not None:
    qrs_list = []
    for qrs in qrs_times.findall("QRS"):
        number = int(qrs.findtext("Number"))
        qrs_type = int(qrs.findtext("Type"))
        time = int(qrs.findtext("Time"))
        qrs_list.append({
            "Number": number,
            "Type": qrs_type,
            "Time": time
        })

    global_rr = qrs_times.findtext("GlobalRR")
    qtrggr = qrs_times.findtext("QTRGGR")

    print("QRS 리스트:")
    for entry in qrs_list:
        #print(entry)
    #print(f"\nGlobalRR: {global_rr}")
    #print(f"QTRGGR: {qtrggr}")

In [None]:
def load_lead_II_waveform(xml_file_path):
    tree = ET.parse(xml_file_path)
    root = tree.getroot()

    # Rhythm Waveform 찾기
    rhythm_waveform = None
    for waveform in root.findall("Waveform"):
        if waveform.findtext("WaveformType") == "Rhythm":
            rhythm_waveform = waveform
            break

    if rhythm_waveform is None:
        raise ValueError("Rhythm 타입 Waveform을 찾을 수 없음.")

    # LeadID="II" 찾기
    lead_2 = None
    for lead_data in rhythm_waveform.findall("LeadData"):
        if lead_data.findtext("LeadID") == "II":
            lead_2 = lead_data
            break

    if lead_2 is None:
        raise ValueError("Rhythm → LeadID='II' 데이터를 찾을 수 없음.")

    # Waveform 복원
    encoded_data = lead_2.findtext("WaveFormData")
    decoded_bytes = base64.b64decode(encoded_data)
    waveform_raw = np.frombuffer(decoded_bytes, dtype=np.int16)

    # μV 단위로 변환
    scale = float(lead_2.findtext("LeadAmplitudeUnitsPerBit"))
    waveform_uv = waveform_raw * scale

    return waveform_uv, {
        "samples": waveform_uv.shape[0],
        "scale": scale
    }

In [None]:
fs = 500
duration_sec = len(ecg_data) / fs
time = np.linspace(0, duration_sec, len(ecg_data))


# 그래프 그리기
plt.figure(figsize=(15, 4))
plt.plot(time, ecg_data, label="Lead II", color="black")
plt.title("ECG Waveform (Lead II, Rhythm)")
plt.xlabel("Time (seconds)")
plt.ylabel("Amplitude (μV)")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()


In [None]:
def extract_qrs_sample_locations(xml_file_path):
    tree = ET.parse(xml_file_path)
    root = tree.getroot()

    qrs_sample_locs = []
    qrs_types = []
    scale_divisor=2

    qrs_times = root.find("QRSTimesTypes")
    if qrs_times is not None:
        for qrs in qrs_times.findall("QRS"):
            try:
                sample_idx = int(qrs.findtext("Time"))
                sample_idx = int(sample_idx / scale_divisor)
                qrs_sample_locs.append(sample_idx)
                qrs_type = qrs.findtext("Type")
                qrs_types.append(qrs_type)
            except (TypeError, ValueError):
                continue

    return qrs_sample_locs, qrs_types

In [None]:
qrs_locations, qrs_types = extract_qrs_sample_locations(xml_path)

print(f"검출된 QRS 샘플 인덱스 수: {len(qrs_locations)}")
print(f"샘플 인덱스: {qrs_locations}")


In [None]:

plt.figure(figsize=(15, 4))
plt.plot(time, ecg_data, label="ECG (Lead II)", color='black')

plt.scatter(
    np.array(qrs_locations) / fs,
    ecg_data[qrs_locations],
    color='green', marker='o', label="QRSTimesTypes S-point", zorder=5
)


plt.xlabel("Time (s)")
plt.ylabel("Amplitude (μV)")
plt.title("ECG Signal")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
import neurokit2 as nk

In [None]:
fs = 500
_, rpeaks = nk.ecg_peaks(ecg_data, sampling_rate=fs)
rpeak = rpeaks['ECG_R_Peaks']

In [None]:

plt.figure(figsize=(15, 4))
plt.plot(time, ecg_data, label="ECG (Lead II)", color='black')

plt.scatter(
    np.array(rpeak) / fs,
    ecg_data[rpeak],
    color='red', marker='o', label="QRSTimesTypes R-peaks", zorder=5
)

plt.scatter(
    np.array(qrs_locations) / fs,
    ecg_data[qrs_locations],
    color='green', marker='o', label="QRSTimesTypes S-point", zorder=5
)

plt.xlabel("Time (s)")
plt.ylabel("Amplitude (μV)")
plt.title("ECG Signal")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
def extract_beat_windows(ecg, locs, window_size=250):
    beats = []
    beat_indices = []
    for i, p in enumerate(locs):
        start = p - window_size
        end = p + window_size
        if start >= 0 and end < len(ecg):
            beat = ecg[start:end]
            beats.append(beat)
            beat_indices.append(i)
    return np.array(beats), beat_indices

In [None]:
beats, beat_indices = extract_beat_windows(ecg_data, qrs_locations)

In [None]:
window_size = 250
beat = beats[0]
x = np.arange(-window_size, window_size)

plt.figure(figsize=(10, 4))
plt.plot(x, beat)
plt.title("ECG Beat (Centered on R-peak)")
plt.xlabel("Sample (relative to R-peak)")
plt.ylabel("Amplitude (μV)")
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:

beat_labels = [qrs_types[i] for i in beat_indices]
beat_labels


In [None]:
import os
import json
import numpy as np
import base64
import neurokit2 as nk
import xml.etree.ElementTree as ET

In [None]:
def safe_parse_xml(xml_path):
    with open(xml_path, 'r', encoding='utf-8') as f:
        raw = f.read()

    # 간단한 응급조치: 잘못된 &를 정리 (주의: 모든 경우 커버는 아님)
    raw = raw.replace('&', '&amp;')

    return ET.ElementTree(ET.fromstring(raw))

In [None]:
def load_lead_II_waveform(xml_file_path):
    tree = safe_parse_xml(xml_file_path)
    root = tree.getroot()

    # 1. Waveform 목록을 모두 찾음
    waveforms = root.findall("Waveform")
    if not waveforms:
        raise ValueError("Waveform 요소를 찾을 수 없습니다.")

    # 2. Rhythm 타입 찾기 (없으면 첫 번째 사용)
    rhythm_waveform = None
    for waveform in waveforms:
        waveform_type = waveform.findtext("WaveformType")
        if waveform_type is None:
            # 타입이 없는 경우 단일 구조로 판단
            rhythm_waveform = waveform
            break
        if waveform_type.lower() == "rhythm":
            rhythm_waveform = waveform
            break

    if rhythm_waveform is None:
        # fallback: 첫 번째 waveform 사용
        rhythm_waveform = waveforms[0]

    # 3. LeadID="II" 찾기
    lead_2 = None
    for lead_data in rhythm_waveform.findall("LeadData"):
        if lead_data.findtext("LeadID") == "II":
            lead_2 = lead_data
            break

    if lead_2 is None:
        raise ValueError("LeadID='II' 데이터를 찾을 수 없습니다.")

    # 4. Waveform 데이터 디코딩
    encoded_data = lead_2.findtext("WaveFormData")
    if encoded_data is None:
        raise ValueError("WaveFormData가 없습니다.")
    decoded_bytes = base64.b64decode(encoded_data)
    waveform_raw = np.frombuffer(decoded_bytes, dtype=np.int16)

    # 5. μV 단위로 변환
    scale = float(lead_2.findtext("LeadAmplitudeUnitsPerBit", "1"))  # 기본값 1
    waveform_uv = waveform_raw * scale

    return waveform_uv, {
        "samples": waveform_uv.shape[0],
        "scale": scale
    }

In [None]:
def extract_qrs_sample_locations(xml_file_path):
    tree = ET.parse(xml_file_path)
    root = tree.getroot()

    qrs_sample_locs = []
    qrs_types = []
    scale_divisor=2

    qrs_times = root.find("QRSTimesTypes")
    if qrs_times is not None:
        for qrs in qrs_times.findall("QRS"):
            try:
                sample_idx = int(qrs.findtext("Time"))
                sample_idx = int(sample_idx / scale_divisor)
                qrs_sample_locs.append(sample_idx)
                qrs_type = qrs.findtext("Type")
                qrs_types.append(qrs_type)
            except (TypeError, ValueError):
                continue

    return qrs_sample_locs, qrs_types

In [None]:
def extract_beat_windows(ecg, locs, window_size=250):
    beats = []
    beat_indices = []
    for i, p in enumerate(locs):
        start = p - window_size
        end = p + window_size
        if start >= 0 and end < len(ecg):
            beat = ecg[start:end]
            beats.append(beat)
            beat_indices.append(i)
    return np.array(beats), beat_indices

In [None]:
def iter_ecg_beats_from_folder_0(top_folder, window_size=250):
    for root, dirs, files in os.walk(top_folder):
        for file in files:
            if file.endswith(".xml"):
                xml_path = os.path.join(root, file)
                try:
                    ecg, _ = load_lead_II_waveform(xml_path)
                    r_locs, qrs_types = extract_qrs_sample_locations(xml_path)
                    beats, beat_indices = extract_beat_windows(ecg, r_locs, window_size=window_size)
                    beat_labels = [qrs_types[i] for i in beat_indices]
                    for beat, label in zip(beats, beat_labels):
                        yield beat, label
                except Exception as e:
                    print(f"오류 발생 [{xml_path}]: {e}")

In [None]:
def iter_ecg_beats_from_folder_2(xml_folder, window_size=250):
    for root, dirs, files in os.walk(xml_folder):
        for file in files:
            if file.endswith(".xml"):
                xml_path = os.path.join(root, file)
                json_path = xml_path.replace(
                    f"{os.sep}원천데이터{os.sep}",
                    f"{os.sep}라벨링데이터{os.sep}"
                ).replace(".xml", ".json")

                try:
                    ecg, _ = load_lead_II_waveform(xml_path)

                    # JSON 라벨 추출
                    label_type = None
                    if os.path.exists(json_path):
                        with open(json_path, "r", encoding="utf-8") as f:
                            data = json.load(f)
                            label_ids = [item.get("label_id") for item in data.get("labels", [])]
                            label_type = label_ids[0] if label_ids else None

                    # R-peak 검출
                    signals, info = nk.ecg_peaks(ecg, sampling_rate=500)
                    r_locs = info["ECG_R_Peaks"]

                    # Beat window 추출
                    beats, beat_indices = extract_beat_windows(ecg, r_locs, window_size=window_size)

                    # 라벨 부여
                    beat_labels = [label_type] * len(beats)

                    for beat, label in zip(beats, beat_labels):
                        yield beat, label

                except Exception as e:
                    print(f"오류 발생 [{xml_path}]: {e}")

In [None]:
def save_batch(beats, labels, batch_idx, save_dir="./78_data"):
    os.makedirs(save_dir, exist_ok=True)

    np.save(os.path.join(save_dir, f"beats_batch_{batch_idx}.npy"), np.array(beats))
    np.save(os.path.join(save_dir, f"labels_batch_{batch_idx}.npy"), np.array(labels))

In [None]:
batch_size = 1000
batch_idx = 225
buffer_beats = []
buffer_labels = []

In [None]:
for beat, label in iter_ecg_beats_from_folder_0("KYUH_data/78/064.심장질환 진단을 위한 심전도 데이터/01.데이터/1.Training/원천데이터/0"):
    buffer_beats.append(beat)
    buffer_labels.append(label)

    if len(buffer_beats) >= batch_size:
        save_batch(buffer_beats, buffer_labels, batch_idx)
        buffer_beats.clear()
        buffer_labels.clear()
        batch_idx += 1

In [None]:
for beat, label in iter_ecg_beats_from_folder_2("KYUH_data/78/064.심장질환 진단을 위한 심전도 데이터/01.데이터/1.Training/원천데이터/2"):
    buffer_beats.append(beat)
    buffer_labels.append(label)

    if len(buffer_beats) >= batch_size:
        save_batch(buffer_beats, buffer_labels, batch_idx)
        buffer_beats.clear()
        buffer_labels.clear()
        batch_idx += 1

In [None]:
beats_files = sorted(glob.glob("export_data/78_data/beats_batch_*.npy"))
labels_files = sorted(glob.glob("export_data/78_data/labels_batch_*.npy"))

beats = np.concatenate([np.load(f) for f in beats_files], axis=0)
labels = np.concatenate([np.load(f, allow_pickle=True) for f in labels_files], axis=0)

In [None]:
X = np.array(beats).astype(np.float32)
X = np.expand_dims(X, axis=-1)

In [None]:
labels = np.array(labels).astype(np.int64)

unique, counts = np.unique(labels, return_counts=True)
print("클래스 분포:", dict(zip(unique, counts)))

In [None]:
y = [0 if i == 0 else 1 for i in labels]
y = np.array(y)

In [None]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout

def build_lstm_model(input_shape, num_classes):
    model = Sequential()
    model.add(LSTM(64, input_shape=input_shape, return_sequences=False))
    model.add(Dropout(0.5))
    model.add(Dense(64, activation='relu'))
    model.add(Dense(num_classes, activation='softmax'))

    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    return model

In [None]:
from sklearn.model_selection import train_test_split

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

In [None]:
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

In [None]:
callbacks = [
    EarlyStopping(
        monitor='val_loss',
        patience=5,
        restore_best_weights=True,
        verbose=1
    ),
    ModelCheckpoint(
        filepath='best_lstm_model.h5',
        monitor='val_loss',
        save_best_only=True,
        verbose=1
    )
]

In [None]:
model = build_lstm_model(input_shape=(X.shape[1], 1), num_classes=1)
model.summary()

In [None]:
model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=10,
    batch_size=64,
    callbacks=callbacks
)

In [None]:
from tensorflow.keras.models import load_model

In [None]:
best_model = load_model('best_lstm_model.h5')
loss, acc = best_model.evaluate(X_val, y_val)
print(f"Best Model Accuracy: {acc:.4f}")

In [None]:
y_pred_prob = model.predict(X_val)
y_pred = (y_pred_prob >= 0.5).astype(int).flatten()

In [None]:
from sklearn.metrics import classification_report, confusion_matrix

print("Classification Report:")
print(classification_report(y_val, y_pred, target_names=["Normal", "Arrhythmia"]))

print("Confusion Matrix:")
print(confusion_matrix(y_val, y_pred))

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

def plot_confusion_matrix(y_true, y_pred, labels=["Normal", "Arrhythmia"]):
    cm = confusion_matrix(y_true, y_pred)

    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=labels, yticklabels=labels)

    plt.title("Confusion Matrix")
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.tight_layout()
    plt.show()

In [None]:
plot_confusion_matrix(y_val, y_pred)

In [None]:
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt

fpr, tpr, _ = roc_curve(y_val, y_pred_prob)
roc_auc = auc(fpr, tpr)


plt.figure()
plt.plot(fpr, tpr, label=f'ROC curve (AUC = {roc_auc:.4f})')
plt.plot([0, 1], [0, 1], 'k--')  # 대각선
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver Operating Characteristic (ROC)")
plt.legend(loc="lower right")
plt.grid(True)
plt.show()