## __Check first before starting__

In [1]:
import os

Working_directory = os.path.normpath("/mnt/mydisk/Continual_Learning_JL/Continual_Learning/")
os.chdir(Working_directory)
print(f"Working directory: {os.getcwd()}")

Working directory: /mnt/mydisk/Continual_Learning_JL/Continual_Learning


## __All imports__

In [2]:
# Operating system and file management
import os
import shutil
import contextlib
import traceback
import gc
import copy
from collections import defaultdict
import subprocess
import time
import re, pickle
import scipy.io
from scipy.io import loadmat
from glob import glob
from math import ceil
from tqdm import tqdm

# Jupyter notebook widgets and display
import ipywidgets as widgets
from IPython.display import display

# Data manipulation and analysis
import pandas as pd
import numpy as np

# Plotting and visualization
import matplotlib.pyplot as plt
from mpl_interactions import zoom_factory, panhandler

# Machine learning and preprocessing
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
import pickle
from ta import trend, momentum, volatility, volume

# Mathematical and scientific computing
import math
from scipy.ndimage import gaussian_filter1d

# Type hinting
from typing import Callable, Tuple

# Deep learning with PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, TensorDataset
from torchvision.models import resnet18
from sklearn.utils.class_weight import compute_class_weight

## __📁 Path Settings and Constants__
This cell defines essential paths and constants for the CPSC2018 ECG dataset processing:
- `BASE_DIR`: Root directory of the project.
- `save_dir`: Path to the preprocessed `.npy` files (one for each continual learning period).
- `ECG_PATH`: Directory containing original `.mat` and `.hea` files.
- `MAX_LEN`: Length of each ECG sample, fixed to 5000 time steps (i.e., 10 seconds at 500Hz).

In [3]:
BASE_DIR = "/mnt/mydisk/Continual_Learning_JL/Continual_Learning/Class_Incremental_CL/CPSC_CIL"
save_dir = os.path.join(BASE_DIR, "processed")
ECG_PATH = os.path.join(BASE_DIR, "datas")
MAX_LEN = 5000

## __🏷️ Label Mapping and Period Configuration__

This section defines:
- `snomed_map`: Mapping from SNOMED CT codes to readable class names for 9 major ECG conditions.
- `period_label_map`: Incremental learning task structure across four periods.  
  Class `1` is reserved for "OTHER" abnormalities until Period 4 when all 9 classes are explicitly categorized.
- `print_class_distribution()`: Helper function to show class-wise data distribution.


In [4]:
# SNOMED CT to readable names
snomed_map = {
    "426783006": "NSR",    # 正常竇性心律
    "270492004": "I-AVB",  # 一度房室傳導阻滯
    "164889003": "AF",     # 心房纖維顫動
    "164909002": "LBBB",   # 左束支傳導阻滯
    "59118001":  "RBBB",   # 右束支傳導阻滯
    "284470004": "PAC",    # 心房早期搏動
    "164884008": "PVC",    # 室性早期搏動
    "429622005": "STD",    # ST 段壓低
    "164931005": "STE"     # ST 段抬高
}

# Period class mapping (固定 class 1 是「其他異常」直到 P4 移除)
period_label_map = {
    1: {"NSR": 0, "OTHER": 1},
    2: {"NSR": 0, "I-AVB": 2, "AF": 3, "OTHER": 1},
    3: {"NSR": 0, "I-AVB": 2, "AF": 3, "LBBB": 4, "RBBB": 5, "OTHER": 1},
    4: {"NSR": 0, "I-AVB": 2, "AF": 3, "LBBB": 4, "RBBB": 5, "PAC": 6, "PVC": 7, "STD": 8, "STE": 9}
}

def print_class_distribution(y, label_map):
    y = np.array(y).flatten()
    total = len(y)
    all_labels = sorted(label_map.values())
    print("\n📊 Class Distribution")
    for lbl in all_labels:
        count = np.sum(y == lbl)
        label = [k for k, v in label_map.items() if v == lbl]
        name = label[0] if label else str(lbl)
        print(f"  ├─ Label {lbl:<2} ({name:<10}) → {count:>5} samples ({(count/total)*100:5.2f}%)")

def ensure_folder(folder_path: str) -> None:
    """Ensure the given folder exists, create it if not."""
    os.makedirs(folder_path, exist_ok=True)


## 📦 EX. Load Example (Period 4) Data and View Format

This example demonstrates how to load preprocessed `.npy` data for **Period 4**, and inspect the dataset shapes and label distribution.  
Use this format as a reference when loading data in other methods (e.g., EWC, PNN, DynEx-CLoRA).

Each ECG sample:
- Has shape `(5000, 12)` → represents 10 seconds (at 500Hz) across 12-lead channels.
- Corresponding label is an integer ID (e.g., 0–9) defined by `period_label_map[4]`.

In [5]:
# # 範例:載入 period 4
# save_dir = os.path.join(BASE_DIR, "processed")
# X_train = np.load(os.path.join(save_dir, "X_train_p4.npy"))
# y_train = np.load(os.path.join(save_dir, "y_train_p4.npy"))
# X_test = np.load(os.path.join(save_dir, "X_test_p4.npy"))
# y_test = np.load(os.path.join(save_dir, "y_test_p4.npy"))

# print("✅ Loaded")
# print("X_train shape:", X_train.shape)
# print("y_train shape:", y_train.shape)
# print("X_test shape:", X_test.shape)
# print("y_test shape:", y_test.shape)
# print_class_distribution(y_train, period_label_map[4])
# print_class_distribution(y_test, period_label_map[4])

# del X_train, y_train, X_test, y_test


## __Check GPU, CUDA, Pytorch__

In [6]:
!nvidia-smi

Thu May  8 21:01:35 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.133.07             Driver Version: 570.133.07     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA RTX A6000               Off |   00000000:2A:00.0 Off |                  Off |
| 46%   62C    P2             88W /  300W |    2876MiB /  49140MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA RTX A6000               Off |   00

### CUDA Details

In [7]:
def check_gpu_config():
    """
    Check GPU availability and display detailed configuration information.
    """
    # Check if GPU is available
    gpu_available = torch.cuda.is_available()
    
    # Print header
    print("=" * 50)
    print("GPU Configuration Check".center(50))
    print("=" * 50)
    
    # Basic GPU availability
    print(f"{'PyTorch Version':<25}: {torch.__version__}")
    print(f"{'GPU Available':<25}: {'Yes' if gpu_available else 'No'}")
    
    # If GPU is available, print detailed info
    if gpu_available:
        print("-" * 50)
        print("GPU Details".center(50))
        print("-" * 50)
        
        # Device info
        print(f"{'Device Name':<25}: {torch.cuda.get_device_name(0)}")
        print(f"{'Number of GPUs':<25}: {torch.cuda.device_count()}")
        print(f"{'Current Device Index':<25}: {torch.cuda.current_device()}")
        
        # Compute capability and CUDA cores
        props = torch.cuda.get_device_properties(0)
        print(f"{'Compute Capability':<25}: {props.major}.{props.minor}")
        print(f"{'Total CUDA Cores':<25}: {props.multi_processor_count * 128}")  # Approx. 128 cores per SM
        
        # Memory info
        total_memory = props.total_memory / (1024 ** 3)  # Convert to GB
        memory_allocated = torch.cuda.memory_allocated(0) / (1024 ** 3)
        memory_reserved = torch.cuda.memory_reserved(0) / (1024 ** 3)
        print(f"{'Total Memory (GB)':<25}: {total_memory:.2f}")
        print(f"{'Allocated Memory (GB)':<25}: {memory_allocated:.2f}")
        print(f"{'Reserved Memory (GB)':<25}: {memory_reserved:.2f}")
    else:
        print("-" * 50)
        print("No GPU detected. Running on CPU.".center(50))
        print("-" * 50)
    
    print("=" * 50)

