In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import numpy as np
import wfdb
import pywt
from scipy.signal import butter, filtfilt

# 데이터 로딩 및 전처리 함수
def load_and_preprocess_ecg_data(base_path):
    all_segments = []
    all_labels = []

    for person_id in range(1, 90):  # Person_01부터 Person_89까지
        person_path = os.path.join(base_path, f'Person_{person_id:02}')
        if not os.path.exists(person_path):
            print(f"Directory {person_path} does not exist. Skipping.")
            continue
        for filename in os.listdir(person_path):
            if filename.endswith('.hea') and not filename.startswith('.'):
                record_name = filename.split('.')[0]
                record_path = os.path.join(person_path, record_name)
                try:
                    record = wfdb.rdrecord(record_path)
                    annotation = wfdb.rdann(record_path, 'atr')
                    ecg_signal = record.p_signal[:, 0]
                    ecg_signal = remove_baseline_drift(ecg_signal)
                    ecg_signal = bandpass_filter(ecg_signal)
                    segments = normalize_and_segment(ecg_signal, annotation.sample)
                    all_segments.extend(segments)
                    all_labels.extend([person_id - 1] * len(segments))
                except Exception as e:
                    print(f"Error processing {record_path}: {e}")
                    continue

    return np.array(all_segments), np.array(all_labels)

def remove_baseline_drift(signal):
    coeff = pywt.wavedec(signal, 'db6', level=9)
    coeff[0] = np.zeros_like(coeff[0])
    return pywt.waverec(coeff, 'db6')

def bandpass_filter(signal, low_freq=0.5, high_freq=40, fs=500, order=5):
    nyquist = 0.5 * fs
    low = low_freq / nyquist
    high = high_freq / nyquist
    b, a = butter(order, [low, high], btype='band')
    return filtfilt(b, a, signal)

def normalize_and_segment(signal, r_peaks, window_size=256):
    scaler = MinMaxScaler(feature_range=(-1, 1))
    signal_normalized = scaler.fit_transform(signal.reshape(-1, 1)).flatten()
    segments = []
    half_window = window_size // 2
    for r_peak in r_peaks:
        start = max(r_peak - half_window, 0)
        end = min(r_peak + half_window, len(signal_normalized))
        if end - start == window_size:
            segments.append(signal_normalized[start:end])
    return np.array(segments)

# Transformer Encoder 모듈
class TransformerEncoder(nn.Module):
    def __init__(self, feature_size, num_heads, num_layers, dropout_rate):
        super(TransformerEncoder, self).__init__()
        encoder_layers = nn.TransformerEncoderLayer(
            d_model=feature_size,
            nhead=num_heads,
            dropout=dropout_rate,
            batch_first=True  # batch_first=True를 설정하여 배치 크기가 첫 번째 차원이 되도록 함
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)

    def forward(self, src):
        output = self.transformer_encoder(src)
        return output

# 하이브리드 CNN-GRU-Transformer 모델 정의
class HybridCNNGRUTransformer(nn.Module):
    def __init__(self, num_classes):
        super(HybridCNNGRUTransformer, self).__init__()
        self.conv1 = nn.Conv1d(1, 64, kernel_size=5, stride=1, padding=2)
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
        self.gru = nn.GRU(64, 100, batch_first=True, num_layers=2, dropout=0.5, bidirectional=True)  # GRU로 변경
        self.transformer = TransformerEncoder(feature_size=200, num_heads=4, num_layers=2, dropout_rate=0.1)
        self.fc = nn.Linear(200, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = x.permute(0, 2, 1)  # Reshape for GRU
        x, _ = self.gru(x)
        x = self.transformer(x)
        x = self.fc(x[:, -1, :])  # Use the last timestep
        return x

# 데이터 준비 및 DataLoader 설정
base_path = '/kaggle/input/ecgid-database'
segments, labels = load_and_preprocess_ecg_data(base_path)
if len(segments) == 0 or len(labels) == 0:
    raise ValueError("No data found. Please check the data path and files.")

segments_train, segments_test, labels_train, labels_test = train_test_split(segments, labels, test_size=0.2, random_state=42)

segments_train_tensor = torch.tensor(segments_train, dtype=torch.float32).unsqueeze(1)
segments_test_tensor = torch.tensor(segments_test, dtype=torch.float32).unsqueeze(1)
labels_train_tensor = torch.tensor(labels_train, dtype=torch.long)
labels_test_tensor = torch.tensor(labels_test, dtype=torch.long)

train_dataset = TensorDataset(segments_train_tensor, labels_train_tensor)
test_dataset = TensorDataset(segments_test_tensor, labels_test_tensor)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Training parameters
num_classes = np.unique(labels).size
model = HybridCNNGRUTransformer(num_classes=num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)

# Training and evaluation loop
def train_and_evaluate(model, train_loader, test_loader, optimizer, criterion, scheduler, num_epochs=100):
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for data, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        scheduler.step()

        print(f'Epoch {epoch+1}, Average Training Loss: {total_loss / len(train_loader)}')

        model.eval()
        with torch.no_grad():
            correct = 0
            total = 0
            for data, labels in test_loader:
                outputs = model(data)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            print(f'Accuracy on Test Set: {100 * correct / total}%')

# Train and evaluate the model
train_and_evaluate(model, train_loader, test_loader, optimizer, criterion, scheduler)
