In [1]:
import wfdb
import numpy as np
from matplotlib import pyplot as plt
from helper_code import *
from utilities import *
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler


data_folder = "data/micro_code_sami/"
print('Finding the Challenge data...')
records = find_records(data_folder)
num_records = len(records)

if num_records == 0:
    raise FileNotFoundError('No data were provided.')

Finding the Challenge data...


In [2]:
utilObject = UtilityFunctions("cpu")
sampling_rate = 400
leads_idxs = {'I': 0, 'II': 1, 'III':2, 'aVR': 3, 'aVL':4, 'aVF':5, 'V1':6, 'V2':7, 'V3':8, 'V4':9, 'V5':10, 'V6':11}

code15_peaks_values_negative = []
code15_peaks_values_positive = []

sami_peaks_values_negative = []
sami_peaks_values_positive = []

other_peaks_values_negative = []
other_peaks_values_positive = []

# Iterate over the records.
for i in range(num_records):
    print(f"{i+1}/{num_records}")
    recording_file = os.path.join(data_folder, records[i])
    header_file = os.path.join(data_folder, get_header_file(records[i]))
    header = load_header(recording_file)
    try:
        current_label= get_label(header)
    except Exception as e:
        print("Failed to load label, assigning 0")
        current_label = 0

    try:
        signal, fields = load_signals(recording_file)
    except Exception as e:
        print(f"Skipping {header_file} and associated recording  because of {e}")
        continue

    recording_full = utilObject.load_and_equalize_recording(signal, fields, header_file, sampling_rate)


    
    leads_peaks_values = []
    for lead_name, idx in leads_idxs.items():
        signal, info = None, None
        try:
            rpeaks = nk.ecg_findpeaks(recording_full[idx], sampling_rate, method="pantompkins1985")
            signal, info =nk.ecg_delineate(recording_full[idx], rpeaks=rpeaks, sampling_rate=sampling_rate, method='dwt')
            info.update(rpeaks)
        except Exception as e:
            print(f"Exception in {header_file}") 
            
        points = {}
        for point in ['ECG_P_Peaks', 'ECG_Q_Peaks', 'ECG_R_Peaks', 'ECG_S_Peaks', 'ECG_T_Peaks']:
            if info is not None and info[point] is not None :
                indices = [indice for indice in info[point] if not np.isnan(indice)]
                values = recording_full[idx][indices]  # Original signal, not cleaned
                points[point] = values
            else:
                points[point] = np.array([])
        
        leads_peaks_values.append(points)


    
    splitted = header.split("\n")
    source_info = [x for x in splitted if "Source" in x]
    if len(source_info) > 0:
        if "SaMi-Trop" in source_info[0]:
            if current_label==0:
                sami_peaks_values_negtive.append(leads_peaks_values)
            else:
                sami_peaks_values_positive.append(leads_peaks_values)
        elif "CODE" in source_info[0]:
            if current_label==0:
                code15_peaks_values_negative.append(leads_peaks_values)
            else:
                code15_peaks_values_positive.append(leads_peaks_values)
        else:
            if current_label==0:
                other_peaks_values_negtive.append(leads_peaks_values)
            else:
                other_peaks_values_positive.append(leads_peaks_values)    
    