if __name__ == "__main__":
    check_gpu_config()

             GPU Configuration Check              
PyTorch Version          : 2.5.1
GPU Available            : Yes
--------------------------------------------------
                   GPU Details                    
--------------------------------------------------
Device Name              : NVIDIA RTX A6000
Number of GPUs           : 3
Current Device Index     : 0
Compute Capability       : 8.6
Total CUDA Cores         : 10752
Total Memory (GB)        : 47.41
Allocated Memory (GB)    : 0.00
Reserved Memory (GB)     : 0.00


### PyTorch Details

In [8]:
def print_torch_config():
    """Print PyTorch and CUDA configuration in a formatted manner."""
    print("=" * 50)
    print("PyTorch Configuration".center(50))
    print("=" * 50)
    
    # Basic PyTorch and CUDA info
    print(f"{'PyTorch Version':<25}: {torch.__version__}")
    print(f"{'CUDA Compiled Version':<25}: {torch.version.cuda}")
    print(f"{'CUDA Available':<25}: {'Yes' if torch.cuda.is_available() else 'No'}")
    print(f"{'Number of GPUs':<25}: {torch.cuda.device_count()}")

    # GPU details if available
    if torch.cuda.is_available():
        print(f"{'GPU Name':<25}: {torch.cuda.get_device_name(0)}")

    print("-" * 50)
    
    # Seed setting
    torch.manual_seed(42)
    print(f"{'Random Seed':<25}: 42 (Seeding successful!)")
    
    print("=" * 50)

if __name__ == "__main__":
    print_torch_config()

              PyTorch Configuration               
PyTorch Version          : 2.5.1
CUDA Compiled Version    : 12.1
CUDA Available           : Yes
Number of GPUs           : 3
GPU Name                 : NVIDIA RTX A6000
--------------------------------------------------
Random Seed              : 42 (Seeding successful!)


## __⚙️ GPU Selection — Auto-select the least loaded GPU__
This code automatically scans available GPUs and selects the one with the lowest current memory usage.


In [9]:
def auto_select_cuda_device(verbose=True):
    """
    Automatically selects the CUDA GPU with the least memory usage.
    Falls back to CPU if no GPU is available.
    """
    if not torch.cuda.is_available():
        print("🚫 No CUDA GPU available. Using CPU.")
        return torch.device("cpu")

    try:
        # Run nvidia-smi to get memory usage of each GPU
        smi_output = subprocess.check_output(
            ['nvidia-smi', '--query-gpu=memory.used', '--format=csv,nounits,noheader'],
            encoding='utf-8'
        )
        memory_used = [int(x) for x in smi_output.strip().split('\n')]
        best_gpu = int(np.argmin(memory_used))

        if verbose:
            print("🎯 Automatically selected GPU:")
            print(f"    - CUDA Device ID : {best_gpu}")
            print(f"    - Memory Used    : {memory_used[best_gpu]} MiB")
            print(f"    - Device Name    : {torch.cuda.get_device_name(best_gpu)}")
        return torch.device(f"cuda:{best_gpu}")
    except Exception as e:
        print(f"⚠️ Failed to auto-detect GPU. Falling back to cuda:0. ({e})")
        return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Execute and assign
device = auto_select_cuda_device()

🎯 Automatically selected GPU:
    - CUDA Device ID : 1
    - Memory Used    : 283 MiB
    - Device Name    : NVIDIA RTX A6000


## __Model__

### ResNet 18 - 1D (ResNet18_1D_big_inplane)

In [10]:
class LoRAConv1d(nn.Module):
    def __init__(self, conv_layer: nn.Conv1d, rank: int):
        super(LoRAConv1d, self).__init__()
        self.conv = conv_layer
        self.rank = rank
        
        # 為適應LoRA低秩分解創建A和B矩陣
        self.lora_A = nn.Parameter(torch.zeros(conv_layer.out_channels, rank))
        self.lora_B = nn.Parameter(torch.zeros(rank, conv_layer.in_channels * conv_layer.kernel_size[0]))
        
        # 初始化權重：A用正態分佈，B用零初始化以確保訓練開始時LoRA無影響
        nn.init.normal_(self.lora_A, mean=0.0, std=0.01)
        nn.init.zeros_(self.lora_B)

    def forward(self, x):
        # 使用原始卷積層的權重和參數進行卷積
        return self.conv(x)
        
    def get_delta(self):
        # 計算LoRA權重並重塑為卷積核形狀
        lora_weight = torch.matmul(self.lora_A, self.lora_B).view(
            self.conv.out_channels, self.conv.in_channels, self.conv.kernel_size[0]
        )
        return lora_weight
        
    def parameters(self, recurse=True):
        # 只返回LoRA參數，不包括原始卷積層參數
        return [self.lora_A, self.lora_B]

class BasicBlock1d_LoRA(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None, lora_rank=None):
        super(BasicBlock1d_LoRA, self).__init__()
        self.conv1 = nn.Conv1d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm1d(planes)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv1d(planes, planes, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm1d(planes)
        self.downsample = downsample
        self.stride = stride

        self.lora_rank = lora_rank
        self.lora_adapters = nn.ModuleList()  # 使用ModuleList存儲多個LoRA適配器

    def add_lora_adapter(self):
        """添加一個新的LoRA適配器到conv2層"""
        new_lora = LoRAConv1d(self.conv2, self.lora_rank)
        device = next(self.parameters()).device
        new_lora.to(device)
        self.lora_adapters.append(new_lora)
        return new_lora

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        if len(self.lora_adapters) > 0:
            # 使用原始conv2進行基本卷積
            base_out = self.conv2(out)
            
            # 如果有LoRA適配器，計算所有LoRA的權重增量並應用
            if self.lora_adapters:
                # 計算所有LoRA適配器的權重總和
                lora_weight_delta = sum(adapter.get_delta() for adapter in self.lora_adapters)
                # 調整後的權重 = 原始權重 + LoRA權重增量
                adapted_weight = self.conv2.weight + lora_weight_delta
                # 使用修改後的權重執行卷積
                out = F.conv1d(out, adapted_weight, bias=self.conv2.bias,
                              stride=self.conv2.stride, padding=self.conv2.padding,
                              dilation=self.conv2.dilation, groups=self.conv2.groups)
            else:
                out = base_out
        else:
            # 如果沒有LoRA適配器，使用原始conv2
            out = self.conv2(out)

        out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)
        return out

