In [None]:
import os
import wfdb
import math
import numpy as np
import pandas as pd
from scipy import interpolate
from scipy import signal
import neurokit2 as nk
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import StandardScaler

In [ ]:
class MITBIHProcessor:
    def __init__(self, data_path, output_path):
        """
    
        """
        self.data_path = data_path
        self.output_path = output_path
        self.records = None
        self.annotations = None
        self.sampling_rate = 360  

        if not os.path.exists(output_path):
            os.makedirs(output_path)
            os.makedirs(os.path.join(output_path, 'features'))
            os.makedirs(os.path.join(output_path, 'labels'))
    
    def load_record_list(self):
        self.records = [
            '100', '101', '102', '103', '104', '105', '106', '107', '108', '109',
            '111', '112', '113', '114', '115', '116', '117', '118', '119', '121',
            '122', '123', '124', '200', '201', '202', '203', '205', '207', '208',
            '209', '210', '212', '213', '214', '215', '217', '219', '220', '221',
            '222', '223', '228', '230', '231', '232', '233', '234'
        ]
        return self.records
    
    def load_annotations(self):
        annotations_info = []
        
        for record in self.records:
            try:
                annotation = wfdb.rdann(os.path.join(self.data_path, record), 'atr')
                
                unique_symbols = set(annotation.symbol)
                beat_counts = {}
                
                for symbol in unique_symbols:
                    if symbol not in ['', '~', '|', ']', '[', 'p', 't', 'u', '`', '^', 'Q']:
                        count = np.sum(np.array(annotation.symbol) == symbol)
                        beat_counts[symbol] = count
                
                annotations_info.append({
                    'record': record,
                    'beat_types': beat_counts,
                    'total_beats': len(annotation.symbol),
                    'sampling_rate': annotation.fs,
                    'duration_minutes': len(annotation.symbol) / (annotation.fs * 60) * 30  # 估计时长
                })
                
            except Exception as e:
                print(f"Error loading annotation for record {record}: {e}")
                continue
        
        self.annotations = pd.DataFrame(annotations_info)
        return self.annotations
    
    def map_beat_to_diagnosis(self, beat_symbol):

        if beat_symbol in ['N', 'L', 'R', 'e', 'j']:
            return 'NORM'  
        elif beat_symbol in ['A', 'a', 'J', 'S']:
            return 'SVPB'  
        elif beat_symbol in ['V', 'E']:
            return 'VPC'
        elif beat_symbol in ['F']:
            return 'FUSION'
        elif beat_symbol in ['/', 'f', 'Q']:
            return 'UNKNOWN'  
        else:
            return 'UNKNOWN'
    
    def resample_signal(self, signal, original_rate, target_rate=250):
        duration = len(signal) / original_rate
        num_samples_target = int(duration * target_rate)
        
        t_original = np.linspace(0, duration, len(signal))
        t_target = np.linspace(0, duration, num_samples_target)
        
        f = interpolate.interp1d(t_original, signal, kind='linear')
        resampled_signal = f(t_target)
        
        return resampled_signal
    
    def normalize_signal(self, signal):

        scaler = StandardScaler()
        normalized = scaler.fit_transform(signal.reshape(-1, 1)).flatten()
        return normalized
    
    def extract_heartbeat_segments(self, record_name, segment_length=250, lead=0):

        try:
            # 读取记录和注释
            record_path = os.path.join(self.data_path, record_name)
            signals, fields = wfdb.rdsamp(record_path)
            annotation = wfdb.rdann(record_path, 'atr')

            if lead >= signals.shape[1]:
                lead = 0 
                
            ecg_signal = signals[:, lead]

            if fields['fs'] != 250:
                ecg_signal = self.resample_signal(ecg_signal, fields['fs'], 250)
            
            ecg_signal = self.normalize_signal(ecg_signal)

            segments = []
            labels = []
            valid_beats = 0
            
            for i, sample in enumerate(annotation.sample):
                if sample >= segment_length//2 and sample + segment_length//2 <= len(ecg_signal):
                    start_idx = sample - segment_length//2
                    end_idx = sample + segment_length//2
                    
                    segment = ecg_signal[start_idx:end_idx]
                    
                    if len(segment) == segment_length:
                        segments.append(segment)
                        beat_label = self.map_beat_to_diagnosis(annotation.symbol[i])
                        labels.append(beat_label)
                        valid_beats += 1
            
            segments_array = np.array(segments)
            labels_array = np.array(labels)
            
            print(f"Record {record_name}: Extracted {valid_beats} valid heartbeat segments")
            
            return segments_array, labels_array
            
        except Exception as e:
            print(f"Error processing record {record_name}: {e}")
            return None, None
    
    def create_patient_wise_dataset(self, beats_per_patient=1000):

        print("Creating patient-wise dataset...")
        
        patient_data = {}
        patient_labels = {}
        
        for record in self.records:
            print(f"Processing record: {record}")

            segments, labels = self.extract_heartbeat_segments(record, lead=0)
            
            if segments is not None and len(segments) > 0:

                patient_id = f"MIT_{record}"

                if len(segments) > beats_per_patient:
                    indices = np.random.choice(len(segments), beats_per_patient, replace=False)
                    segments = segments[indices]
                    labels = labels[indices]
                elif len(segments) < beats_per_patient:
                    repeat_times = beats_per_patient // len(segments)
                    remainder = beats_per_patient % len(segments)
                    
                    segments_repeated = np.tile(segments, (repeat_times, 1))
                    labels_repeated = np.tile(labels, repeat_times)
                    
                    if remainder > 0:
                        indices = np.random.choice(len(segments), remainder, replace=False)
                        segments = np.vstack([segments_repeated, segments[indices]])
                        labels = np.concatenate([labels_repeated, labels[indices]])
                    else:
                        segments = segments_repeated
                        labels = labels_repeated
                
                segments = segments.reshape(-1, segments.shape[-1], 1)
                
                patient_data[patient_id] = segments
                patient_labels[patient_id] = labels
                
                np.save(os.path.join(self.output_path, 'features', f'feature_{patient_id}.npy'), segments)
        
        self.create_label_file(patient_labels)
        
        return patient_data, patient_labels
    
    def create_label_file(self, patient_labels):

        label_data = []
        
        for patient_id, labels in patient_labels.items():

            unique, counts = np.unique(labels, return_counts=True)
            primary_diagnosis = unique[np.argmax(counts)]

            diagnosis_map = {'NORM': 0, 'SVPB': 1, 'VPC': 2, 'FUSION': 3, 'UNKNOWN': 4}
            label_num = diagnosis_map.get(primary_diagnosis, 4)

            record_num = int(patient_id.split('_')[1])
            
            label_data.append([label_num, record_num])
        
        label_array = np.array(label_data)
        np.save(os.path.join(self.output_path, 'labels', 'label.npy'), label_array)
        
        label_df = pd.DataFrame(label_array, columns=['label', 'patient_id'])
        diagnosis_map_reverse = {v: k for k, v in diagnosis_map.items()}
        label_df['diagnosis'] = label_df['label'].map(diagnosis_map_reverse)
        label_df.to_csv(os.path.join(self.output_path, 'labels', 'label_info.csv'), index=False)
        
        print(f"Label distribution:\n{label_df['diagnosis'].value_counts()}")
    
    def create_10s_segments_dataset(self, segment_length=2500):
        """
        
        Args:
            segment_length: 10秒片段的长度（250Hz * 10s = 2500个点）
        """
        print("Creating 10-second segments dataset...")
        
        segment_data = {}
        segment_labels = {}
        segment_id = 0
        
        for record in self.records:
            print(f"Processing record: {record}")
            
            try:
                record_path = os.path.join(self.data_path, record)
                signals, fields = wfdb.rdsamp(record_path)
                
                if signals.shape[1] >= 1:
                    ecg_signal = signals[:, 0]
                    
                    if fields['fs'] != 250:
                        ecg_signal = self.resample_signal(ecg_signal, fields['fs'], 250)
                    ecg_signal = self.normalize_signal(ecg_signal)
                    
                    num_segments = len(ecg_signal) // segment_length
                    
                    for i in range(num_segments):
                        start_idx = i * segment_length
                        end_idx = (i + 1) * segment_length
                        
                        if end_idx <= len(ecg_signal):
                            segment = ecg_signal[start_idx:end_idx]
                            segment_id_str = f"MIT_{record}_seg_{i}"
                            
                            segment_data[segment_id_str] = segment.reshape(1, -1, 1)
                            segment_labels[segment_id_str] = f"MIT_{record}"
                            
                            np.save(os.path.join(self.output_path, 'features', f'feature_{segment_id_str}.npy'), 
                                   segment.reshape(1, -1, 1))
                            
                            segment_id += 1
                            
            except Exception as e:
                print(f"Error processing record {record} for 10s segments: {e}")
                continue
        
        print(f"Created {segment_id} 10-second segments")
        return segment_data, segment_labels
    
    def generate_dataset_summary(self):
        
        summary = {
            'total_records': len(self.records),
            'sampling_rate': 250,  
            'segment_length': 250, 
            'leads_used': 1,  # 使用MLII导联
            'diagnosis_categories': ['NORM', 'SVPB', 'VPC', 'FUSION', 'UNKNOWN']
        }
        
    
        summary_df = pd.DataFrame([summary])
        summary_df.to_csv(os.path.join(self.output_path, 'dataset_summary.csv'), index=False)
        
        return summary

In [ ]:
def main():
    mit_bih_path = r'C:\Users\28199\Desktop\mit-bih-arrhythmia-database-1.0.0' 
    output_path = r'C:\Users\28199\Desktop\mit_bih_processed'  

    processor = MITBIHProcessor(mit_bih_path, output_path)

    records = processor.load_record_list()
    print(f"Loaded {len(records)} MIT-BIH records")

    annotations = processor.load_annotations()
    print("Annotations summary:")
    print(annotations.head())

    patient_data, patient_labels = processor.create_patient_wise_dataset(beats_per_patient=1000)
    
    # segment_data, segment_labels = processor.create_10s_segments_dataset()
    
    summary = processor.generate_dataset_summary()
    print("Dataset summary:")
    print(summary)
    
    print(f"Processing completed! Data saved to: {output_path}")

if __name__ == "__main__":
    main()