In [11]:
"""
Written by, 
Sriram Ravindran, sriram@ucsd.edu

Original paper - https://arxiv.org/abs/1611.08024

Please reach out to me if you spot an error.
"""

'\nWritten by, \nSriram Ravindran, sriram@ucsd.edu\n\nOriginal paper - https://arxiv.org/abs/1611.08024\n\nPlease reach out to me if you spot an error.\n'

In [12]:
# %pip install scikit-learn numpy scipy matplotlib seaborn pandas
# %pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

In [None]:
from sklearn.metrics import confusion_matrix, balanced_accuracy_score, roc_auc_score, precision_score, recall_score, accuracy_score
from torch.autograd import Variable
import copy
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd 
import pickle
import random
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib
matplotlib.use('Agg') # [중요] GUI 백엔드가 필요없는 환경에서 사용 (예: 서버)
import matplotlib.pyplot as plt
import gc # 가비지 컬렉션 모듈 추가

<p>Here's the description from the paper</p>
<img src="EEGNet.png" style="width: 700px; float:left;">

In [None]:


class EEGNet(nn.Module):
    def __init__(self, n_channels=53, n_timepoints=448, n_classes=6):
        super(EEGNet, self).__init__()
        self.T = n_timepoints
        self.C = n_channels
        self.n_classes = n_classes
        
        # [Layer 1] Spatial Conv (공간 필터)
        # n_channels개 채널의 정보를 하나로 압축
        # Input: (Batch, 1, T, C) -> Output: (Batch, 16, T, 1)
        self.conv1 = nn.Conv2d(1, 16, (1, n_channels), padding=0)
        self.batchnorm1 = nn.BatchNorm2d(16, False)
        
        # [Layer 2] Temporal Conv (시간 필터)
        self.padding1 = nn.ZeroPad2d((16, 17, 0, 1)) 
        self.conv2 = nn.Conv2d(1, 4, (2, 32))
        self.batchnorm2 = nn.BatchNorm2d(4, False)
        self.pooling2 = nn.MaxPool2d(2, 4)
        
        # [Layer 3] Depthwise/Separable Conv
        self.padding2 = nn.ZeroPad2d((2, 1, 4, 3))
        self.conv3 = nn.Conv2d(4, 4, (8, 4))
        self.batchnorm3 = nn.BatchNorm2d(4, False)
        self.pooling3 = nn.MaxPool2d((2, 4))
        
        # FC Layer 차원 계산
        self._calculate_fc_input_size(n_channels, n_timepoints)
        
        # FC Layer - n_classes 출력 (5-class classification)
        self.fc1 = nn.Linear(self.fc_input_size, n_classes)
        
    def _calculate_fc_input_size(self, n_channels, n_timepoints):
        """Forward pass를 통해 FC layer 입력 크기 계산"""
        with torch.no_grad():
            x = torch.zeros(1, 1, n_timepoints, n_channels)
            x = F.elu(nn.Conv2d(1, 16, (1, n_channels), padding=0)(x))
            x = x.permute(0, 3, 1, 2)
            x = nn.ZeroPad2d((16, 17, 0, 1))(x)
            x = F.elu(nn.Conv2d(1, 4, (2, 32))(x))
            x = nn.MaxPool2d(2, 4)(x)
            x = nn.ZeroPad2d((2, 1, 4, 3))(x)
            x = F.elu(nn.Conv2d(4, 4, (8, 4))(x))
            x = nn.MaxPool2d((2, 4))(x)
            self.fc_input_size = x.numel()

    def forward(self, x):
        # Layer 1: Spatial Learning
        x = F.elu(self.conv1(x))
        x = self.batchnorm1(x)
        x = F.dropout(x, 0.25, training=self.training)
        x = x.permute(0, 3, 1, 2)
        
        # Layer 2: Temporal Learning
        x = self.padding1(x)
        x = F.elu(self.conv2(x))
        x = self.batchnorm2(x)
        x = F.dropout(x, 0.25, training=self.training)
        x = self.pooling2(x)
        
        # Layer 3: High-level Feature Learning
        x = self.padding2(x)
        x = F.elu(self.conv3(x))
        x = self.batchnorm3(x)
        x = F.dropout(x, 0.25, training=self.training)
        x = self.pooling3(x)
        
        # FC Layer
        x = x.reshape(x.size(0), -1)
        x = self.fc1(x)  # CrossEntropyLoss에서 softmax 처리
        return x