class ResNet18_1D_LoRA(nn.Module):
    def __init__(self, input_channels=12, output_size=9, inplanes=64, lora_rank=4):
        super(ResNet18_1D_LoRA, self).__init__()
        self.inplanes = inplanes
        self.lora_rank = lora_rank

        # 初始卷積層
        self.conv1 = nn.Conv1d(input_channels, inplanes, kernel_size=15, stride=2, padding=7, bias=False)
        self.bn1 = nn.BatchNorm1d(inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)

        # 殘差層
        self.layer1 = self._make_layer(BasicBlock1d_LoRA, 64, 2)
        self.layer2 = self._make_layer(BasicBlock1d_LoRA, 128, 2, stride=2)
        self.layer3 = self._make_layer(BasicBlock1d_LoRA, 256, 2, stride=2)
        self.layer4 = self._make_layer(BasicBlock1d_LoRA, 512, 2, stride=2)

        # 自適應池化（平均和最大）
        self.adaptiveavgpool = nn.AdaptiveAvgPool1d(1)
        self.adaptivemaxpool = nn.AdaptiveMaxPool1d(1)

        # 全連接層與dropout
        self.dropout = nn.Dropout(0.2)
        self.fc = nn.Linear(512 * 2, output_size)  # *2因為concat了avg和max池化

        # 初始化權重
        self.init_weights()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes:
            downsample = nn.Sequential(
                nn.Conv1d(self.inplanes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm1d(planes),
            )
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.lora_rank))
        self.inplanes = planes
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, lora_rank=self.lora_rank))
        return nn.Sequential(*layers)

    def forward(self, x):
        # 預期輸入形狀: (batch_size, time_steps, channels)
        x = x.permute(0, 2, 1)  # → (batch_size, channels, time_steps)
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        # 應用平均和最大池化
        x1 = self.adaptiveavgpool(x)
        x2 = self.adaptivemaxpool(x)
        
        # 連接池化結果
        x = torch.cat((x1, x2), dim=1)
        
        # 展平
        x = x.view(x.size(0), -1)
        
        # 應用dropout
        x = self.dropout(x)
        
        # 最終分類
        x = self.fc(x)
        
        return x

    def init_weights(self):
        """初始化網絡權重"""
        for m in self.modules():
            if isinstance(m, (nn.Linear, nn.Conv1d)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def add_lora_adapter(self):
        """為所有BasicBlock的conv2層添加一個新的LoRA適配器"""
        lora_count = 0
        added_loras = []
        
        for module in self.modules():
            if isinstance(module, BasicBlock1d_LoRA):
                new_lora = module.add_lora_adapter()
                added_loras.append(new_lora)
                lora_count += 1
                
        print(f"✅ Added new LoRA adapters to {lora_count} BasicBlocks")
        return added_loras

    def get_trainable_parameters(self):
        """返回可訓練參數列表（用於優化器）並提供參數統計"""
        lora_params = []
        lora_names = []
        fc_params = []
        fc_names = []
        
        # 計算總參數數量
        total_params = sum(p.numel() for p in self.parameters())
        
        # 收集所有LoRA參數
        for name, module in self.named_modules():
            if isinstance(module, LoRAConv1d):
                lora_params.append(module.lora_A)
                lora_names.append(f"{name}.lora_A")
                lora_params.append(module.lora_B)
                lora_names.append(f"{name}.lora_B")

        # 添加fc層參數
        for name, param in self.fc.named_parameters():
            fc_params.append(param)
            fc_names.append(f"fc.{name}")
        
        # 計算統計數據
        trainable_params = lora_params + fc_params
        frozen_params = total_params - sum(p.numel() for p in trainable_params)
        lora_param_count = sum(p.numel() for p in lora_params)
        fc_param_count = sum(p.numel() for p in fc_params)
        trainable_param_count = lora_param_count + fc_param_count
        
        # 打印統計信息
        print(f"📊 Parameter Statistics:")
        print(f"  - Total parameters: {total_params:,}")
        print(f"  - Trainable parameters: {trainable_param_count:,} ({trainable_param_count/total_params*100:.2f}%)")
        print(f"    - LoRA parameters: {lora_param_count:,} ({lora_param_count/total_params*100:.2f}%)")
        print(f"    - FC parameters: {fc_param_count:,} ({fc_param_count/total_params*100:.2f}%)")
        print(f"  - Frozen parameters: {frozen_params:,} ({frozen_params/total_params*100:.2f}%)")
        
        print(f"🧠 Trainable parameter names:")
        for name in lora_names:
            print(f"  ✅ {name} (LoRA)")
        for name in fc_names:
            print(f"  ✅ {name} (FC)")
        
        return trainable_params
    
    def count_lora_adapters(self):
        """計算網絡中所有LoRA適配器的數量"""
        total_adapters = 0
        blocks_with_lora = 0
        
        for module in self.modules():
            if isinstance(module, BasicBlock1d_LoRA):
                if len(module.lora_adapters) > 0:
                    blocks_with_lora += 1
                    total_adapters += len(module.lora_adapters)
        
        print(f"📈 LoRA Adapter Statistics:")
        print(f"  - Total LoRA adapters: {total_adapters}")
        print(f"  - BasicBlocks with adapters: {blocks_with_lora}")
        
        return total_adapters
    
    def count_lora_groups(self):
        blocks = [m for m in self.modules() if isinstance(m, BasicBlock1d_LoRA)]
        if not blocks:
            return 0
        return len(blocks[0].lora_adapters)  # 所有 block 的 group 數應該一致


## __Training and validation function__

### Extra Function

In [11]:
def compute_classwise_accuracy(student_logits_flat, y_batch, class_correct, class_total):
    """
    Computes per-class accuracy by accumulating correct and total samples for each class using vectorized operations.
    
    Args:
        student_logits_flat (torch.Tensor): Model predictions (logits) in shape [batch_size * seq_len, output_size]
        y_batch (torch.Tensor): True labels in shape [batch_size * seq_len]
        class_correct (dict): Dictionary to store correct predictions per class
        class_total (dict): Dictionary to store total samples per class
    """
    # Ensure inputs are on the same device
    if student_logits_flat.device != y_batch.device:
        raise ValueError("student_logits_flat and y_batch must be on the same device")

    # Convert logits to predicted class indices
    predictions = torch.argmax(student_logits_flat, dim=-1)  # Shape: [batch_size * seq_len]

    # Compute correct predictions mask
    correct_mask = (predictions == y_batch)  # Shape: [batch_size * seq_len], boolean

    # Get unique labels in this batch
    unique_labels = torch.unique(y_batch)

    # Update class_total and class_correct using vectorized operations
    for label in unique_labels:
        label = label.item()  # Convert tensor to scalar
        if label not in class_total:
            class_total[label] = 0
            class_correct[label] = 0
        
        # Count total samples for this label
        label_mask = (y_batch == label)
        class_total[label] += label_mask.sum().item()
        
        # Count correct predictions for this label
        class_correct[label] += (label_mask & correct_mask).sum().item()

In [12]:
def get_model_parameter_info(model):
    total_params = sum(p.numel() for p in model.parameters())
    param_size_bytes = total_params * 4
    param_size_MB = param_size_bytes / (1024**2)
    return total_params, param_size_MB

In [13]:
def compute_class_weights(y: np.ndarray, num_classes: int, exclude_classes: list = None) -> torch.Tensor:
    """
    計算 class weights（inverse frequency）避免 class imbalance。
    可排除某些類別（如不存在的類別），這些類別的權重將設為 0。
    """
    exclude_classes = set(exclude_classes or [])
    class_sample_counts = np.bincount(y, minlength=num_classes)
    total_samples = len(y)

    weights = np.zeros(num_classes, dtype=np.float32)

    for cls in range(num_classes):
        if cls in exclude_classes:
            weights[cls] = 0.0
        else:
            count = class_sample_counts[cls]
            weights[cls] = total_samples / (count + 1e-6)

    # Normalize only non-excluded weights
    valid_mask = np.array([cls not in exclude_classes for cls in range(num_classes)])
    norm_sum = weights[valid_mask].sum()
    if norm_sum > 0:
        weights[valid_mask] /= norm_sum

    print("\n📊 Class Weights (normalized):")
    for i, w in enumerate(weights):
        status = " (excluded)" if i in exclude_classes else ""
        print(f"  - Class {i}: {w:.4f}{status}")
    
    return torch.tensor(weights, dtype=torch.float32)


### Training Function

#### v10 pure optimizer, all freeze, no consider base

In [14]:
def train_with_dynex_clora_ecg(model, teacher_model, output_size, criterion, optimizer,
                           X_train, y_train, X_val, y_val,
                           num_epochs, batch_size, alpha,
                           model_saving_folder, model_name,
                           stop_signal_file=None, scheduler=None,
                           period=None, stable_classes=None,
                           similarity_threshold=0.0,
                           class_features_dict=None, related_labels=None, device=None):
    
    print(f"\n🚀 'train_with_dynex_clora_ecg' started for Period {period}\n")
    
    start_time = time.time()
    
    model_name = model_name or 'dynex_clora_model'
    model_saving_folder = model_saving_folder or './saved_models'
    
    if os.path.exists(model_saving_folder):
        shutil.rmtree(model_saving_folder)
        print(f"✅ Removed existing folder: {model_saving_folder}")
    os.makedirs(model_saving_folder, exist_ok=True)
    
    device = device or auto_select_cuda_device()
    model.to(device)
    
    if teacher_model:
        teacher_model.to(device)
        teacher_model.eval()
    
    X_train = torch.tensor(X_train, dtype=torch.float32).to(device)
    y_train = torch.tensor(y_train, dtype=torch.long).to(device)
    X_val = torch.tensor(X_val, dtype=torch.float32).to(device)
    y_val = torch.tensor(y_val, dtype=torch.long).to(device)
    
    train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(TensorDataset(X_val, y_val), batch_size=batch_size, shuffle=False)
    
    print("\n✅ Data Overview:")
    print(f"X_train: {X_train.shape}, y_train: {y_train.shape}")
    print(f"X_val: {X_val.shape}, y_val: {y_val.shape}")
    
    best_results = []
    
    # === Class Feature Extraction ===
    model.eval()
    new_class_features = {}
    
    # Extract features for current classes
    with torch.no_grad():
        for xb, yb in train_loader:
            # Get feature representations
            features = extract_features(model, xb)
            for cls in torch.unique(yb):
                cls_mask = (yb == cls)
                cls_feat = features[cls_mask]
                if cls.item() not in new_class_features:
                    new_class_features[cls.item()] = []
                new_class_features[cls.item()].append(cls_feat)
    
    # Average features per class
    for cls in new_class_features:
        new_class_features[cls] = torch.cat(new_class_features[cls], dim=0).mean(dim=0)
    
    # Initialize related_labels if not provided
    if related_labels is None:
        related_labels = {}
    
    # === Similarity Computation (only for Period > 1) ===
    if period > 1 and class_features_dict:
        cosine_sim = torch.nn.CosineSimilarity(dim=0)
        similarity_scores = {}
        
        # Calculate similarities between new and old classes
        for new_label, new_feat in new_class_features.items():
            similarity_scores[new_label] = {}
            for old_label, old_feat in class_features_dict.items():
                sim = cosine_sim(new_feat.to(device), old_feat.to(device)).item()
                similarity_scores[new_label][old_label] = sim
        
        # Print similarity information
        print("\n🔎 Similarity Analysis:")
        print(f"  Similarity threshold: {similarity_threshold:.4f}")
        print(f"  Existing classes: {sorted(list(class_features_dict.keys()))}")
        print(f"  Current classes: {sorted(list(new_class_features.keys()))}")
        
        # Calculate new classes (classes not in previous periods)
        existing_classes = set(class_features_dict.keys())
        current_classes = set(new_class_features.keys())
        new_classes = current_classes - existing_classes
        print(f"  New classes: {sorted(list(new_classes))}")
        
        # Print similarity scores
        print("\n📊 Similarity Scores:")
        for new_label, scores in similarity_scores.items():
            if new_label in new_classes:  # Only show for new classes
                print(f"  New Class {new_label}:")
                if scores:
                    for old_label, score in sorted(scores.items(), key=lambda x: x[1], reverse=True):
                        print(f"    - Existing Class {old_label}: {score:.4f}")
                else:
                    print("    - No existing classes to compare")
        
        # Calculate average similarity statistics
        all_similarities = [s for scores in similarity_scores.values() for s in scores.values()]
        if all_similarities:
            avg_similarity = np.mean(all_similarities)
            std_similarity = np.std(all_similarities)
            print(f"\n  Average similarity: {avg_similarity:.4f}, Std: {std_similarity:.4f}")
    
    # === Period-specific LoRA Management Logic ===
    to_unfreeze = set()
    
    # Special handling for period 1 - no LoRA adapters yet
    if period == 1:
        if not related_labels:
            # Initialize related_labels for first period
            # The base conv layers (index 'base') are associated with initial classes
            initial_classes = list(new_class_features.keys())
            related_labels['base'] = initial_classes
            print(f"\n🔄 Initializing related_labels for first period: {related_labels}")
        
        # For first period, all base model parameters are trainable
        print("\n🔓 First period: All model parameters are trainable")
    
    # For periods > 1, manage LoRA adapters
    elif period > 1 and class_features_dict:
        new_lora_indices = []
        cosine_sim = torch.nn.CosineSimilarity(dim=0)
        
        existing_classes = set(class_features_dict.keys())
        current_classes = set(new_class_features.keys())
        new_classes = current_classes - existing_classes
        
        print(f"\n🧩 Managing LoRA adapters for {len(new_classes)} new classes...")
        
        # Process each new class
        for new_cls in new_classes:
            new_feat = new_class_features[new_cls]
            
            # Calculate similarities to all existing classes
            sims = [(old_cls, cosine_sim(new_feat.to(device), class_features_dict[old_cls].to(device)).item())
                   for old_cls in class_features_dict]
            
            # Sort by similarity (highest first)
            sims.sort(key=lambda x: x[1], reverse=True)
            
            # Check if new class is similar to any existing class
            matched = False
            for old_cls, sim in sims:
                if sim >= similarity_threshold:
                    matched = True
                    # Find which adapter/network is associated with this old class
                    for adapter_idx, related_cls_list in related_labels.items():
                        if old_cls in related_cls_list:
                            # Add this new class to the same adapter's related classes
                            if new_cls not in related_labels[adapter_idx]:
                                related_labels[adapter_idx].append(new_cls)
                                print(f"🔄 New Class {new_cls} is similar to Class {old_cls} (sim={sim:.4f}) → Added to adapter '{adapter_idx}'")
                            
                            # Mark this adapter for unfreezing during training
                            to_unfreeze.add(adapter_idx)
                            break
            
            # If no match found, create new LoRA adapter
            if not matched:
                # === 加入一整組 LoRA adapters（每層一個） ===
                model.add_lora_adapter()

                # 記錄這是第幾組（用 group index）
                group_idx = max([k for k in related_labels.keys() if isinstance(k, int)], default=-1) + 1

                related_labels[group_idx] = [new_cls]
                new_lora_indices.append(group_idx)
                print(f"➕ New Class {new_cls} is not similar to any existing class → Created new adapter group #{group_idx}")
        
        # Check stability of existing classes
        print("\n🔍 Checking stability of existing classes...")
        for old_cls in existing_classes & current_classes:  # Intersection - classes that exist in both periods
            if old_cls in new_class_features:
                sim_self = cosine_sim(new_class_features[old_cls].to(device), 
                                      class_features_dict[old_cls].to(device)).item()
                print(f"  Class {old_cls} similarity with itself: {sim_self:.4f}")
                
                # If class representation has drifted too much
                if sim_self < similarity_threshold:
                    for adapter_idx, related_cls_list in related_labels.items():
                        if old_cls in related_cls_list:
                            to_unfreeze.add(adapter_idx)
                            print(f"⚠️ Class {old_cls} has drifted (self-sim={sim_self:.4f}) → Unfreezing adapter '{adapter_idx}'")
        
        # # Freeze all LoRA adapters and conv2 weights
        # print("\n🔒 Default: Freezing all LoRA adapters and base conv2 weights")
        # for module in model.modules():
        #     if isinstance(module, BasicBlock1d_LoRA):
        #         # freeze all conv2
        #         for param in module.conv2.parameters():
        #             param.requires_grad = False
        #         # freeze all LoRA inside
        #         for adapter in module.lora_adapters:
        #             for param in adapter.parameters():
        #                 param.requires_grad = False

        # 🔒 Freeze ALL model parameters first
        print("\n🔒 Default: Freezing ALL model parameters")
        for name, param in model.named_parameters():
            param.requires_grad = False
        
        ####### Check frozen parameters #######
        print("\n🔍 Frozen Parameters:")
        for name, param in model.named_parameters():
            if not param.requires_grad:
                print(f"  ❌ {name:<50} | shape={list(param.shape)}")
        ####### Check frozen parameters #######

        # 🔓 Unfreeze specific adapter groups
        print("\n🔓 Unfreezing selected adapters (by group):")
        for adapter_group_idx in to_unfreeze:
            if isinstance(adapter_group_idx, int):
                for module in model.modules():
                    if isinstance(module, BasicBlock1d_LoRA):
                        if adapter_group_idx < len(module.lora_adapters):
                            for param in module.lora_adapters[adapter_group_idx].parameters():
                                param.requires_grad = True
                print(f"  - Adapter Group #{adapter_group_idx} (all blocks) (classes: {related_labels.get(adapter_group_idx, [])})")
            elif adapter_group_idx == 'base':
                print(f"  ⛔ Base layers (classes: {related_labels.get('base', [])}) are frozen and will NOT be updated.")
                # for module in model.modules():
                #     if isinstance(module, BasicBlock1d_LoRA):
                #         for p in module.conv2.parameters():
                #             p.requires_grad = True
                # print(f"  - Base layers (classes: {related_labels.get('base', [])})")

        ####### Check frozen parameters #######
        print("\n🔍 Frozen Parameters: (After Unfreeze specific adapter groups)")
        for name, param in model.named_parameters():
            if not param.requires_grad:
                print(f"  ❌ {name:<50} | shape={list(param.shape)}")
        ####### Check frozen parameters #######

        # 🔓 Unfreeze newly added adapter groups
        for group_idx in new_lora_indices:
            block_counter = 0
            for module in model.modules():
                if isinstance(module, BasicBlock1d_LoRA):
                    if group_idx < len(module.lora_adapters):
                        adapter = module.lora_adapters[group_idx]
                        for param in adapter.parameters():
                            param.requires_grad = True
                    block_counter += 1
            print(f"  - New adapter group #{group_idx} (classes: {related_labels[group_idx]})")
            
        ####### Check frozen parameters #######
        print("\n🔍 Frozen Parameters (After Unfreeze newly added adapter groups):")
        for name, param in model.named_parameters():
            if not param.requires_grad:
                print(f"  ❌ {name:<50} | shape={list(param.shape)}")
        ####### Check frozen parameters #######

        # Always unfreeze the final layer for all periods
        for param in model.fc.parameters():
            param.requires_grad = True

        ####### Check frozen parameters #######
        print("\n🔍 Frozen Parameters (After unfreeze the final layer):")
        for name, param in model.named_parameters():
            if not param.requires_grad:
                print(f"  ❌ {name:<50} | shape={list(param.shape)}")
        ####### Check frozen parameters #######

    # Print summary of related_labels
    print(f"\n📋 Related Labels Summary:")
    for adapter_idx, classes in related_labels.items():
        print(f"  - {'Base network' if adapter_idx == 'base' else f'Adapter #{adapter_idx}'}: Classes {classes}")
    
    # Print trainable parameters
    print("\n🔧 Trainable Parameter Status:")
    total_params = 0
    trainable_params = 0
    trainable_count = 0
    
    for name, param in model.named_parameters():
        total_params += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
            trainable_count += 1
            print(f"  ✅ {name:<50} | shape={list(param.shape)}")
    print(f"trainable_count: {trainable_count}")
    
    frozen_params = total_params - trainable_params
    print(f"\n📊 Parameter Statistics:")
    print(f"  - Total parameters: {total_params:,}")
    print(f"  - Trainable parameters: {trainable_params:,} ({trainable_params/total_params*100:.2f}%)")
    print(f"  - Frozen parameters: {frozen_params:,} ({frozen_params/total_params*100:.2f}%)")
    
    # Update optimizer with only trainable parameters
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=optimizer.param_groups[0]['lr'],
        weight_decay=optimizer.param_groups[0].get('weight_decay', 0)
    )

    # 重設所有參數的 requires_grad = True
    for param in model.parameters():
        param.requires_grad = True

    ####### Check frozen parameters #######
    print("\n🔍 Frozen Parameters:")
    for name, param in model.named_parameters():
        if not param.requires_grad:
            print(f"  ❌ {name:<50} | shape={list(param.shape)}")
    ####### Check frozen parameters #######

    # 檢查最終 optimizer 控制了哪些參數
    optimizer_param_ids = set(id(p) for group in optimizer.param_groups for p in group['params'])

    named_params = list(model.named_parameters())
    print(f"\n🧠 Parameters currently controlled by the optimizer: ({len(named_params)})")
    for name, param in named_params:
        if id(param) in optimizer_param_ids:
            print(f"  ✅ {name}")
        else:
            print(f"  ⛔ {name} (NOT included in optimizer)")

    print("\n" + "=" * 80)
    print("Starting training...\n")
    
    # Training loop
    for epoch in range(num_epochs):
        if stop_signal_file and os.path.exists(stop_signal_file):
            print("\n🛑 Stop signal detected. Exiting training loop.")
            break
        
        # Training phase
        model.train()
        epoch_loss = 0.0
        class_correct, class_total = {}, {}
        
        for xb, yb in train_loader:
            optimizer.zero_grad()
            logits = model(xb)
            ce_loss = criterion(logits, yb)
            
            # Apply distillation if teacher model is provided
            if teacher_model and stable_classes:
                with torch.no_grad():
                    teacher_logits = teacher_model(xb)
                student_stable = logits[:, stable_classes]
                teacher_stable = teacher_logits[:, stable_classes]
                distill_loss = F.mse_loss(student_stable, teacher_stable)
                total_loss = alpha * distill_loss + (1 - alpha) * ce_loss
            else:
                total_loss = ce_loss
            
            total_loss.backward()
            optimizer.step()
            
            epoch_loss += total_loss.item() * xb.size(0)
            compute_classwise_accuracy(logits, yb, class_correct, class_total)
        
        train_loss = epoch_loss / len(train_loader.dataset)
        train_acc = {int(k): f"{(class_correct[k] / class_total[k]) * 100:.2f}%" 
                    if class_total[k] > 0 else "0.00%" for k in class_total}
        
        # Validation phase
        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0
        val_class_correct, val_class_total = {}, {}
        
        with torch.no_grad():
            for xb, yb in val_loader:
                outputs = model(xb)
                val_loss += criterion(outputs, yb).item() * xb.size(0)
                preds = torch.argmax(outputs, dim=1)
                val_correct += (preds == yb).sum().item()
                val_total += yb.size(0)
                compute_classwise_accuracy(outputs, yb, val_class_correct, val_class_total)
        
        val_loss /= len(val_loader.dataset)
        val_acc = val_correct / val_total
        val_acc_cls = {int(k): f"{(val_class_correct[k]/val_class_total[k])*100:.2f}%" 
                      if val_class_total[k] > 0 else "0.00%" for k in val_class_total}
        
        # Print epoch results
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.6f}, Train-Class-Acc: {train_acc}")
        print(f"Val Loss: {val_loss:.6f}, Val Acc: {val_acc*100:.2f}%, Val-Class-Acc: {val_acc_cls}, LR: {optimizer.param_groups[0]['lr']:.6f}")
        
        # Save model checkpoint
        model_path = os.path.join(model_saving_folder, f"{model_name}_epoch_{epoch+1}.pth")
        current = {
            'epoch': epoch + 1,
            'train_loss': train_loss,
            'val_loss': val_loss,
            'val_accuracy': val_acc,
            'train_classwise_accuracy': train_acc,
            'val_classwise_accuracy': val_acc_cls,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'learning_rate': optimizer.param_groups[0]['lr'],
            'model_path': model_path,
            'num_lora_groups': model.count_lora_groups(),
            'related_labels': related_labels
        }
        
        # Keep top 5 best models
        if len(best_results) < 5 or val_acc > best_results[-1]['val_accuracy']:
            if len(best_results) == 5:
                to_remove = best_results.pop()
                if os.path.exists(to_remove['model_path']):
                    os.remove(to_remove['model_path'])
                    print(f"🗑 Removed: {to_remove['model_path']}")
            best_results.append(current)
            best_results.sort(key=lambda x: (x['val_accuracy'], x['epoch']), reverse=True)
            torch.save(current, model_path)
            print(f"✅ Saved model: {model_path}")
        
        # Update learning rate scheduler if provided
        if scheduler:
            scheduler.step(val_loss)
    
    # End of training
    elapsed_time = time.time() - start_time
    total_params, param_size_MB = get_model_parameter_info(model)
    
    # Save best model
    if best_results:
        best = best_results[0]
        best_model_path = os.path.join(model_saving_folder, f"{model_name}_best.pth")
        torch.save(best, best_model_path)
        print(f"\n🏆 Best model saved as: {best_model_path} (Val Accuracy: {best['val_accuracy'] * 100:.2f}%)")
    
    # Save final model
    final_model_path = os.path.join(model_saving_folder, f"{model_name}_final.pth")
    torch.save(current, final_model_path)
    print(f"\n📌 Final model saved as: {final_model_path}")
    
    # Print top 5 models
    print("\n🎯 Top 5 Best Models:")
    for res in best_results:
        print(f"Epoch {res['epoch']}, Train Loss: {res['train_loss']:.6f}, Train-Acc: {res['train_classwise_accuracy']},\n"
              f"Val Loss: {res['val_loss']:.6f}, Val Acc: {res['val_accuracy']*100:.2f}%, Val-Acc: {res['val_classwise_accuracy']}, "
              f"Model Path: {res['model_path']}")
    
    # Print model summary
    print(f"\n🧠 Model Summary:")
    print(f"Total Parameters: {total_params:,}")
    print(f"Model Size (float32): {param_size_MB:.2f} MB")
    print(f"Number of LoRA adapters: {model.count_lora_adapters()}")
    print(f"Number of LoRA groups: {model.count_lora_groups()}")
    print(f"Total Training Time: {elapsed_time:.2f} seconds")
    
    # Extract period number from folder name
    match = re.search(r'Period_(\d+)', model_saving_folder)
    period_label = match.group(1) if match else str(period)
    model_name_str = model.__class__.__name__
    
    # Print markdown summary
    best_model = max(best_results, key=lambda x: x['val_accuracy'])
    print(f"""
---
### Period {period_label} (alpha = {alpha}, similarity_threshold = {similarity_threshold})
+ ##### Total training time: {elapsed_time:.2f} seconds
+ ##### Model: {model_name_str}
+ ##### Training and saving in *'{model_saving_folder}'*
+ ##### Best Epoch: {best_model['epoch']}
#### __Val Accuracy: {best_model['val_accuracy'] * 100:.2f}%__
#### __Val-Class-Acc: {best_model['val_classwise_accuracy']}__
#### __Total Parameters: {total_params:,}__
#### __Model Size (float32): {param_size_MB:.2f} MB__
#### __Number of LoRA adapters: {model.count_lora_adapters()}__
#### __Number of LoRA groups: {model.count_lora_groups()}__
""".strip())
    
    # Save class features for next period
    if class_features_dict is None:
        class_features_dict = {}
    class_features_dict.update(new_class_features)
    with open(os.path.join(model_saving_folder, "class_features.pkl"), 'wb') as f:
        pickle.dump(class_features_dict, f)
    print(f"\nSaved class features to: {os.path.join(model_saving_folder, 'class_features.pkl')}")
    
    # Clean up memory
    torch.cuda.empty_cache()
    gc.collect()


