In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import wfdb
import pywt
from scipy.signal import butter, filtfilt
from torch.optim.lr_scheduler import StepLR

# 데이터 로딩 및 전처리 함수
def load_all_ecg_data(base_path):
    all_segments = []
    all_labels = []

    for person_id in range(1, 90):  # Person_01부터 Person_89까지
        person_path = os.path.join(base_path, f'Person_{person_id:02}')
        for filename in os.listdir(person_path):
            if filename.endswith('.hea') and not filename.startswith('.'):
                record_name = filename.split('.')[0]
                record_path = os.path.join(person_path, record_name)
                try:
                    record = wfdb.rdrecord(record_path)
                    annotation = wfdb.rdann(record_path, 'atr')
                    ecg_signal = record.p_signal[:, 0]
                    ecg_signal = remove_baseline_drift(ecg_signal)
                    ecg_signal = bandpass_filter(ecg_signal)
                    segments = normalize_and_segment(ecg_signal, annotation.sample)
                    all_segments.extend(segments)
                    all_labels.extend([person_id - 1] * len(segments))
                except FileNotFoundError:
                    print(f"File not found: {record_path}")
                    continue

    return np.array(all_segments), np.array(all_labels)

def remove_baseline_drift(signal):
    coeff = pywt.wavedec(signal, 'db6', level=9)
    coeff[0] = np.zeros_like(coeff[0])
    return pywt.waverec(coeff, 'db6')

def bandpass_filter(signal, low_freq=0.5, high_freq=40, fs=500, order=5):
    nyquist = 0.5 * fs
    low = low_freq / nyquist
    high = high_freq / nyquist
    b, a = butter(order, [low, high], btype='band')
    return filtfilt(b, a, signal)

def normalize_and_segment(signal, r_peaks, window_size=180):
    scaler = MinMaxScaler(feature_range=(-1, 1))
    signal_normalized = scaler.fit_transform(signal.reshape(-1, 1)).flatten()
    segments = []
    half_window = window_size // 2
    for r_peak in r_peaks:
        start = max(r_peak - half_window, 0)
        end = min(r_peak + half_window, len(signal_normalized))
        if end - start == window_size:
            segments.append(signal_normalized[start:end])
    return np.array(segments)

# 복잡한 CNN 모델 정의
class ComplexCNNModel(nn.Module):
    def __init__(self, num_classes):
        super(ComplexCNNModel, self).__init__()
        self.conv1 = nn.Conv1d(1, 32, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv1d(32, 64, kernel_size=5, stride=1, padding=2)
        self.conv3 = nn.Conv1d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv1d(128, 256, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(256 * 11, 256)
        self.fc2 = nn.Linear(256, num_classes)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.pool(F.relu(self.conv4(x)))
        x = x.view(-1, 256 * 11)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# 데이터 준비 및 DataLoader 설정
base_path = '/kaggle/input/ecgid-database'
segments, labels = load_all_ecg_data(base_path)

# 데이터 분할 및 PyTorch 텐서로 변환
segments_train, segments_test, labels_train, labels_test = train_test_split(segments, labels, test_size=0.2, random_state=42)
segments_train_tensor = torch.tensor(segments_train, dtype=torch.float32).unsqueeze(1)
segments_test_tensor = torch.tensor(segments_test, dtype=torch.float32).unsqueeze(1)
labels_train_tensor = torch.tensor(labels_train, dtype=torch.long)
labels_test_tensor = torch.tensor(labels_test, dtype=torch.long)

train_dataset = TensorDataset(segments_train_tensor, labels_train_tensor)
test_dataset = TensorDataset(segments_test_tensor, labels_test_tensor)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 학습 파라미터 설정
num_classes = np.unique(labels).size
model = ComplexCNNModel(num_classes=num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

# 학습 및 평가 루프
def train_and_evaluate(model, train_loader, test_loader, optimizer, criterion, scheduler, num_epochs=50):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for data, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        scheduler.step()
        print(f'Epoch {epoch+1}, Average Training Loss: {total_loss / len(train_loader)}')

        # 평가 단계
        model.eval()
        with torch.no_grad():
            correct = 0
            total = 0
            for data, labels in test_loader:
                outputs = model(data)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        accuracy = 100 * correct / total
        print(f'Accuracy on Test Set: {accuracy}%')
        if accuracy >= 90:
            print("Reached target accuracy. Stopping training.")
            break

# 모델 학습 및 평가
train_and_evaluate(model, train_loader, test_loader, optimizer, criterion, scheduler)