# 모델 초기화 (53채널, 448 timepoints, 5 classes)
net = EEGNet(n_channels=53, n_timepoints=448, n_classes=6).cuda(0)
print(f"FC input size: {net.fc_input_size}")

# 테스트
test_input = torch.randn(1, 1, 448, 53).cuda(0)
test_output = net(test_input)
print(f"Input shape: {test_input.shape}")
print(f"Output shape: {test_output.shape}")

# Loss: CrossEntropyLoss (5-class classification)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters())

FC input size: 224
Input shape: torch.Size([1, 1, 448, 53])
Output shape: torch.Size([1, 5])


#### Evaluate function returns values of different criteria like accuracy, precision etc. 
In case you face memory overflow issues, use batch size to control how many samples get evaluated at one time. Use a batch_size that is a factor of length of samples. This ensures that you won't miss any samples.

In [15]:
def evaluate(model, data_loader, params=["acc"]):
    """Multi-class classification용 evaluate 함수"""
    model.eval()
    results = []
    
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for inputs, labels, mask in data_loader:
            inputs = inputs.cuda(0)
            outputs = model(inputs)
            probs = F.softmax(outputs, dim=1)
            preds = torch.argmax(outputs, dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())
            all_probs.extend(probs.cpu().numpy())
    
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)
    
    for param in params:
        if param == 'acc':
            results.append(accuracy_score(all_labels, all_preds))
        if param == "auc":
            # Multi-class AUC (one-vs-rest)
            try:
                results.append(roc_auc_score(all_labels, all_probs, multi_class='ovr'))
            except:
                results.append(0.0)
        if param == "recall":
            results.append(recall_score(all_labels, all_preds, average='macro'))
        if param == "precision":
            results.append(precision_score(all_labels, all_preds, average='macro'))
        if param == "fmeasure":
            precision = precision_score(all_labels, all_preds, average='macro')
            recall = recall_score(all_labels, all_preds, average='macro')
            if precision + recall > 0:
                results.append(2 * precision * recall / (precision + recall))
            else:
                results.append(0.0)
    
    model.train()
    return results, all_preds, all_labels


def plot_confusion_matrix(y_true, y_pred, class_names, save_path, title):
    """Confusion matrix를 class별 accuracy (확률)로 표시하고 저장"""
    from sklearn.metrics import confusion_matrix
    import seaborn as sns
    
    cm = confusion_matrix(y_true, y_pred)
    
    # Row-wise normalization: 각 class별로 맞춘 확률 (recall per class)
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    cm_normalized = np.nan_to_num(cm_normalized)  # NaN 처리
    
    fig, ax = plt.subplots(figsize=(10, 8))
    
    # Heatmap with normalized values
    sns.heatmap(cm_normalized, annot=True, fmt='.2%', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names,
                ax=ax, vmin=0, vmax=1, cbar_kws={'label': 'Accuracy'})
    
    # Add raw counts as secondary annotation
    for i in range(len(class_names)):
        for j in range(len(class_names)):
            ax.text(j + 0.5, i + 0.75, f'({cm[i, j]})', 
                   ha='center', va='center', fontsize=8, color='gray')
    
    ax.set_xlabel('Predicted Label', fontsize=12)
    ax.set_ylabel('True Label', fontsize=12)
    ax.set_title(title, fontsize=14)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()
    
    return cm_normalized


def get_patch_time_ms(patch_idx, time_bin, sampling_rate, start_time_ms=-200):
    """패치 인덱스를 실제 시간(ms)으로 변환"""
    time_per_bin_ms = time_bin / sampling_rate * 1000
    start_ms = patch_idx * time_per_bin_ms + start_time_ms
    return start_ms