def extract_features(model, x):
    """Helper function to extract features from the model for similarity calculation"""
    # This is a placeholder - you'll need to adapt this based on your actual model architecture
    # The goal is to extract meaningful features before the classification layer
    # For ResNet18_1D_LoRA model, this would typically be the features right before the fc layer
    
    # Example (pseudo-code - adapt to your actual model):
    x = x.permute(0, 2, 1)  # Convert to (batch_size, channels, time_steps)
    
    # Feed through the network up to the point before classification
    with torch.no_grad():
        x = model.conv1(x)
        x = model.bn1(x)
        x = model.relu(x)
        x = model.maxpool(x)
        
        x = model.layer1(x)
        x = model.layer2(x)
        x = model.layer3(x)
        x = model.layer4(x)
        
        # Apply pooling
        x1 = model.adaptiveavgpool(x)
        x2 = model.adaptivemaxpool(x)
        
        # Concatenate pooling results
        x = torch.cat((x1, x2), dim=1)
        
        # Flatten
        features = x.view(x.size(0), -1)
    
    return features

## __Training__

### Period 1 Summary

In [15]:
def display_model_summary_with_params(model_folder, model_filename="ResNet18_big_inplane_1D_best.pth", input_channels=12, output_size=10):
    model_path = os.path.join(model_folder, model_filename)

    if not os.path.exists(model_path):
        print(f"❌ File not found: {model_path}")
        return

    checkpoint = torch.load(model_path, map_location='cpu')

    # === 還原模型並載入參數 ===
    model = ResNet18_1D_LoRA(input_channels=input_channels, output_size=output_size)
    model.load_state_dict(checkpoint["model_state_dict"])
    total_params, param_size_MB = get_model_parameter_info(model)

    # === 顯示摘要 ===
    epoch = checkpoint.get("epoch", "?")
    train_loss = checkpoint.get("train_loss", "?")
    val_loss = checkpoint.get("val_loss", "?")
    val_acc = checkpoint.get("val_accuracy", "?")
    train_acc_dict = checkpoint.get("train_classwise_accuracy", {})
    val_acc_dict = checkpoint.get("val_classwise_accuracy", {})
    lr = checkpoint.get("learning_rate", "?")
    stored_path = checkpoint.get("model_path", "N/A")
    print(f"Model Architecture:")
    print(model)
    print(f"\n📦 Model Summary from: {model_path}")
    print(f"📌 Epoch: {epoch}")
    print(f"🧮 Train Loss: {train_loss:.6f}" if isinstance(train_loss, float) else f"🧮 Train Loss: {train_loss}")
    print(f"🎯 Val Loss: {val_loss:.6f}" if isinstance(val_loss, float) else f"🎯 Val Loss: {val_loss}")
    print(f"✅ Val Accuracy: {val_acc*100:.2f}%" if isinstance(val_acc, float) else f"✅ Val Accuracy: {val_acc}")
    print(f"📎 Learning Rate: {lr}")
    print(f"📁 Stored Model Path: {stored_path}")
    print(f"🧠 Total Parameters: {total_params:,}")
    print(f"📏 Model Size (float32): {param_size_MB:.2f} MB")

    print("\n📊 Train Class-wise Accuracy:")
    for c, acc in train_acc_dict.items():
        print(f"  └─ Class {c:<2}: {acc}")

    print("\n📊 Val Class-wise Accuracy:")
    for c, acc in val_acc_dict.items():
        print(f"  └─ Class {c:<2}: {acc}")

    print("\n---\n### Period 1 Summary (Markdown Format)")
    print(f"+ **Epoch:** {epoch}")
    print(f"+ **Train Loss:** {train_loss}")
    print(f"+ **Val Loss:** {val_loss}")
    print(f"+ **Val Accuracy:** {val_acc*100:.2f}%" if isinstance(val_acc, float) else f"+ **Val Accuracy:** {val_acc}")
    print(f"+ **Learning Rate:** {lr}")
    print(f"+ **Stored Model Path:** `{stored_path}`")
    print(f"+ **Total Parameters:** {total_params:,}")
    print(f"+ **Model Size (float32):** {param_size_MB:.2f} MB")
    print(f"+ **Train-Class-Acc:** {train_acc_dict}")
    print(f"+ **Val-Class-Acc:** {val_acc_dict}")
    print("---")