1/131
2/131
3/131
4/131
5/131
6/131
7/131
8/131
9/131
10/131
11/131
12/131
13/131
14/131
15/131
16/131
17/131
18/131
19/131
20/131
21/131
22/131
23/131
24/131
25/131
26/131
27/131
28/131
29/131
30/131
31/131
32/131
33/131
34/131
35/131
36/131
37/131
38/131
39/131
40/131
41/131
42/131
43/131
44/131
45/131
46/131
47/131
48/131
49/131
50/131
51/131
52/131
53/131
54/131
55/131
56/131


  warn(


Exception in data/micro_code_sami/562979.hea
57/131
58/131
59/131
60/131
61/131
62/131
63/131
64/131
65/131
66/131
67/131
68/131
69/131
70/131
71/131
72/131
73/131
74/131
75/131
76/131
77/131
78/131
79/131
80/131
81/131
82/131
83/131
84/131
85/131
86/131
87/131
88/131
89/131
90/131
91/131
92/131
93/131
94/131
95/131
96/131
97/131
98/131
99/131
100/131
101/131
102/131
103/131
104/131
105/131


  warn(


Exception in data/micro_code_sami/568225.hea
106/131
107/131
108/131
109/131
110/131
111/131
112/131
113/131
114/131
115/131
116/131
117/131
118/131
119/131
120/131
121/131
122/131
123/131
124/131
125/131
126/131
127/131
128/131
129/131
130/131
131/131


In [17]:
# ======================================
# Step 1: Organize data into one big DataFrame
# ======================================

save_path = "plots/data_analysis/"


def create_dataframe(all_peaks_values, label):
    rows = []
    for sample_idx, sample in enumerate(all_peaks_values):
        for lead_idx, lead_dict in enumerate(sample):
            for wave_name, values in lead_dict.items():
                for v in values:
                    rows.append({
                        "Sample": sample_idx,
                        "Lead": lead_idx,
                        "Wave": wave_name.replace('ECG_', '').replace('_Peaks', ''),
                        "Amplitude": v,
                        "Class": label
                    })
    return pd.DataFrame(rows)

# Build DataFrames for both classes
df_pos = create_dataframe(sami_peaks_values_positive, label=True)
df_neg = create_dataframe(code15_peaks_values_negative, label=False)

# Concatenate into one DataFrame
df = pd.concat([df_pos, df_neg], ignore_index=True)

print("Dataframe head:")
print(df.head())

# ======================================
# Step 2: Descriptive statistics
# ======================================
print("\n=== Descriptive statistics by Wave and Class ===")
desc_stats_wave_class = df.groupby(["Wave", "Class"])["Amplitude"].describe()
print(desc_stats_wave_class)

print("\n=== Descriptive statistics by Lead, Wave and Class ===")
desc_stats_lead_wave_class = df.groupby(["Lead", "Wave", "Class"])["Amplitude"].describe()
print(desc_stats_lead_wave_class)

# ======================================
# Step 3: Percentile analysis
# ======================================
percentiles = [1, 5, 25, 50, 75, 95, 99]

print("\n=== Percentiles by Wave and Class ===")
percentile_wave_class = df.groupby(["Wave", "Class"])["Amplitude"].quantile(np.array(percentiles)/100).unstack()
print(percentile_wave_class)

print("\n=== Percentiles by Lead, Wave and Class ===")
percentile_lead_wave_class = df.groupby(["Lead", "Wave", "Class"])["Amplitude"].quantile(np.array(percentiles)/100).unstack()
print(percentile_lead_wave_class)

# ======================================
# Step 4: EDA Global - Compare Classes
# ======================================

# 4a. Histograms by Class
plt.figure(figsize=(15, 8))
sns.histplot(data=df, x="Amplitude", hue="Class", element="step", kde=True, stat="density", common_norm=False, palette="Set1")
plt.title("Global Amplitude Distribution - Positive vs Negative Classes")
plt.xlabel("Amplitude")
plt.ylabel("Density")
plt.grid(True)
plt.savefig(os.path.join(save_path, "global_histogram_classes.png"), bbox_inches="tight")
plt.close()

# 4b. Boxplot by Class
plt.figure(figsize=(15, 6))
sns.boxplot(data=df, x="Wave", y="Amplitude", hue="Class", palette="Set2")
plt.title("Boxplot of Amplitudes by ECG Wave - Positive vs Negative Classes")
plt.grid(True)
plt.legend(title="Class", labels=["Negative", "Positive"])
plt.savefig(os.path.join(save_path, "boxplot_wave_classes.png"), bbox_inches="tight")
plt.close()

# 4c. Violin Plot by Class
plt.figure(figsize=(15, 6))
sns.violinplot(data=df, x="Wave", y="Amplitude", hue="Class", split=True, inner="quartile", palette="Set2")
plt.title("Violin Plot of Amplitudes by ECG Wave - Positive vs Negative Classes")
plt.grid(True)
plt.legend(title="Class", labels=["Negative", "Positive"])
plt.savefig(os.path.join(save_path, "violinplot_wave_classes.png"), bbox_inches="tight")
plt.close()

# 4d. Standardized Strip Plot
scaler = StandardScaler()
df['Amplitude_Standardized'] = scaler.fit_transform(df[['Amplitude']])

plt.figure(figsize=(15, 6))
sns.stripplot(data=df, x="Wave", y="Amplitude_Standardized", hue="Class", dodge=True, jitter=True, alpha=0.5, palette="Set1")
plt.title("Standardized Amplitudes - Stripplot by ECG Wave and Class")
plt.grid(True)
plt.legend(title="Class", labels=["Negative", "Positive"])
plt.savefig(os.path.join(save_path, "stripplot_wave_classes.png"), bbox_inches="tight")
plt.close()

# ======================================
# Step 5: Per-Lead EDA - Compare Classes
# ======================================

# 5a. Per-Lead Boxplots
plt.figure(figsize=(20, 12))
sns.boxplot(data=df, x="Wave", y="Amplitude", hue="Class", palette="Set3")
plt.title("Boxplot of Amplitudes by ECG Wave and Lead - Class Comparison")
plt.grid(True)
plt.legend(title="Class", labels=["Negative", "Positive"])
plt.savefig(os.path.join(save_path, "boxplot_wave_lead_classes.png"), bbox_inches="tight")
plt.close()

# 5b. Per-Lead Violin plots
g = sns.catplot(
    data=df, kind="violin", x="Wave", y="Amplitude", hue="Class", split=True,
    col="Lead", col_wrap=4, height=4, aspect=1, palette="Set2", sharey=False
)
g.fig.suptitle("Violin Plots by Lead - Positive vs Negative Classes", y=1.02)
plt.grid(True)
g.savefig(os.path.join(save_path, "violinplot_per_lead_classes.png"), bbox_inches="tight")
plt.close()

# 5c. Histograms per Lead
for lead in sorted(df['Lead'].unique()):
    plt.figure(figsize=(15, 6))
    sns.histplot(
        data=df[df['Lead'] == lead], x="Amplitude", hue="Class", element="step",
        kde=True, stat="density", common_norm=False, palette="Set1"
    )
    plt.title(f"Amplitude Distribution for Each Class - Lead {lead}")
    plt.xlabel("Amplitude")
    plt.ylabel("Density")
    plt.grid(True)
    plt.savefig(os.path.join(save_path, f"histogram_lead_{lead}.png"), bbox_inches="tight")
    plt.close()

# 5d. Standardized Scatter by Lead
for lead in sorted(df['Lead'].unique()):
    plt.figure(figsize=(15, 6))
    sns.stripplot(
        data=df[df['Lead'] == lead],
        x="Wave", y="Amplitude_Standardized", hue="Class",
        dodge=True, jitter=True, alpha=0.5, palette="Set1"
    )
    plt.title(f"Standardized Amplitudes - Scatter by Wave and Class - Lead {lead}")
    plt.grid(True)
    plt.legend(title="Class", labels=["Negative", "Positive"])
    plt.savefig(os.path.join(save_path, f"stripplot_lead_{lead}.png"), bbox_inches="tight")
    plt.close()

Dataframe head:
   Sample  Lead Wave  Amplitude  Class
0       0     0    P      0.322   True
1       0     0    P     -0.029   True
2       0     0    P      0.000   True
3       0     0    P      0.046   True
4       0     0    P      0.020   True

=== Descriptive statistics by Wave and Class ===
              count      mean       std        min      25%     50%      75%  \
Wave Class                                                                     
P    False  14151.0 -0.019337  1.591953 -15.184000 -0.29800  0.0300  0.23800   
     True    1289.0  0.182652  0.406772  -0.561000  0.01900  0.0850  0.19500   
Q    False  13223.0 -0.639975  1.781590 -17.410000 -1.16700 -0.2640 -0.01200   
     True    1186.0 -0.353951  0.662415  -5.443000 -0.36075 -0.1190 -0.02825   
R    False  14985.0 -0.132406  1.739601 -16.827000 -0.52800 -0.0610  0.24500   
     True    1331.0 -0.066716  0.582950  -5.287000 -0.23650 -0.0200  0.08300   
S    False  13213.0 -0.424673  1.692272 -20.436001 -0.83700 

-20.436