# DATA LOAD(traer mis archivos)

In [16]:
time_bin = 32
batch_size = 64
MODEL_INPUT_LEN = 448


#### Generate random data

##### Data format:
Datatype - float32 (both X and Y) <br>
X.shape - (#samples, 1, #timepoints,  #channels) <br>
Y.shape - (#samples)

In [17]:
# X_train = np.random.rand(100, 1, 120, 64).astype('float32') # np.random.rand generates between [0, 1)
# y_train = np.round(np.random.rand(100).astype('float32')) # binary data, so we round it to 0 or 1.

# X_val = np.random.rand(100, 1, 120, 64).astype('float32')
# y_val = np.round(np.random.rand(100).astype('float32'))

# X_test = np.random.rand(100, 1, 120, 64).astype('float32')
# y_test = np.round(np.random.rand(100).astype('float32'))

#### Run

In [None]:


class COMBLoader(torch.utils.data.Dataset):
    def __init__(self, root, files, sampling_rate=256, is_train=True, is_augment=False, patch_idx=None, time_bin=None):
        self.root = root
        self.files = files
        self.sampling_rate = sampling_rate
        self.is_train = is_train
        self.is_augment = is_augment
        self.patch_idx = patch_idx

        self.model_input_len = MODEL_INPUT_LEN
        if time_bin is None:
            self.time_bin = int(self.sampling_rate * 0.05)
        else:
            self.time_bin = time_bin

        self.total_blocks = self.model_input_len // self.time_bin

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        sample = pickle.load(open(os.path.join(self.root, self.files[index]), "rb"))
        X = sample["signal"]  # shape: (n_channels, data_len)
        Y = int(sample["label"] - 1)  # [2, 3, 4, 5, 6] -> [0, 1, 2, 3, 4]
        
        n_channels = X.shape[0]
        data_len = X.shape[-1]
        model_input_len = self.model_input_len

        # z-score normalization
        mean = np.mean(X, axis=-1, keepdims=True)
        std = np.std(X, axis=-1, keepdims=True) + 1e-6
        X = (X - mean) / std

        time_bin = self.time_bin
        max_valid_patch = data_len // time_bin

        # 1. random window augmentation
        if self.is_augment and data_len > time_bin:
            if random.random() < 0.5:
                # cumulative Window Augmentation
                end_patch_idx = random.randint(1, max_valid_patch)
                start_patch_idx = 0
            else:
                # random window augmentation
                end_patch_idx = random.randint(1, max_valid_patch)
                start_patch_idx = random.randint(0, end_patch_idx)
        # 2. Fixed window
        else:
            if self.patch_idx is not None:
                # 지정된 패치 인덱스 사용
                start_patch_idx = self.patch_idx
                end_patch_idx = self.patch_idx + 1  # 끝 인덱스는 포함되지 않으므로 +1
            else:
                # [Test] 전체 사용
                start_patch_idx = 0
                end_patch_idx = max_valid_patch

        # 전체를 0으로 초기화 (보고자 하는 영역만 값이 들어감)
        input_tensor = torch.zeros((n_channels, model_input_len), dtype=torch.float32)
        mask = torch.zeros(model_input_len, dtype=torch.float32)

        # 유효 구간 계산
        start_t_index = start_patch_idx * time_bin
        end_t_index = min(end_patch_idx * time_bin, data_len, model_input_len)

        # 해당 구간만 데이터 복사 (나머지는 0)
        copy_len = min(end_t_index - start_t_index, model_input_len - start_t_index)
        if copy_len > 0:
            input_tensor[:, start_t_index:start_t_index + copy_len] = torch.from_numpy(
                X[:, start_t_index:start_t_index + copy_len].astype(np.float32)
            )
            mask[start_t_index:start_t_index + copy_len] = 1.0

        # Shape 변환: (C, T) -> (1, T, C) for EEGNet input
        # EEGNet expects: (batch, 1, T, C)
        input_tensor = input_tensor.transpose(0, 1)  # (T, C)
        input_tensor = input_tensor.unsqueeze(0)     # (1, T, C)

        return input_tensor, Y, mask

In [None]:
root_dir = "/local_raid3/03_user/myyu/EEGPT/downstream_combine3/PreprocessedEEG"
train_files = os.listdir(os.path.join(root_dir, "processed_train"))
test_files = os.listdir(os.path.join(root_dir, "processed_test"))

print("Loading train dataset...")
# Training dataset (augmentation 사용)
train_dataset = COMBLoader(
    os.path.join(root_dir, "processed_train"), train_files, 
    is_train=True, is_augment=False, time_bin=time_bin
)
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, num_workers=8, shuffle=True
)

print("\nLoading test dataset...")
# Test dataset (전체 데이터 사용, patch_idx=None)
test_dataset = COMBLoader(
    os.path.join(root_dir, "processed_test"), test_files, 
    is_train=False, time_bin=time_bin
)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=batch_size, num_workers=8, shuffle=False
)