# Example call:
display_model_summary_with_params(
    model_folder=os.path.join("Class_Incremental_CL", "CPSC_CIL", "ResNet18_Selection", "ResNet18_big_inplane_v1"),
    input_channels=12,
    output_size=2
)

Model Architecture:
ResNet18_1D_LoRA(
  (conv1): Conv1d(12, 64, kernel_size=(15,), stride=(2,), padding=(7,), bias=False)
  (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool1d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock1d_LoRA(
      (conv1): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
      (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
      (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (lora_adapters): ModuleList()
    )
    (1): BasicBlock1d_LoRA(
      (conv1): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
      (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, tra

  checkpoint = torch.load(model_path, map_location='cpu')


### Generate class features (period1)

In [18]:
def generate_class_features_period1(
    model_path: str,
    save_path: str,
    X_train: np.ndarray,
    y_train: np.ndarray,
    input_channels: int = 12,
    output_size: int = 2,
    batch_size: int = 64
):
    # === 載入模型 ===
    device = auto_select_cuda_device()
    model = ResNet18_1D_LoRA(input_channels=input_channels, output_size=output_size)
    checkpoint = torch.load(model_path, map_location='cpu')
    model.load_state_dict(checkpoint["model_state_dict"])
    model.to(device)
    model.eval()

    # === 製作 Dataloader ===
    X_tensor = torch.tensor(X_train, dtype=torch.float32)
    y_tensor = torch.tensor(y_train, dtype=torch.long)
    dataloader = DataLoader(TensorDataset(X_tensor, y_tensor), batch_size=batch_size, shuffle=False)

    # === 開始萃取特徵 ===
    class_features_dict = {}
    with torch.no_grad():
        for xb, yb in tqdm(dataloader, desc="Extracting Features"):
            xb, yb = xb.to(device), yb.to(device)
            features = extract_features(model, xb)  # shape: [B, F]
            for cls in torch.unique(yb):
                cls_mask = (yb == cls)
                cls_feat = features[cls_mask]
                cls_id = cls.item()
                if cls_id not in class_features_dict:
                    class_features_dict[cls_id] = []
                class_features_dict[cls_id].append(cls_feat.cpu())

    # === 平均每個 class 的 feature 向量 ===
    for cls in class_features_dict:
        class_features_dict[cls] = torch.cat(class_features_dict[cls], dim=0).mean(dim=0)

    # === 儲存為 .pkl 方便 Period 2 載入 ===
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    with open(save_path, 'wb') as f:
        pickle.dump(class_features_dict, f)
    
    print(f"✅ Saved class_features_dict to: {save_path}")
    return class_features_dict


model_path = os.path.join(BASE_DIR, "ResNet18_Selection", "ResNet18_big_inplane_v1", "ResNet18_big_inplane_1D_best.pth")
save_path = os.path.join(BASE_DIR, "ResNet18_Selection", "ResNet18_big_inplane_v1", "class_features.pkl")
X_train = np.load(os.path.join(save_dir, f"X_train_p{1}.npy"))
y_train = np.load(os.path.join(save_dir, f"y_train_p{1}.npy"))

class_features_dict = generate_class_features_period1(
    model_path=model_path,
    save_path=save_path,
    X_train=X_train,  # 你從 npy 載入的 Period 1 資料
    y_train=y_train,
    input_channels=12,
    output_size=2,
    batch_size=64
)

🎯 Automatically selected GPU:
    - CUDA Device ID : 1
    - Memory Used    : 18 MiB
    - Device Name    : NVIDIA RTX A6000


  checkpoint = torch.load(model_path, map_location='cpu')
Extracting Features: 100%|██████████| 23/23 [00:00<00:00, 59.05it/s]


✅ Saved class_features_dict to: /mnt/mydisk/Continual_Learning_JL/Continual_Learning/Class_Incremental_CL/CPSC_CIL/ResNet18_Selection/ResNet18_big_inplane_v1/class_features.pkl


### Period 2

#### v10 no distillation, all freeze, th=0.99

In [15]:
# ================================
# 📌 Period 2: DynEx-CLoRA Training (ECG)
# ================================
period = 2

# ==== Paths ====
stop_signal_file = os.path.join(BASE_DIR, "stop_training.txt")
model_saving_folder = os.path.join(BASE_DIR, "Trained_models", "DynEx_CLoRA_CIL_v10", f"Period_{period}")
ensure_folder(model_saving_folder)

# ==== Load Period 2 Data ====
X_train = np.load(os.path.join(save_dir, f"X_train_p{period}.npy"))
y_train = np.load(os.path.join(save_dir, f"y_train_p{period}.npy"))
X_val   = np.load(os.path.join(save_dir, f"X_test_p{period}.npy"))
y_val   = np.load(os.path.join(save_dir, f"y_test_p{period}.npy"))

# ==== Device ====
device = auto_select_cuda_device()

# ==== 載入 Period 1 的 class features ====
prev_folder = os.path.join(BASE_DIR, "ResNet18_Selection", "ResNet18_big_inplane_v1")
class_features_path = os.path.join(prev_folder, "class_features.pkl")
with open(class_features_path, "rb") as f:
    class_features_dict = pickle.load(f)
print(f"✅ Loaded class features from: {class_features_path}")

# ==== 載入 Period 1 預訓練模型 ====
prev_model_path = os.path.join(prev_folder, "ResNet18_big_inplane_1D_best.pth")
checkpoint = torch.load(prev_model_path, map_location=device)

input_channels = X_train.shape[2]
output_size = len(np.unique(y_train))

# ==== 建立 teacher model（output_size 要扣掉新類別數）
teacher_model = ResNet18_1D_LoRA(input_channels=input_channels, output_size=output_size - 2).to(device)
teacher_model.load_state_dict(checkpoint["model_state_dict"], strict=False)
teacher_model.eval()

# ==== 建立 student model ====
model = ResNet18_1D_LoRA(input_channels=input_channels, output_size=output_size).to(device)

# ==== 根據 LoRA 數量同步 adapter 結構 ====
num_lora_groups = checkpoint.get("num_lora_groups", 0)
print(f"🔄 Number of LoRA groups: {num_lora_groups}")
related_labels = checkpoint.get("related_labels", {"base": [0, 1]})
for _ in range(num_lora_groups):
    model.add_lora_adapter()

# ==== 複製 shared 權重（排除 fc / lora_adapter）====
model_dict = model.state_dict()
prev_state_dict = checkpoint["model_state_dict"]
filtered_dict = {
    k: v for k, v in prev_state_dict.items()
    if k in model_dict and model_dict[k].shape == v.shape and k not in ["fc.weight", "fc.bias"]
}
model.load_state_dict(filtered_dict, strict=False)
for k in model_dict:
    if k not in filtered_dict:
        print(f"🔍 Not loaded: {k}, shape={model_dict[k].shape}")
print("✅ Loaded shared weights from Period 1 (excluding FC only)")

# ==== 訓練參數 ====
learning_rate = 1e-3
weight_decay = 1e-5
num_epochs = 200
batch_size = 64
alpha = 0.0 # No distillation
similarity_threshold = 0.99
stable_classes = [0]
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.9, patience=10)

# ==== 開始訓練 ====
train_with_dynex_clora_ecg(
    model=model,
    teacher_model=None,
    output_size=output_size,
    criterion=criterion,
    optimizer=optimizer,
    X_train=X_train,
    y_train=y_train,
    X_val=X_val,
    y_val=y_val,
    num_epochs=num_epochs,
    batch_size=batch_size,
    alpha=alpha,
    model_saving_folder=model_saving_folder,
    model_name="ResNet18_1D_LoRA",
    stop_signal_file=stop_signal_file,
    scheduler=scheduler,
    period=period,
    stable_classes=stable_classes,
    similarity_threshold=similarity_threshold,
    class_features_dict=class_features_dict,
    related_labels=related_labels,
    device=device
)

# ==== Cleanup ====
del X_train, y_train, X_val, y_val
del model, teacher_model, checkpoint, model_dict, prev_state_dict, filtered_dict
gc.collect()
torch.cuda.empty_cache()

🎯 Automatically selected GPU:
    - CUDA Device ID : 1
    - Memory Used    : 283 MiB
    - Device Name    : NVIDIA RTX A6000
✅ Loaded class features from: /mnt/mydisk/Continual_Learning_JL/Continual_Learning/Class_Incremental_CL/CPSC_CIL/ResNet18_Selection/ResNet18_big_inplane_v1/class_features.pkl


  checkpoint = torch.load(prev_model_path, map_location=device)


🔄 Number of LoRA groups: 0
🔍 Not loaded: fc.weight, shape=torch.Size([4, 1024])
🔍 Not loaded: fc.bias, shape=torch.Size([4])
✅ Loaded shared weights from Period 1 (excluding FC only)

🚀 'train_with_dynex_clora_ecg' started for Period 2

✅ Removed existing folder: /mnt/mydisk/Continual_Learning_JL/Continual_Learning/Class_Incremental_CL/CPSC_CIL/Trained_models/DynEx_CLoRA_CIL_v10/Period_2

✅ Data Overview:
X_train: torch.Size([3263, 5000, 12]), y_train: torch.Size([3263])
X_val: torch.Size([816, 5000, 12]), y_val: torch.Size([816])

🔎 Similarity Analysis:
  Similarity threshold: 0.9900
  Existing classes: [0, 1]
  Current classes: [0, 1, 2, 3]
  New classes: [2, 3]

📊 Similarity Scores:
  New Class 2:
    - Existing Class 1: 0.9665
    - Existing Class 0: 0.8415
  New Class 3:
    - Existing Class 1: 0.9912
    - Existing Class 0: 0.8097

  Average similarity: 0.9055, Std: 0.0837

🧩 Managing LoRA adapters for 2 new classes...
✅ Added new LoRA adapters to 8 BasicBlocks
➕ New Class 2 is n

### Period 3

#### v10 no distillation, all freeze, th=0.85

In [None]:
# ================================
# 📌 Period 3: DynEx-CLoRA Training (ECG)
# ================================
period = 3

# ==== Paths ====
stop_signal_file = os.path.join(BASE_DIR, "stop_training.txt")
model_saving_folder = os.path.join(BASE_DIR, "Trained_models", "DynEx_CLoRA_CIL_v10", f"Period_{period}")
ensure_folder(model_saving_folder)

# ==== Load Period 3 Data ====
X_train = np.load(os.path.join(save_dir, f"X_train_p{period}.npy"))
y_train = np.load(os.path.join(save_dir, f"y_train_p{period}.npy"))
X_val   = np.load(os.path.join(save_dir, f"X_test_p{period}.npy"))
y_val   = np.load(os.path.join(save_dir, f"y_test_p{period}.npy"))

# ==== Device ====
device = auto_select_cuda_device()

# ==== Load Period 2 Features ====
prev_folder = os.path.join(BASE_DIR, "Trained_models", "DynEx_CLoRA_CIL_v10", "Period_2")
class_features_path = os.path.join(prev_folder, "class_features.pkl")
with open(class_features_path, "rb") as f:
    class_features_dict = pickle.load(f)
print(f"✅ Loaded class features from: {class_features_path}")

# ==== Load Period 2 Checkpoint ====
prev_model_path = os.path.join(prev_folder, "ResNet18_1D_LoRA_best.pth")
checkpoint = torch.load(prev_model_path, map_location=device)

input_channels = X_train.shape[2]
output_size = len(np.unique(y_train))

# ==== Build Teacher Model ====
teacher_model = ResNet18_1D_LoRA(input_channels=input_channels, output_size=output_size - 2).to(device)
teacher_model.load_state_dict(checkpoint["model_state_dict"], strict=False)
teacher_model.eval()

# ==== Build Student Model ====
model = ResNet18_1D_LoRA(input_channels=input_channels, output_size=output_size).to(device)

# ==== Sync Adapter Structure ====
num_lora_groups = checkpoint.get("num_lora_groups", 0)
print(f"🔄 Number of LoRA groups: {num_lora_groups}")
related_labels = checkpoint.get("related_labels", {"base": [0, 1]})
for _ in range(num_lora_groups):
    model.add_lora_adapter()

# ==== Load Previous Weights (excluding FC and LoRA) ====
model_dict = model.state_dict()
prev_state_dict = checkpoint["model_state_dict"]
filtered_dict = {
    k: v for k, v in prev_state_dict.items()
    if k in model_dict and model_dict[k].shape == v.shape and k not in ["fc.weight", "fc.bias"]
}
model.load_state_dict(filtered_dict, strict=False)
for k in model_dict:
    if k not in filtered_dict:
        print(f"🔍 Not loaded: {k}, shape={model_dict[k].shape}")
print("✅ Loaded shared weights from Period 2 (excluding FC only)")

# ==== Training Configuration ====
learning_rate = 1e-3
weight_decay = 1e-5
num_epochs = 200
batch_size = 64
alpha = 0.0 # No distillation
similarity_threshold = 0.85
stable_classes = [0, 2, 3]  # 根據 Period 2 class 結果
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.9, patience=10)

# ==== Start Training ====
train_with_dynex_clora_ecg(
    model=model,
    teacher_model=None,
    output_size=output_size,
    criterion=criterion,
    optimizer=optimizer,
    X_train=X_train,
    y_train=y_train,
    X_val=X_val,
    y_val=y_val,
    num_epochs=num_epochs,
    batch_size=batch_size,
    alpha=alpha,
    model_saving_folder=model_saving_folder,
    model_name="ResNet18_1D_LoRA",
    stop_signal_file=stop_signal_file,
    scheduler=scheduler,
    period=period,
    stable_classes=stable_classes,
    similarity_threshold=similarity_threshold,
    class_features_dict=class_features_dict,
    related_labels=related_labels,
    device=device
)

# ==== Cleanup ====
del X_train, y_train, X_val, y_val
del model, teacher_model, checkpoint, model_dict, prev_state_dict, filtered_dict
gc.collect()
torch.cuda.empty_cache()


🎯 Automatically selected GPU:
    - CUDA Device ID : 1
    - Memory Used    : 759 MiB
    - Device Name    : NVIDIA RTX A6000
✅ Loaded class features from: /mnt/mydisk/Continual_Learning_JL/Continual_Learning/Class_Incremental_CL/CPSC_CIL/Trained_models/DynEx_CLoRA_CIL_v10/Period_2/class_features.pkl


  checkpoint = torch.load(prev_model_path, map_location=device)


🔄 Number of LoRA groups: 1
✅ Added new LoRA adapters to 8 BasicBlocks
🔍 Not loaded: fc.weight, shape=torch.Size([6, 1024])
🔍 Not loaded: fc.bias, shape=torch.Size([6])
✅ Loaded shared weights from Period 2 (excluding FC only)

🚀 'train_with_dynex_clora_ecg' started for Period 3

✅ Removed existing folder: /mnt/mydisk/Continual_Learning_JL/Continual_Learning/Class_Incremental_CL/CPSC_CIL/Trained_models/DynEx_CLoRA_CIL_v10/Period_3

✅ Data Overview:
X_train: torch.Size([5120, 5000, 12]), y_train: torch.Size([5120])
X_val: torch.Size([1281, 5000, 12]), y_val: torch.Size([1281])

🔎 Similarity Analysis:
  Similarity threshold: 0.8500
  Existing classes: [0, 1, 2, 3]
  Current classes: [0, 1, 2, 3, 4, 5]
  New classes: [4, 5]

📊 Similarity Scores:
  New Class 4:
    - Existing Class 2: 0.8920
    - Existing Class 0: 0.8823
    - Existing Class 3: 0.8816
    - Existing Class 1: 0.8797
  New Class 5:
    - Existing Class 1: 0.8797
    - Existing Class 3: 0.8523
    - Existing Class 2: 0.8347
 

: 