print(f"\nTrain samples: {len(train_dataset)}, Test samples: {len(test_dataset)}")

# 데이터 shape 및 레이블 확인
sample_input, sample_label, sample_mask = train_dataset[0]
print(f"Sample input shape: {sample_input.shape}")  # (1, 448, 53) expected
print(f"Sample label (model input, 0-4): {sample_label}")
print(f"Sample label (display, 1-5): {sample_label + 1}")
print(f"Sample mask shape: {sample_mask.shape}")

In [None]:
root_dir = "/local_raid3/03_user/myyu/EEGPT/downstream_combine3/PreprocessedEEG"
train_files = os.listdir(os.path.join(root_dir, "processed_train"))
test_files = os.listdir(os.path.join(root_dir, "processed_test"))


num_epochs = 100
num_patches = MODEL_INPUT_LEN // time_bin  # 448 // 32 = 14 patches
sampling_rate = 500

patience = 10  

# 저장 디렉토리 설정
save_dir = "/local_raid3/03_user/myyu/EEGNet/logs"
os.makedirs(save_dir, exist_ok=True)


# 클래스 이름 정의 (표시용: 0-5)
class_names = ['0', '1', '2', '3', '4', '5']

# 각 패치별로 학습 및 테스트
for patch_idx in range(6,num_patches):
    # 결과 저장용 딕셔너리 (패치별로 저장)
    results = {
        'train_patch_idx': [],
        'time_ms': [],
        'train_loss': [],
        'train_acc': [],
        'test_patch_idx': [],
        'test_acc': [],
        'test_balanced_acc': [],
        'test_auc': [],
        'test_fmeasure': []
    }
    # 패치 시간 계산 (ms)
    patch_time_ms = get_patch_time_ms(patch_idx, time_bin, sampling_rate, start_time_ms=-200)
    
    print(f"\n{'='*60}")
    print(f"Training with Patch {patch_idx} (Time: {patch_time_ms:.0f}ms, samples {patch_idx*time_bin}-{(patch_idx+1)*time_bin})")
    print(f"{'='*60}")

    train_dir = os.path.join(save_dir, f"patch_{patch_idx}")
    cm_dir = os.path.join(train_dir, "cm")
    os.makedirs(cm_dir, exist_ok=True)
    os.makedirs(train_dir, exist_ok=True)

    # 해당 패치만 사용하는 train dataset
    train_dataset_patch = COMBLoader(
        os.path.join(root_dir, "processed_train"), train_files,
        is_train=True, is_augment=False, time_bin=time_bin, patch_idx=patch_idx
    )
    train_loader_patch = torch.utils.data.DataLoader(
        train_dataset_patch, batch_size=batch_size, num_workers=8, shuffle=True, persistent_workers=False
    )
    
    # 모델 재초기화 (각 패치마다 새로운 모델)
    net = EEGNet(n_channels=53, n_timepoints=448, n_classes=6).cuda(0)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters())
    
    best_acc = 0.0
    best_model_path = None
    
    # 학습
    for epoch in range(num_epochs):
        net.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for inputs, labels, mask in train_loader_patch:
            inputs = inputs.cuda(0)
            labels = labels.cuda(0)
            
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        train_acc = 100 * correct / total
        avg_loss = running_loss / len(train_loader_patch)
        
        # 테스트 및 모델 저장
        test_results, all_preds, all_labels = evaluate(net, test_loader_all, ["acc"])
        current_acc = test_results[0] * 100
        
        if current_acc > best_acc:
            best_acc = current_acc
            early_stop_counter = 0  # Early stopping counter reset

            file_name = f"EEGNet_F1_P{patch_idx}_CAll-epoch={epoch+1}-valid_acc={current_acc:.2f}.pth"
            best_model_path = os.path.join(save_dir, file_name)
            torch.save(net.state_dict(), best_model_path)
            
            print(f"  [Epoch {epoch+1}] Best Acc Updated: {current_acc:.2f}% (Loss: {avg_loss:.4f}) - Saved")
        
        else:
            early_stop_counter += 1
            if (epoch + 1) % 10 == 0:
                print(f"  Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}, Valid Acc(1-5): {current_acc:.2f}%")
            
            # Patience ���� �� ����
            if early_stop_counter >= patience:
                print(f"\n[Early Stopping] Triggered at epoch {epoch+1}. No improvement for {patience} epochs.")
                break
    # -----------------------------------------------------
    # 학습 종료 후: Best Model 로드 및 최종 평가 (CM 저장)
    # -----------------------------------------------------
    if best_model_path:
        torch.cuda.empty_cache()
        net.load_state_dict(torch.load(best_model_path))
        print(f"Loaded Best Model: {os.path.basename(best_model_path)}")
    
    for test_patch_idx in range(num_patches):
        print(f"  Evaluating on Test Patch {test_patch_idx}...")
        # 해당 패치만 사용하는 test dataset
        test_dataset_patch = COMBLoader(
            os.path.join(root_dir, "processed_test"), test_files,
            is_train=False, time_bin=time_bin, patch_idx=test_patch_idx
        )
        test_loader_patch = torch.utils.data.DataLoader(
            test_dataset_patch, batch_size=batch_size, num_workers=0, shuffle=False
        )


        # 최종 평가 (Metrics 계산)
        test_metrics, all_preds, all_labels = evaluate(net, test_loader_patch, ["acc", "auc", "fmeasure"])



        # 0번 클래스 제외 필터링 (Confusion Matrix용)
        np_preds = np.array(all_preds)
        np_labels = np.array(all_labels)
        valid_indices = np.where(np_labels != 0)[0]
        
        filtered_preds = np_preds[valid_indices]
        filtered_labels = np_labels[valid_indices]
        
        final_acc = 100 * (filtered_preds == filtered_labels).sum() / len(valid_indices)
        final_balanced_acc = 100 * balanced_accuracy_score(filtered_labels, filtered_preds)

        # CM 저장
        cm_filename = f"P{patch_idx}_testP{test_patch_idx}_Acc{final_acc:.2f}_Bal{final_balanced_acc:.2f}_cm.png"
        cm_save_path = os.path.join(train_dir, cm_filename)
        cm_title = f"Train Patch {patch_idx}({patch_time_ms:.0f}ms)/Test Patch {test_patch_idx} \nAcc: {final_acc:.2f}% | Bal Acc: {final_balanced_acc:.2f}%"
        display_class_names = ['1', '2', '3', '4', '5']
        
        plot_confusion_matrix(filtered_labels, filtered_preds, display_class_names, cm_save_path, cm_title)
        print(f"  Confusion matrix saved: {cm_save_path}")
        plt.close('all') # 모든 figure 닫기

        # 결과 기록
        results['train_patch_idx'].append(patch_idx)
        results['time_ms'].append(patch_time_ms)
        results['train_loss'].append(avg_loss) # 마지막 epoch loss
        results['train_acc'].append(train_acc) # 마지막 epoch train acc
        results['test_patch_idx'].append(test_patch_idx)
        results['test_acc'].append(final_acc)
        results['test_balanced_acc'].append(final_balanced_acc) # Balanced Acc
        results['test_auc'].append(test_metrics[1])
        results['test_fmeasure'].append(test_metrics[2])

    # ==========================================
    # 4. 전체 결과 저장 및 그래프 시각화
    # ==========================================
    print(f"\n{'='*60}")
    print("Saving Results & Graphs...")
    print(f"{'='*60}")

    # 1) DataFrame으로 변환 및 CSV 저장
    df_results = pd.DataFrame(results)
    csv_filename = f"P{patch_idx}_testP{test_patch_idx}_Acc{final_acc:.2f}_Bal{final_balanced_acc:.2f}_results.csv"
    csv_path = os.path.join(train_dir, csv_filename)
    df_results.to_csv(csv_path, index=False)
    print(f"Results saved to: {csv_path}")

    # 2) 전체 Test Accuracy 그래프 그리기
    plt.figure(figsize=(12, 6))

    # Standard Accuracy Plot
    plt.plot(df_results['test_patch_idx'], df_results['test_acc'], 
            marker='o', label='Test Accuracy', color='blue', linewidth=2)

    # Balanced Accuracy Plot (점선으로 표시)
    plt.plot(df_results['test_patch_idx'], df_results['test_balanced_acc'], 
            marker='s', label='Balanced Accuracy', color='red', linestyle='--', linewidth=2)

    plt.title("Test Accuracy per Time Patch", fontsize=15)
    plt.xlabel("Time (ms)", fontsize=12)
    plt.ylabel("Accuracy (%)", fontsize=12)
    plt.ylim(0, 100) # 0~100% 범위 고정
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.7)

    # X축 틱 설정 (가독성을 위해)
    plt.xticks(df_results['test_patch_idx'], rotation=45)
    plt.tight_layout()

    graph_filename = f"P{patch_idx}_testP{test_patch_idx}_Acc{final_acc:.2f}_Bal{final_balanced_acc:.2f}_acc.png"
    graph_path = os.path.join(save_dir, graph_filename)
    plt.savefig(graph_path)
    plt.close()

    print(f"Graph saved to: {graph_path}")
    print("All tasks finished successfully.")
    del net, optimizer, criterion, train_loader_patch
    if 'test_loader_patch' in locals(): del test_loader_patch
    torch.cuda.empty_cache() # GPU ������ ������
    gc.collect() # Python 메모리 캐시 비우기

    print(f"Finished Patch {patch_idx}. Moving to next...")

print(f"\n{'='*60}")
print("All Patches Training Complete!")
print(f"{'='*60}")

In [None]:


# 결과를 DataFrame으로 변환
df_results = pd.DataFrame(results)
print(df_results.to_string(index=False))

# X축 레이블 생성 (5개 간격으로 패치 번호와 실제 시간 표시)
tick_interval = 5
tick_indices = list(range(0, num_patches, tick_interval))
if (num_patches - 1) not in tick_indices:
    tick_indices.append(num_patches - 1)
tick_labels = [f"P{i}\n({df_results['time_ms'].iloc[i]:.0f}ms)" for i in tick_indices]

# 결과 시각화
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# 1. Test Accuracy per Patch (시간 레이블)
axes[0, 0].plot(df_results['patch_idx'], df_results['test_acc'], 'o-', color='steelblue', linewidth=2, markersize=8)
axes[0, 0].set_xlabel('Patch Index (Time in ms)', fontsize=11)
axes[0, 0].set_ylabel('Test Accuracy (%)', fontsize=11)
axes[0, 0].set_title('Test Accuracy by Training Patch', fontsize=13)
axes[0, 0].set_xticks(tick_indices)
axes[0, 0].set_xticklabels(tick_labels, fontsize=9)
axes[0, 0].axhline(y=df_results['test_acc'].mean(), color='r', linestyle='--', 
                   label=f"Avg: {df_results['test_acc'].mean():.2f}%")
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# 2. Test AUC per Patch
axes[0, 1].plot(df_results['patch_idx'], df_results['test_auc'], 's-', color='coral', linewidth=2, markersize=8)
axes[0, 1].set_xlabel('Patch Index (Time in ms)', fontsize=11)
axes[0, 1].set_ylabel('Test AUC', fontsize=11)
axes[0, 1].set_title('Test AUC by Training Patch', fontsize=13)
axes[0, 1].set_xticks(tick_indices)
axes[0, 1].set_xticklabels(tick_labels, fontsize=9)
axes[0, 1].axhline(y=df_results['test_auc'].mean(), color='r', linestyle='--',
                   label=f"Avg: {df_results['test_auc'].mean():.4f}")
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# 3. Train Accuracy per Patch
axes[1, 0].plot(df_results['patch_idx'], df_results['train_acc'], '^-', color='green', linewidth=2, markersize=8)
axes[1, 0].set_xlabel('Patch Index (Time in ms)', fontsize=11)
axes[1, 0].set_ylabel('Train Accuracy (%)', fontsize=11)
axes[1, 0].set_title('Final Train Accuracy by Patch', fontsize=13)
axes[1, 0].set_xticks(tick_indices)
axes[1, 0].set_xticklabels(tick_labels, fontsize=9)
axes[1, 0].grid(True, alpha=0.3)

# 4. Train vs Test Accuracy
axes[1, 1].plot(df_results['patch_idx'], df_results['train_acc'], '^-', color='green', 
                linewidth=2, markersize=7, label='Train Acc', alpha=0.8)
axes[1, 1].plot(df_results['patch_idx'], df_results['test_acc'], 'o-', color='steelblue', 
                linewidth=2, markersize=7, label='Test Acc', alpha=0.8)
axes[1, 1].set_xlabel('Patch Index (Time in ms)', fontsize=11)
axes[1, 1].set_ylabel('Accuracy (%)', fontsize=11)
axes[1, 1].set_title('Train vs Test Accuracy by Patch', fontsize=13)
axes[1, 1].set_xticks(tick_indices)
axes[1, 1].set_xticklabels(tick_labels, fontsize=9)
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()

# Accuracy 그래프 저장
acc_graph_path = os.path.join(save_dir, 'patch_training_accuracy.png')
plt.savefig(acc_graph_path, dpi=150, bbox_inches='tight')
print(f"\nAccuracy graph saved: {acc_graph_path}")
plt.show()

# 최종 결과 요약
print("\n" + "="*60)
print("Final Results Summary")
print("="*60)
best_idx = df_results['test_acc'].idxmax()
worst_idx = df_results['test_acc'].idxmin()
print(f"\nBest Patch: P{df_results.loc[best_idx, 'patch_idx']} (Time: {df_results.loc[best_idx, 'time_ms']:.0f}ms) "
      f"- Test Acc: {df_results['test_acc'].max():.2f}%")
print(f"Worst Patch: P{df_results.loc[worst_idx, 'patch_idx']} (Time: {df_results.loc[worst_idx, 'time_ms']:.0f}ms) "
      f"- Test Acc: {df_results['test_acc'].min():.2f}%")
print(f"\nAverage Test Acc: {df_results['test_acc'].mean():.2f}% (std: {df_results['test_acc'].std():.2f}%)")
print(f"Average Test AUC: {df_results['test_auc'].mean():.4f} (std: {df_results['test_auc'].std():.4f})")
print(f"Average Test F1:  {df_results['test_fmeasure'].mean():.4f} (std: {df_results['test_fmeasure'].std():.4f})")

# 결과 CSV로 저장
csv_path = os.path.join(save_dir, 'patch_training_results.csv')
df_results.to_csv(csv_path, index=False)
print(f"\nResults saved to: {csv_path}")