In [5]:
import copy
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import os
import numpy as np
import random
import yaml
from models.Flexible_DANN_MMD import Flexible_DANN
from PKLDataset import PKLDataset
from utils.general_train_and_test import general_test_model
from models.get_no_label_dataloader import get_target_loader
from models.MMD import *
from collections import deque



def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def get_dataloaders(source_path, target_path, batch_size):
    source_dataset = PKLDataset(txt_path=source_path)
    source_loader = DataLoader(source_dataset, batch_size=batch_size, shuffle=True)
    target_loader = get_target_loader(target_path, batch_size=batch_size, shuffle=True)
    return source_loader, target_loader

def dann_lambda(epoch, num_epochs):
    """
    常用的 DANN λ 调度：从 0 平滑升到 1
    你也可以把 -10 调轻/重来改变上升速度
    """
    p = epoch / float(num_epochs)
    return 2. / (1. + np.exp(-10 * p)) - 1.

def mmd_lambda(epoch, num_epochs, max_lambda=1e-1):
    # 0 → max_lambda，S 型上升
    p = epoch / max(1, num_epochs - 1)         # p ∈ [0,1]
    s = 1.0 / (1.0 + torch.exp(torch.tensor(-10.0*(p - 0.5))))  # ∈ (0,1)
    return float(max_lambda * s)

def train_dann_with_mmd(model, source_loader, target_loader,
                        optimizer, criterion_cls, criterion_domain,
                        device, num_epochs=20,
                        lambda_dann=0.1,           # 域分类器的权重
                        lambda_mmd_max=1e-1,       # MMD 的最大权重
                        use_mk=False,               # 是否用多核
                        scheduler=None):
    PATIENCE = 3
    MIN_EPOCH = 10

    best_gap = 0.5
    best_cls = float('inf')
    best_mmd = float('inf')
    best_model_state = None
    patience = 0

    MMD_THRESH = 3e-2  # MMD²足够小的阈值，按任务可调（0.02~0.05常见）
    MMD_PLATEAU_EPS = 5e-3  # 平台期判定的波动阈值
    mmd_hist = deque(maxlen=5)  # 用最近5个epoch判断是否进入平台期

    mmd_fn = (lambda x, y: mmd_mk_biased(x, y, gammas=(0.5,1,2,4,8))) if use_mk \
             else (lambda x, y: mmd_rbf_biased(x, y, gamma=None))

    for epoch in range(num_epochs):
        cls_loss_sum, dom_loss_sum, mmd_loss_sum, total_loss_sum = 0.0, 0.0, 0.0, 0.0
        total_cls_samples, total_dom_samples = 0, 0
        dom_correct, dom_total = 0, 0
        model.train()


        for (src_x, src_y), tgt_x in zip(source_loader, target_loader):
            src_x, src_y = src_x.to(device), src_y.to(device)
            tgt_x = tgt_x.to(device)

            cls_out_src, dom_out_src, feat_src = model(src_x)
            _,            dom_out_tgt, feat_tgt = model(tgt_x)

            # 1) 分类损失（仅源域）
            loss_cls = criterion_cls(cls_out_src, src_y)

            # 2) 域分类损失（DANN）
            dom_label_src = torch.zeros(src_x.size(0), dtype=torch.long, device=device)
            dom_label_tgt = torch.ones(tgt_x.size(0),  dtype=torch.long, device=device)
            loss_dom = criterion_domain(dom_out_src, dom_label_src) + \
                       criterion_domain(dom_out_tgt, dom_label_tgt)

            # 3) RBF‑MMD（特征对齐）
            # 建议先做 L2 归一化，提升稳定性
            feat_src_n = F.normalize(feat_src, dim=1)
            feat_tgt_n = F.normalize(feat_tgt, dim=1)
            loss_mmd = mmd_fn(feat_src_n, feat_tgt_n)

            # 4) 组合总损失
            #    - DANN 的 lambda 可继续用你已有的动态 dann_lambda
            #    - MMD 的权重做 warm‑up（避免一开始就把决策结构抹平）
            lambda_dann_now = dann_lambda(epoch, num_epochs) if callable(lambda_dann) else lambda_dann
            lambda_mmd_now  = float(mmd_lambda(epoch, num_epochs, max_lambda=lambda_mmd_max))

            loss = loss_cls + lambda_dann_now * loss_dom + lambda_mmd_now * loss_mmd

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # 记录指标
            cls_loss_sum  += loss_cls.item() * src_x.size(0)
            dom_loss_sum  += loss_dom.item() * (src_x.size(0) + tgt_x.size(0))
            mmd_loss_sum  += loss_mmd.item() * (src_x.size(0) + tgt_x.size(0))
            total_loss_sum += loss.item() * (src_x.size(0) + tgt_x.size(0))

            total_cls_samples += src_x.size(0)
            total_dom_samples += (src_x.size(0) + tgt_x.size(0))

            # 域分类准确率
            dom_preds_src = torch.argmax(dom_out_src, dim=1)
            dom_preds_tgt = torch.argmax(dom_out_tgt, dim=1)
            dom_correct += (dom_preds_src == dom_label_src).sum().item()
            dom_correct += (dom_preds_tgt == dom_label_tgt).sum().item()
            dom_total   += dom_label_src.size(0) + dom_label_tgt.size(0)

        # ——Epoch 级日志——
        avg_cls_loss  = cls_loss_sum  / max(1, total_cls_samples)
        avg_dom_loss  = dom_loss_sum  / max(1, total_dom_samples)
        avg_mmd_loss  = mmd_loss_sum  / max(1, total_dom_samples)
        avg_total_loss= total_loss_sum/ max(1, total_dom_samples)
        dom_acc = dom_correct / max(1, dom_total)
        gap = abs(dom_acc - 0.5)

        if scheduler is not None:
            scheduler.step()

        print(f"[Epoch {epoch + 1}] Total: {avg_total_loss:.4f} | "
              f"Cls: {avg_cls_loss:.4f} | Dom: {avg_dom_loss:.4f} | "
              f"MMD: {avg_mmd_loss:.4f} | DomAcc: {dom_acc:.4f} | "
              f"λ_dann: {lambda_dann_now:.4f} | λ_mmd: {lambda_mmd_now:.4f}")

        mmd_hist.append(avg_mmd_loss)
        mmd_plateau = (len(mmd_hist) == mmd_hist.maxlen) and (max(mmd_hist) - min(mmd_hist) < MMD_PLATEAU_EPS)

        # 触发条件
        cond_align = (gap < 0.05)
        cond_cls = (avg_cls_loss < 0.5)
        cond_mmd_small = (avg_mmd_loss < MMD_THRESH)
        cond_mmd_plateau = mmd_plateau

        # 是否有任何指标刷新“最好”
        improved = False
        if gap < best_gap - 1e-4:
            best_gap = gap
            best_model_state = copy.deepcopy(model.state_dict())
            improved = True
        if avg_cls_loss < best_cls - 1e-4:
            best_cls = avg_cls_loss
            best_model_state = copy.deepcopy(model.state_dict())
            improved = True
        if avg_mmd_loss < best_mmd - 1e-5:
            best_mmd = avg_mmd_loss
            best_model_state = copy.deepcopy(model.state_dict())
            improved = True

        # ——Early stopping：对齐 + 分类收敛 + （MMD小 或 MMD平台期）——
        if epoch > MIN_EPOCH and cond_align and cond_cls and (cond_mmd_small or cond_mmd_plateau):
            if not improved:
                patience += 1
            else:
                patience = 0
            print(f"[INFO] patience {patience} / {PATIENCE} | MMD_small={cond_mmd_small} plateau={cond_mmd_plateau}")
            if patience >= PATIENCE:
                if best_model_state is not None:
                    model.load_state_dict(best_model_state)
                print("[INFO] Early stopping: domain aligned, classifier converged, and MMD stabilized.")
                break
        else:

            patience = 0
        print("[INFO] Evaluating on source test set...")
        source_test_path = '../datasets/source/test/DC_T197_RP.txt'
        s_dataset = PKLDataset(source_test_path)
        s_loader = DataLoader(s_dataset, batch_size=batch_size, shuffle=False)
        general_test_model(model, criterion_cls, s_loader, device)

    return model


if __name__ == '__main__':
    set_seed(seed=42)
    with open("../configs/default.yaml", 'r') as f:
        config = yaml.safe_load(f)['baseline']
    batch_size = config['batch_size']
    learning_rate = config['learning_rate']
    weight_decay = config['weight_decay']
    num_layers = config['num_layers']
    kernel_size = config['kernel_size']
    start_channels = config['start_channels']
    num_epochs = config['num_epochs']

    source_path = '../datasets/source/train/DC_T197_RP.txt'
    target_path = '../datasets/target/train/HC_T185_RP.txt'
    target_test_path = '../datasets/target/test/HC_T185_RP.txt'
    out_path = 'model'
    os.makedirs(out_path, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Flexible_DANN(num_layers=num_layers,
                          start_channels=start_channels,
                          kernel_size=kernel_size,
                          cnn_act='leakrelu',
                          num_classes=10,
                          lambda_=1).to(device)

    source_loader, target_loader = get_dataloaders(source_path, target_path, batch_size)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=num_epochs, eta_min=learning_rate * 0.1
    )
    criterion_cls = nn.CrossEntropyLoss()
    criterion_domain = nn.CrossEntropyLoss()

    print("[INFO] Starting standard DANN training (no pseudo labels)...")
    model = train_dann_with_mmd(model, source_loader, target_loader,
                                optimizer, criterion_cls, criterion_domain,
                                device, num_epochs=20, lambda_dann=0.5, scheduler=scheduler)

    print("[INFO] Evaluating on target test set...")
    test_dataset = PKLDataset(target_test_path)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    general_test_model(model, criterion_cls, test_loader, device)


[INFO] Starting standard DANN training (no pseudo labels)...
[Epoch 1] Total: 1.2440 | Cls: 0.6058 | Dom: 1.2747 | MMD: 0.0368 | DomAcc: 0.6091 | λ_dann: 0.5000 | λ_mmd: 0.0007
[INFO] Evaluating on source test set...
- test Loss: 0.117786, test Acc: 0.9665
[Epoch 2] Total: 0.7993 | Cls: 0.1340 | Dom: 1.3301 | MMD: 0.0678 | DomAcc: 0.5857 | λ_dann: 0.5000 | λ_mmd: 0.0011
[INFO] Evaluating on source test set...
- test Loss: 0.097509, test Acc: 0.9685
[Epoch 3] Total: 0.7864 | Cls: 0.1028 | Dom: 1.3667 | MMD: 0.0560 | DomAcc: 0.5511 | λ_dann: 0.5000 | λ_mmd: 0.0019
[INFO] Evaluating on source test set...
- test Loss: 0.082509, test Acc: 0.9783
[Epoch 4] Total: 0.7638 | Cls: 0.0781 | Dom: 1.3709 | MMD: 0.0516 | DomAcc: 0.5596 | λ_dann: 0.5000 | λ_mmd: 0.0032
[INFO] Evaluating on source test set...
- test Loss: 0.043204, test Acc: 0.9823
[Epoch 5] Total: 0.7463 | Cls: 0.0771 | Dom: 1.3376 | MMD: 0.0532 | DomAcc: 0.5982 | λ_dann: 0.5000 | λ_mmd: 0.0052
[INFO] Evaluating on source test set...

In [9]:
import copy
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import os
import numpy as np
import random
import yaml
from models.Flexible_DANN_MMD import Flexible_DANN
from PKLDataset import PKLDataset
from utils.general_train_and_test import general_test_model
from models.get_no_label_dataloader import get_target_loader
from models.MMD import *
from collections import deque



def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def get_dataloaders(source_path, target_path, batch_size):
    source_dataset = PKLDataset(txt_path=source_path)
    source_loader = DataLoader(source_dataset, batch_size=batch_size, shuffle=True)
    target_loader = get_target_loader(target_path, batch_size=batch_size, shuffle=True)
    return source_loader, target_loader

def dann_lambda(epoch, num_epochs):
    """
    常用的 DANN λ 调度：从 0 平滑升到 1
    你也可以把 -10 调轻/重来改变上升速度
    """
    p = epoch / float(num_epochs)
    return 2. / (1. + np.exp(-10 * p)) - 1.

def mmd_lambda(epoch, num_epochs, max_lambda=1e-1):
    # 0 → max_lambda，S 型上升
    p = epoch / max(1, num_epochs - 1)         # p ∈ [0,1]
    s = 1.0 / (1.0 + torch.exp(torch.tensor(-10.0*(p - 0.5))))  # ∈ (0,1)
    return float(max_lambda * s)

def train_dann_with_mmd(model, source_loader, target_loader,
                        optimizer, criterion_cls, criterion_domain,
                        device, num_epochs=20,
                        lambda_dann=0.1,           # 域分类器的权重
                        lambda_mmd_max=1e-1,       # MMD 的最大权重
                        use_mk=False,               # 是否用多核
                        scheduler=None):
    PATIENCE = 3
    MIN_EPOCH = 10

    best_gap = 0.5
    best_cls = float('inf')
    best_mmd = float('inf')
    best_model_state = None
    patience = 0

    MMD_THRESH = 3e-2  # MMD²足够小的阈值，按任务可调（0.02~0.05常见）
    MMD_PLATEAU_EPS = 5e-3  # 平台期判定的波动阈值
    mmd_hist = deque(maxlen=5)  # 用最近5个epoch判断是否进入平台期

    mmd_fn = (lambda x, y: mmd_mk_biased(x, y, gammas=(0.5,1,2,4,8))) if use_mk \
             else (lambda x, y: mmd_rbf_biased(x, y, gamma=None))

    for epoch in range(num_epochs):
        cls_loss_sum, dom_loss_sum, mmd_loss_sum, total_loss_sum = 0.0, 0.0, 0.0, 0.0
        total_cls_samples, total_dom_samples = 0, 0
        dom_correct, dom_total = 0, 0
        model.train()


        for (src_x, src_y), tgt_x in zip(source_loader, target_loader):
            src_x, src_y = src_x.to(device), src_y.to(device)
            tgt_x = tgt_x.to(device)

            cls_out_src, dom_out_src, feat_src = model(src_x)
            _,            dom_out_tgt, feat_tgt = model(tgt_x)

            # 1) 分类损失（仅源域）
            loss_cls = criterion_cls(cls_out_src, src_y)

            # 2) 域分类损失（DANN）
            dom_label_src = torch.zeros(src_x.size(0), dtype=torch.long, device=device)
            dom_label_tgt = torch.ones(tgt_x.size(0),  dtype=torch.long, device=device)
            loss_dom = criterion_domain(dom_out_src, dom_label_src) + \
                       criterion_domain(dom_out_tgt, dom_label_tgt)

            # 3) RBF‑MMD（特征对齐）
            # 建议先做 L2 归一化，提升稳定性
            feat_src_n = F.normalize(feat_src, dim=1)
            feat_tgt_n = F.normalize(feat_tgt, dim=1)
            loss_mmd = mmd_fn(feat_src_n, feat_tgt_n)

            # 4) 组合总损失
            #    - DANN 的 lambda 可继续用你已有的动态 dann_lambda
            #    - MMD 的权重做 warm‑up（避免一开始就把决策结构抹平）
            lambda_dann_now = dann_lambda(epoch, num_epochs) if callable(lambda_dann) else lambda_dann
            lambda_mmd_now  = float(mmd_lambda(epoch, num_epochs, max_lambda=lambda_mmd_max))

            loss = loss_cls + lambda_dann_now * loss_dom + lambda_mmd_now * loss_mmd

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # 记录指标
            cls_loss_sum  += loss_cls.item() * src_x.size(0)
            dom_loss_sum  += loss_dom.item() * (src_x.size(0) + tgt_x.size(0))
            mmd_loss_sum  += loss_mmd.item() * (src_x.size(0) + tgt_x.size(0))
            total_loss_sum += loss.item() * (src_x.size(0) + tgt_x.size(0))

            total_cls_samples += src_x.size(0)
            total_dom_samples += (src_x.size(0) + tgt_x.size(0))

            # 域分类准确率
            dom_preds_src = torch.argmax(dom_out_src, dim=1)
            dom_preds_tgt = torch.argmax(dom_out_tgt, dim=1)
            dom_correct += (dom_preds_src == dom_label_src).sum().item()
            dom_correct += (dom_preds_tgt == dom_label_tgt).sum().item()
            dom_total   += dom_label_src.size(0) + dom_label_tgt.size(0)

        # ——Epoch 级日志——
        avg_cls_loss  = cls_loss_sum  / max(1, total_cls_samples)
        avg_dom_loss  = dom_loss_sum  / max(1, total_dom_samples)
        avg_mmd_loss  = mmd_loss_sum  / max(1, total_dom_samples)
        avg_total_loss= total_loss_sum/ max(1, total_dom_samples)
        dom_acc = dom_correct / max(1, dom_total)
        gap = abs(dom_acc - 0.5)

        if scheduler is not None:
            scheduler.step()

        print(f"[Epoch {epoch + 1}] Total: {avg_total_loss:.4f} | "
              f"Cls: {avg_cls_loss:.4f} | Dom: {avg_dom_loss:.4f} | "
              f"MMD: {avg_mmd_loss:.4f} | DomAcc: {dom_acc:.4f} | "
              f"λ_dann: {lambda_dann_now:.4f} | λ_mmd: {lambda_mmd_now:.4f}")

        mmd_hist.append(avg_mmd_loss)
        mmd_plateau = (len(mmd_hist) == mmd_hist.maxlen) and (max(mmd_hist) - min(mmd_hist) < MMD_PLATEAU_EPS)

        # 触发条件
        cond_align = (gap < 0.05)
        cond_cls = (avg_cls_loss < 0.5)
        cond_mmd_small = (avg_mmd_loss < MMD_THRESH)
        cond_mmd_plateau = mmd_plateau

        # 是否有任何指标刷新“最好”
        improved = False
        if gap < best_gap - 1e-4:
            best_gap = gap
            best_model_state = copy.deepcopy(model.state_dict())
            improved = True
        if avg_cls_loss < best_cls - 1e-4:
            best_cls = avg_cls_loss
            best_model_state = copy.deepcopy(model.state_dict())
            improved = True
        if avg_mmd_loss < best_mmd - 1e-5:
            best_mmd = avg_mmd_loss
            best_model_state = copy.deepcopy(model.state_dict())
            improved = True

        # ——Early stopping：对齐 + 分类收敛 + （MMD小 或 MMD平台期）——
        if epoch > MIN_EPOCH and cond_align and cond_cls and (cond_mmd_small or cond_mmd_plateau):
            if not improved:
                patience += 1
            else:
                patience = 0
            print(f"[INFO] patience {patience} / {PATIENCE} | MMD_small={cond_mmd_small} plateau={cond_mmd_plateau}")
            if patience >= PATIENCE:
                if best_model_state is not None:
                    model.load_state_dict(best_model_state)
                print("[INFO] Early stopping: domain aligned, classifier converged, and MMD stabilized.")
                break
        else:

            patience = 0
        print("[INFO] Evaluating on source test set...")
        source_test_path = '../datasets/source/test/DC_T197_RP.txt'
        s_dataset = PKLDataset(source_test_path)
        s_loader = DataLoader(s_dataset, batch_size=batch_size, shuffle=False)
        general_test_model(model, criterion_cls, s_loader, device)

    return model


if __name__ == '__main__':
    set_seed(seed=18)
    with open("../configs/default.yaml", 'r') as f:
        config = yaml.safe_load(f)['baseline']
    batch_size = config['batch_size']
    learning_rate = config['learning_rate']
    weight_decay = config['weight_decay']
    num_layers = config['num_layers']
    kernel_size = config['kernel_size']
    start_channels = config['start_channels']
    num_epochs = config['num_epochs']

    source_path = '../datasets/source/train/DC_T197_RP.txt'
    target_path = '../datasets/target/train/HC_T188_RP.txt'
    target_test_path = '../datasets/target/test/HC_T188_RP.txt'
    out_path = 'model'
    os.makedirs(out_path, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Flexible_DANN(num_layers=num_layers,
                          start_channels=start_channels,
                          kernel_size=kernel_size,
                          cnn_act='leakrelu',
                          num_classes=10,
                          lambda_=1).to(device)

    source_loader, target_loader = get_dataloaders(source_path, target_path, batch_size)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=num_epochs, eta_min=learning_rate * 0.1
    )
    criterion_cls = nn.CrossEntropyLoss()
    criterion_domain = nn.CrossEntropyLoss()

    print("[INFO] Starting standard DANN training (no pseudo labels)...")
    model = train_dann_with_mmd(model, source_loader, target_loader,
                                optimizer, criterion_cls, criterion_domain,
                                device, num_epochs=30, lambda_dann=0.5,use_mk=True,scheduler=scheduler)

    print("[INFO] Evaluating on target test set...")
    test_dataset = PKLDataset(target_test_path)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    general_test_model(model, criterion_cls, test_loader, device)


[INFO] Starting standard DANN training (no pseudo labels)...
[Epoch 1] Total: 1.1895 | Cls: 0.5764 | Dom: 1.2251 | MMD: 0.0556 | DomAcc: 0.6446 | λ_dann: 0.5000 | λ_mmd: 0.0007
[INFO] Evaluating on source test set...
- test Loss: 0.149571, test Acc: 0.9547
[Epoch 2] Total: 0.7854 | Cls: 0.1362 | Dom: 1.2980 | MMD: 0.0733 | DomAcc: 0.6359 | λ_dann: 0.5000 | λ_mmd: 0.0009
[INFO] Evaluating on source test set...
- test Loss: 0.112532, test Acc: 0.9724
[Epoch 3] Total: 0.7367 | Cls: 0.0965 | Dom: 1.2804 | MMD: 0.0701 | DomAcc: 0.6248 | λ_dann: 0.5000 | λ_mmd: 0.0013
[INFO] Evaluating on source test set...
- test Loss: 0.118564, test Acc: 0.9547
[Epoch 4] Total: 0.7170 | Cls: 0.0644 | Dom: 1.3049 | MMD: 0.0660 | DomAcc: 0.6256 | λ_dann: 0.5000 | λ_mmd: 0.0019
[INFO] Evaluating on source test set...
- test Loss: 0.071798, test Acc: 0.9764
[Epoch 5] Total: 0.7360 | Cls: 0.0455 | Dom: 1.3806 | MMD: 0.0805 | DomAcc: 0.5358 | λ_dann: 0.5000 | λ_mmd: 0.0026
[INFO] Evaluating on source test set...

In [10]:
import copy
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import os
import numpy as np
import random
import yaml
from models.Flexible_DANN_MMD import Flexible_DANN
from PKLDataset import PKLDataset
from utils.general_train_and_test import general_test_model
from models.get_no_label_dataloader import get_target_loader
from models.MMD import *
from collections import deque



def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def get_dataloaders(source_path, target_path, batch_size):
    source_dataset = PKLDataset(txt_path=source_path)
    source_loader = DataLoader(source_dataset, batch_size=batch_size, shuffle=True)
    target_loader = get_target_loader(target_path, batch_size=batch_size, shuffle=True)
    return source_loader, target_loader

def dann_lambda(epoch, num_epochs):
    """
    常用的 DANN λ 调度：从 0 平滑升到 1
    你也可以把 -10 调轻/重来改变上升速度
    """
    p = epoch / float(num_epochs)
    return 2. / (1. + np.exp(-10 * p)) - 1.

def mmd_lambda(epoch, num_epochs, max_lambda=1e-1):
    # 0 → max_lambda，S 型上升
    p = epoch / max(1, num_epochs - 1)         # p ∈ [0,1]
    s = 1.0 / (1.0 + torch.exp(torch.tensor(-10.0*(p - 0.5))))  # ∈ (0,1)
    return float(max_lambda * s)

def set_bn_eval(m):
    if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
        m.eval()


def train_dann_with_mmd(model, source_loader, target_loader,
                        optimizer, criterion_cls, criterion_domain,
                        device, num_epochs=20,
                        lambda_dann=0.1,           # 域分类器的权重
                        lambda_mmd_max=1e-1,       # MMD 的最大权重
                        use_mk=False,               # 是否用多核
                        scheduler=None):
    PATIENCE = 3
    MIN_EPOCH = 10

    best_gap = 0.5
    best_cls = float('inf')
    best_mmd = float('inf')
    best_model_state = None
    patience = 0

    MMD_THRESH = 3e-2  # MMD²足够小的阈值，按任务可调（0.02~0.05常见）
    MMD_PLATEAU_EPS = 5e-3  # 平台期判定的波动阈值
    mmd_hist = deque(maxlen=5)  # 用最近5个epoch判断是否进入平台期

    mmd_fn = (lambda x, y: mmd_mk_biased(x, y, gammas=(0.5,1,2,4,8))) if use_mk \
             else (lambda x, y: mmd_rbf_biased(x, y, gamma=None))

    for epoch in range(num_epochs):
        cls_loss_sum, dom_loss_sum, mmd_loss_sum, total_loss_sum = 0.0, 0.0, 0.0, 0.0
        total_cls_samples, total_dom_samples = 0, 0
        dom_correct, dom_total = 0, 0
        model.train()


        for (src_x, src_y), tgt_x in zip(source_loader, target_loader):
            src_x, src_y = src_x.to(device), src_y.to(device)
            tgt_x = tgt_x.to(device)

            cls_out_src, dom_out_src, feat_src = model(src_x)
            _,            dom_out_tgt, feat_tgt = model(tgt_x)

            # 1) 分类损失（仅源域）
            loss_cls = criterion_cls(cls_out_src, src_y)

            # 2) 域分类损失（DANN）
            dom_label_src = torch.zeros(src_x.size(0), dtype=torch.long, device=device)
            dom_label_tgt = torch.ones(tgt_x.size(0),  dtype=torch.long, device=device)
            loss_dom = criterion_domain(dom_out_src, dom_label_src) + \
                       criterion_domain(dom_out_tgt, dom_label_tgt)

            # 3) RBF‑MMD（特征对齐）
            # 建议先做 L2 归一化，提升稳定性
            feat_src_n = F.normalize(feat_src, dim=1)
            feat_tgt_n = F.normalize(feat_tgt, dim=1)
            loss_mmd = mmd_fn(feat_src_n, feat_tgt_n)

            # 4) 组合总损失
            #    - DANN 的 lambda 可继续用你已有的动态 dann_lambda
            #    - MMD 的权重做 warm‑up（避免一开始就把决策结构抹平）
            lambda_dann_now = dann_lambda(epoch, num_epochs) if callable(lambda_dann) else lambda_dann
            lambda_mmd_now  = float(mmd_lambda(epoch, num_epochs, max_lambda=lambda_mmd_max))

            loss = loss_cls + lambda_dann_now * loss_dom + lambda_mmd_now * loss_mmd

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # 记录指标
            cls_loss_sum  += loss_cls.item() * src_x.size(0)
            dom_loss_sum  += loss_dom.item() * (src_x.size(0) + tgt_x.size(0))
            mmd_loss_sum  += loss_mmd.item() * (src_x.size(0) + tgt_x.size(0))
            total_loss_sum += loss.item() * (src_x.size(0) + tgt_x.size(0))

            total_cls_samples += src_x.size(0)
            total_dom_samples += (src_x.size(0) + tgt_x.size(0))

            # 域分类准确率
            dom_preds_src = torch.argmax(dom_out_src, dim=1)
            dom_preds_tgt = torch.argmax(dom_out_tgt, dim=1)
            dom_correct += (dom_preds_src == dom_label_src).sum().item()
            dom_correct += (dom_preds_tgt == dom_label_tgt).sum().item()
            dom_total   += dom_label_src.size(0) + dom_label_tgt.size(0)

        # ——Epoch 级日志——
        avg_cls_loss  = cls_loss_sum  / max(1, total_cls_samples)
        avg_dom_loss  = dom_loss_sum  / max(1, total_dom_samples)
        avg_mmd_loss  = mmd_loss_sum  / max(1, total_dom_samples)
        avg_total_loss= total_loss_sum/ max(1, total_dom_samples)
        dom_acc = dom_correct / max(1, dom_total)
        gap = abs(dom_acc - 0.5)

        if scheduler is not None:
            scheduler.step()

        print(f"[Epoch {epoch + 1}] Total: {avg_total_loss:.4f} | "
              f"Cls: {avg_cls_loss:.4f} | Dom: {avg_dom_loss:.4f} | "
              f"MMD: {avg_mmd_loss:.4f} | DomAcc: {dom_acc:.4f} | "
              f"λ_dann: {lambda_dann_now:.4f} | λ_mmd: {lambda_mmd_now:.4f}")

        mmd_hist.append(avg_mmd_loss)
        mmd_plateau = (len(mmd_hist) == mmd_hist.maxlen) and (max(mmd_hist) - min(mmd_hist) < MMD_PLATEAU_EPS)

        # 触发条件
        cond_align = (gap < 0.05)
        cond_cls = (avg_cls_loss < 0.5)
        cond_mmd_small = (avg_mmd_loss < MMD_THRESH)
        cond_mmd_plateau = mmd_plateau

        # 是否有任何指标刷新“最好”
        improved = False
        if gap < best_gap - 1e-4:
            best_gap = gap
            best_model_state = copy.deepcopy(model.state_dict())
            improved = True
        if avg_cls_loss < best_cls - 1e-4:
            best_cls = avg_cls_loss
            best_model_state = copy.deepcopy(model.state_dict())
            improved = True
        if avg_mmd_loss < best_mmd - 1e-5:
            best_mmd = avg_mmd_loss
            best_model_state = copy.deepcopy(model.state_dict())
            improved = True

        # ——Early stopping：对齐 + 分类收敛 + （MMD小 或 MMD平台期）——
        if epoch > MIN_EPOCH and cond_align and cond_cls and (cond_mmd_small or cond_mmd_plateau):
            if not improved:
                patience += 1
            else:
                patience = 0
            print(f"[INFO] patience {patience} / {PATIENCE} | MMD_small={cond_mmd_small} plateau={cond_mmd_plateau}")
            if patience >= PATIENCE:
                if best_model_state is not None:
                    model.load_state_dict(best_model_state)
                print("[INFO] Early stopping: domain aligned, classifier converged, and MMD stabilized.")
                break
        else:

            patience = 0

        model.apply(set_bn_eval)

    return model


if __name__ == '__main__':
    set_seed(seed=42)
    with open("../configs/default.yaml", 'r') as f:
        config = yaml.safe_load(f)['baseline']
    batch_size = config['batch_size']
    learning_rate = config['learning_rate']
    weight_decay = config['weight_decay']
    num_layers = config['num_layers']
    kernel_size = config['kernel_size']
    start_channels = config['start_channels']
    num_epochs = config['num_epochs']

    source_path = '../datasets/source/train/DC_T197_RP.txt'
    target_path = '../datasets/target/train/HC_T185_RP.txt'
    target_test_path = '../datasets/target/test/HC_T185_RP.txt'
    out_path = 'model'
    os.makedirs(out_path, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Flexible_DANN(num_layers=num_layers,
                          start_channels=start_channels,
                          kernel_size=kernel_size,
                          cnn_act='leakrelu',
                          num_classes=10,
                          lambda_=1).to(device)

    source_loader, target_loader = get_dataloaders(source_path, target_path, batch_size)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=num_epochs, eta_min=learning_rate * 0.1
    )
    criterion_cls = nn.CrossEntropyLoss()
    criterion_domain = nn.CrossEntropyLoss()

    print("[INFO] Starting standard DANN training (no pseudo labels)...")
    model = train_dann_with_mmd(model, source_loader, target_loader,
                                optimizer, criterion_cls, criterion_domain,
                                device, num_epochs=20, lambda_dann=0.5, scheduler=scheduler)

    print("[INFO] Evaluating on target test set...")
    test_dataset = PKLDataset(target_test_path)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    general_test_model(model, criterion_cls, test_loader, device)


[INFO] Starting standard DANN training (no pseudo labels)...
[Epoch 1] Total: 1.2440 | Cls: 0.6058 | Dom: 1.2747 | MMD: 0.0368 | DomAcc: 0.6091 | λ_dann: 0.5000 | λ_mmd: 0.0007
[Epoch 2] Total: 0.7820 | Cls: 0.1334 | Dom: 1.2975 | MMD: 0.0681 | DomAcc: 0.6119 | λ_dann: 0.5000 | λ_mmd: 0.0011
[Epoch 3] Total: 0.7743 | Cls: 0.1089 | Dom: 1.3309 | MMD: 0.0512 | DomAcc: 0.5886 | λ_dann: 0.5000 | λ_mmd: 0.0019
[Epoch 4] Total: 0.7548 | Cls: 0.0903 | Dom: 1.3287 | MMD: 0.0429 | DomAcc: 0.5896 | λ_dann: 0.5000 | λ_mmd: 0.0032
[Epoch 5] Total: 0.7447 | Cls: 0.0931 | Dom: 1.3025 | MMD: 0.0456 | DomAcc: 0.6101 | λ_dann: 0.5000 | λ_mmd: 0.0052
[Epoch 6] Total: 0.6982 | Cls: 0.0678 | Dom: 1.2601 | MMD: 0.0449 | DomAcc: 0.6440 | λ_dann: 0.5000 | λ_mmd: 0.0086
[Epoch 7] Total: 0.7289 | Cls: 0.0625 | Dom: 1.3315 | MMD: 0.0455 | DomAcc: 0.5826 | λ_dann: 0.5000 | λ_mmd: 0.0137
[Epoch 8] Total: 0.7222 | Cls: 0.0657 | Dom: 1.3108 | MMD: 0.0485 | DomAcc: 0.6229 | λ_dann: 0.5000 | λ_mmd: 0.0212
[Epoch 9] T

In [16]:
import copy
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import os
import numpy as np
import random
import yaml
from models.Flexible_DANN_MMD import Flexible_DANN
from PKLDataset import PKLDataset
from utils.general_train_and_test import general_test_model
from models.get_no_label_dataloader import get_target_loader
from models.MMD import *
from collections import deque
import torch.nn.functional as F



def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def get_dataloaders(source_path, target_path, batch_size):
    source_dataset = PKLDataset(txt_path=source_path)
    source_loader = DataLoader(source_dataset, batch_size=batch_size, shuffle=True)
    target_loader = get_target_loader(target_path, batch_size=batch_size, shuffle=True)
    return source_loader, target_loader

def dann_lambda(epoch, num_epochs):
    """
    常用的 DANN λ 调度：从 0 平滑升到 1
    你也可以把 -10 调轻/重来改变上升速度
    """
    p = epoch / float(num_epochs)
    return 2. / (1. + np.exp(-10 * p)) - 1.

def mmd_lambda(epoch, num_epochs, max_lambda=1e-1):
    # 0 → max_lambda，S 型上升
    p = epoch / max(1, num_epochs - 1)         # p ∈ [0,1]
    s = 1.0 / (1.0 + torch.exp(torch.tensor(-10.0*(p - 0.5))))  # ∈ (0,1)
    return float(max_lambda * s)

def train_dann_with_mmd(model, source_loader, target_loader,
                        optimizer, criterion_cls, criterion_domain,
                        device, num_epochs=20,
                        lambda_dann=0.1,           # 域分类器的权重
                        lambda_mmd_max=1e-1,       # MMD 的最大权重
                        use_mk=True,               # 是否用多核
                        scheduler=None):
    PATIENCE = 3
    MIN_EPOCH = 10

    best_gap = 0.5
    best_cls = float('inf')
    best_mmd = float('inf')
    best_model_state = None
    patience = 0

    MMD_THRESH = 4e-2  # MMD²足够小的阈值，按任务可调（0.02~0.05常见）
    MMD_PLATEAU_EPS = 5e-2  # 平台期判定的波动阈值
    mmd_hist = deque(maxlen=5)  # 用最近5个epoch判断是否进入平台期

    mmd_fn = (lambda x, y: mmd_mk_biased(x, y, gammas=(0.5,1,2,4,8))) if use_mk \
             else (lambda x, y: mmd_rbf_biased(x, y, gamma=None))

    for epoch in range(num_epochs):
        cls_loss_sum, dom_loss_sum, mmd_loss_sum, total_loss_sum = 0.0, 0.0, 0.0, 0.0
        total_cls_samples, total_dom_samples = 0, 0
        dom_correct, dom_total = 0, 0
        model.train()


        for (src_x, src_y), tgt_x in zip(source_loader, target_loader):
            src_x, src_y = src_x.to(device), src_y.to(device)
            tgt_x = tgt_x.to(device)

            cls_out_src, dom_out_src, feat_src = model(src_x)
            _,            dom_out_tgt, feat_tgt = model(tgt_x)

            # 1) 分类损失（仅源域）
            loss_cls = criterion_cls(cls_out_src, src_y)

            # 2) 域分类损失（DANN）
            dom_label_src = torch.zeros(src_x.size(0), dtype=torch.long, device=device)
            dom_label_tgt = torch.ones(tgt_x.size(0),  dtype=torch.long, device=device)
            loss_dom = criterion_domain(dom_out_src, dom_label_src) + \
                       criterion_domain(dom_out_tgt, dom_label_tgt)

            # 3) RBF‑MMD（特征对齐）
            # 建议先做 L2 归一化，提升稳定性
            feat_src_n = F.normalize(feat_src, dim=1)
            feat_tgt_n = F.normalize(feat_tgt, dim=1)
            loss_mmd = mmd_fn(feat_src_n, feat_tgt_n)

            # 4) 组合总损失
            #    - DANN 的 lambda 可继续用你已有的动态 dann_lambda
            #    - MMD 的权重做 warm‑up（避免一开始就把决策结构抹平）
            lambda_dann_now = dann_lambda(epoch, num_epochs) if callable(lambda_dann) else lambda_dann
            lambda_mmd_now  = float(mmd_lambda(epoch, num_epochs, max_lambda=lambda_mmd_max))

            loss = loss_cls + lambda_dann_now * loss_dom + lambda_mmd_now * loss_mmd

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # 记录指标
            cls_loss_sum  += loss_cls.item() * src_x.size(0)
            dom_loss_sum  += loss_dom.item() * (src_x.size(0) + tgt_x.size(0))
            mmd_loss_sum  += loss_mmd.item() * (src_x.size(0) + tgt_x.size(0))
            total_loss_sum += loss.item() * (src_x.size(0) + tgt_x.size(0))

            total_cls_samples += src_x.size(0)
            total_dom_samples += (src_x.size(0) + tgt_x.size(0))

            # 域分类准确率
            dom_preds_src = torch.argmax(dom_out_src, dim=1)
            dom_preds_tgt = torch.argmax(dom_out_tgt, dim=1)
            dom_correct += (dom_preds_src == dom_label_src).sum().item()
            dom_correct += (dom_preds_tgt == dom_label_tgt).sum().item()
            dom_total   += dom_label_src.size(0) + dom_label_tgt.size(0)

        # ——Epoch 级日志——
        avg_cls_loss  = cls_loss_sum  / max(1, total_cls_samples)
        avg_dom_loss  = dom_loss_sum  / max(1, total_dom_samples)
        avg_mmd_loss  = mmd_loss_sum  / max(1, total_dom_samples)
        avg_total_loss= total_loss_sum/ max(1, total_dom_samples)
        dom_acc = dom_correct / max(1, dom_total)
        gap = abs(dom_acc - 0.5)

        if scheduler is not None:
            scheduler.step()

        print(f"[Epoch {epoch + 1}] Total: {avg_total_loss:.4f} | "
              f"Cls: {avg_cls_loss:.4f} | Dom: {avg_dom_loss:.4f} | "
              f"MMD: {avg_mmd_loss:.4f} | DomAcc: {dom_acc:.4f} | "
              f"λ_dann: {lambda_dann_now:.4f} | λ_mmd: {lambda_mmd_now:.4f}")

        mmd_hist.append(avg_mmd_loss)
        mmd_plateau = (len(mmd_hist) == mmd_hist.maxlen) and (max(mmd_hist) - min(mmd_hist) < MMD_PLATEAU_EPS)

        # 触发条件
        cond_align = (gap < 0.05)
        cond_cls = (avg_cls_loss < 0.5)
        cond_mmd_small = (avg_mmd_loss < MMD_THRESH)
        cond_mmd_plateau = mmd_plateau

        # 是否有任何指标刷新“最好”
        improved = False
        if gap < best_gap - 1e-4:
            best_gap = gap
            best_model_state = copy.deepcopy(model.state_dict())
            improved = True
        if avg_cls_loss < best_cls - 1e-4:
            best_cls = avg_cls_loss
            best_model_state = copy.deepcopy(model.state_dict())
            improved = True
        if avg_mmd_loss < best_mmd - 1e-5:
            best_mmd = avg_mmd_loss
            best_model_state = copy.deepcopy(model.state_dict())
            improved = True

        # ——Early stopping：对齐 + 分类收敛 + （MMD小 或 MMD平台期）——
        if epoch > MIN_EPOCH and cond_align and cond_cls and (cond_mmd_small or cond_mmd_plateau):
            if not improved:
                patience += 1
            else:
                patience = 0
            print(f"[INFO] patience {patience} / {PATIENCE} | MMD_small={cond_mmd_small} plateau={cond_mmd_plateau}")
            if patience >= PATIENCE:
                if best_model_state is not None:
                    model.load_state_dict(best_model_state)
                print("[INFO] Early stopping: domain aligned, classifier converged, and MMD stabilized.")
                break
        else:

            patience = 0
        if epoch == (num_epochs-1):
            if best_model_state is not None:
                    model.load_state_dict(best_model_state)

    return model


if __name__ == '__main__':
    set_seed(seed=42)
    with open("../configs/default.yaml", 'r') as f:
        config = yaml.safe_load(f)['baseline']
    batch_size = config['batch_size']
    learning_rate = config['learning_rate']
    weight_decay = config['weight_decay']
    num_layers = config['num_layers']
    kernel_size = config['kernel_size']
    start_channels = config['start_channels']
    num_epochs = config['num_epochs']

    source_path = '../datasets/source/train/DC_T197_RP.txt'
    target_path = '../datasets/HC_T185_RP.txt'
    target_test_path = '../datasets/HC_T185_RP.txt'
    out_path = 'model'
    os.makedirs(out_path, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Flexible_DANN(num_layers=num_layers,
                          start_channels=start_channels,
                          kernel_size=kernel_size,
                          cnn_act='leakrelu',
                          num_classes=10,
                          lambda_=0.5).to(device)

    source_loader, target_loader = get_dataloaders(source_path, target_path, batch_size)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=num_epochs, eta_min=learning_rate * 0.1
    )
    criterion_cls = nn.CrossEntropyLoss()
    criterion_domain = nn.CrossEntropyLoss()

    print("[INFO] Starting standard DANN training (no pseudo labels)...")
    model = train_dann_with_mmd(model, source_loader, target_loader,
                                optimizer, criterion_cls, criterion_domain,
                                device, num_epochs=20, lambda_dann=0.5, use_mk=True,scheduler=scheduler)

    print("[INFO] Evaluating on target test set...")
    test_dataset = PKLDataset(target_test_path)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    general_test_model(model, criterion_cls, test_loader, device)


[INFO] Starting standard DANN training (no pseudo labels)...
[Epoch 1] Total: 1.0782 | Cls: 0.5802 | Dom: 0.9961 | MMD: 0.1656 | DomAcc: 0.7460 | λ_dann: 0.5000 | λ_mmd: 0.0007
[Epoch 2] Total: 0.6515 | Cls: 0.1473 | Dom: 1.0077 | MMD: 0.1466 | DomAcc: 0.7404 | λ_dann: 0.5000 | λ_mmd: 0.0011
[Epoch 3] Total: 0.6933 | Cls: 0.1279 | Dom: 1.1303 | MMD: 0.1383 | DomAcc: 0.6893 | λ_dann: 0.5000 | λ_mmd: 0.0019
[Epoch 4] Total: 0.6753 | Cls: 0.0984 | Dom: 1.1529 | MMD: 0.1339 | DomAcc: 0.6866 | λ_dann: 0.5000 | λ_mmd: 0.0032
[Epoch 5] Total: 0.6815 | Cls: 0.0905 | Dom: 1.1805 | MMD: 0.1312 | DomAcc: 0.6775 | λ_dann: 0.5000 | λ_mmd: 0.0052
[Epoch 6] Total: 0.6827 | Cls: 0.0724 | Dom: 1.2187 | MMD: 0.1180 | DomAcc: 0.6472 | λ_dann: 0.5000 | λ_mmd: 0.0086
[Epoch 7] Total: 0.6870 | Cls: 0.0542 | Dom: 1.2622 | MMD: 0.1237 | DomAcc: 0.6511 | λ_dann: 0.5000 | λ_mmd: 0.0137
[Epoch 8] Total: 0.7050 | Cls: 0.0549 | Dom: 1.2948 | MMD: 0.1267 | DomAcc: 0.6182 | λ_dann: 0.5000 | λ_mmd: 0.0212
[Epoch 9] T

In [24]:
import copy
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import os
import numpy as np
import random
import yaml
from models.Flexible_DANN_MMD import Flexible_DANN
from PKLDataset import PKLDataset
from utils.general_train_and_test import general_test_model
from models.get_no_label_dataloader import get_target_loader
from models.MMD import *
from collections import deque
import torch.nn.functional as F



def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def get_dataloaders(source_path, target_path, batch_size):
    source_dataset = PKLDataset(txt_path=source_path)
    source_loader = DataLoader(source_dataset, batch_size=batch_size, shuffle=True)
    target_loader = get_target_loader(target_path, batch_size=batch_size, shuffle=True)
    return source_loader, target_loader

def dann_lambda(epoch, num_epochs):
    """
    常用的 DANN λ 调度：从 0 平滑升到 1
    你也可以把 -10 调轻/重来改变上升速度
    """
    p = epoch / float(num_epochs)
    return 2. / (1. + np.exp(-10 * p)) - 1.

def mmd_lambda(epoch, num_epochs, max_lambda=1e-1):
    # 0 → max_lambda，S 型上升
    p = epoch / max(1, num_epochs - 1)         # p ∈ [0,1]
    s = 1.0 / (1.0 + torch.exp(torch.tensor(-10.0*(p - 0.5))))  # ∈ (0,1)
    return float(max_lambda * s)

def train_dann_with_mmd(model, source_loader, target_loader,
                        optimizer, criterion_cls, criterion_domain,
                        device, num_epochs=20,
                        lambda_dann=0.1,           # 域分类器的权重
                        lambda_mmd_max=15e-2,       # MMD 的最大权重
                        use_mk=True,               # 是否用多核
                        scheduler=None):
    PATIENCE = 3
    MIN_EPOCH = 10

    best_gap = 0.5
    best_cls = float('inf')
    best_mmd = float('inf')
    best_model_state = None
    patience = 0

    MMD_THRESH = 4e-2  # MMD²足够小的阈值，按任务可调（0.02~0.05常见）
    MMD_PLATEAU_EPS = 5e-2  # 平台期判定的波动阈值
    mmd_hist = deque(maxlen=5)  # 用最近5个epoch判断是否进入平台期

    mmd_fn = (lambda x, y: mmd_mk_biased(x, y, gammas=(0.5,1,2,4,8))) if use_mk \
             else (lambda x, y: mmd_rbf_biased(x, y, gamma=None))

    for epoch in range(num_epochs):
        cls_loss_sum, dom_loss_sum, mmd_loss_sum, total_loss_sum = 0.0, 0.0, 0.0, 0.0
        total_cls_samples, total_dom_samples = 0, 0
        dom_correct, dom_total = 0, 0
        model.train()


        for (src_x, src_y), tgt_x in zip(source_loader, target_loader):
            src_x, src_y = src_x.to(device), src_y.to(device)
            tgt_x = tgt_x.to(device)

            cls_out_src, dom_out_src, feat_src = model(src_x)
            _,            dom_out_tgt, feat_tgt = model(tgt_x)

            # 1) 分类损失（仅源域）
            loss_cls = criterion_cls(cls_out_src, src_y)

            # 2) 域分类损失（DANN）
            dom_label_src = torch.zeros(src_x.size(0), dtype=torch.long, device=device)
            dom_label_tgt = torch.ones(tgt_x.size(0),  dtype=torch.long, device=device)
            loss_dom = criterion_domain(dom_out_src, dom_label_src) + \
                       criterion_domain(dom_out_tgt, dom_label_tgt)

            # 3) RBF‑MMD（特征对齐）
            # 建议先做 L2 归一化，提升稳定性
            feat_src_n = F.normalize(feat_src, dim=1)
            feat_tgt_n = F.normalize(feat_tgt, dim=1)
            loss_mmd = mmd_fn(feat_src_n, feat_tgt_n)

            # 4) 组合总损失
            #    - DANN 的 lambda 可继续用你已有的动态 dann_lambda
            #    - MMD 的权重做 warm‑up（避免一开始就把决策结构抹平）
            lambda_dann_now = dann_lambda(epoch, num_epochs) if callable(lambda_dann) else lambda_dann
            lambda_mmd_now  = float(mmd_lambda(epoch, num_epochs, max_lambda=lambda_mmd_max))

            loss = loss_cls + lambda_dann_now * loss_dom + lambda_mmd_now * loss_mmd

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # 记录指标
            cls_loss_sum  += loss_cls.item() * src_x.size(0)
            dom_loss_sum  += loss_dom.item() * (src_x.size(0) + tgt_x.size(0))
            mmd_loss_sum  += loss_mmd.item() * (src_x.size(0) + tgt_x.size(0))
            total_loss_sum += loss.item() * (src_x.size(0) + tgt_x.size(0))

            total_cls_samples += src_x.size(0)
            total_dom_samples += (src_x.size(0) + tgt_x.size(0))

            # 域分类准确率
            dom_preds_src = torch.argmax(dom_out_src, dim=1)
            dom_preds_tgt = torch.argmax(dom_out_tgt, dim=1)
            dom_correct += (dom_preds_src == dom_label_src).sum().item()
            dom_correct += (dom_preds_tgt == dom_label_tgt).sum().item()
            dom_total   += dom_label_src.size(0) + dom_label_tgt.size(0)

        # ——Epoch 级日志——
        avg_cls_loss  = cls_loss_sum  / max(1, total_cls_samples)
        avg_dom_loss  = dom_loss_sum  / max(1, total_dom_samples)
        avg_mmd_loss  = mmd_loss_sum  / max(1, total_dom_samples)
        avg_total_loss= total_loss_sum/ max(1, total_dom_samples)
        dom_acc = dom_correct / max(1, dom_total)
        gap = abs(dom_acc - 0.5)

        if scheduler is not None:
            scheduler.step()

        print(f"[Epoch {epoch + 1}] Total: {avg_total_loss:.4f} | "
              f"Cls: {avg_cls_loss:.4f} | Dom: {avg_dom_loss:.4f} | "
              f"MMD: {avg_mmd_loss:.4f} | DomAcc: {dom_acc:.4f} | "
              f"λ_dann: {lambda_dann_now:.4f} | λ_mmd: {lambda_mmd_now:.4f}")

        mmd_hist.append(avg_mmd_loss)
        mmd_plateau = (len(mmd_hist) == mmd_hist.maxlen) and (max(mmd_hist) - min(mmd_hist) < MMD_PLATEAU_EPS)

        # 触发条件
        cond_align = (gap < 0.05)
        cond_cls = (avg_cls_loss < 0.5)
        cond_mmd_small = (avg_mmd_loss < MMD_THRESH)
        cond_mmd_plateau = mmd_plateau

        # 是否有任何指标刷新“最好”
        improved = False
        if gap < best_gap - 1e-4:
            best_gap = gap
            best_model_state = copy.deepcopy(model.state_dict())
            improved = True

        # ——Early stopping：对齐 + 分类收敛 + （MMD小 或 MMD平台期）——
        if epoch > MIN_EPOCH and cond_align and cond_cls and (cond_mmd_small and cond_mmd_plateau):
            if not improved:
                patience += 1
            else:
                patience = 0
            print(f"[INFO] patience {patience} / {PATIENCE} | MMD_small={cond_mmd_small} plateau={cond_mmd_plateau}")
            if patience >= PATIENCE:
                if best_model_state is not None:
                    model.load_state_dict(best_model_state)
                print("[INFO] Early stopping: domain aligned, classifier converged, and MMD stabilized.")
                break
        else:

            patience = 0
        if epoch == (num_epochs-1):
            if best_model_state is not None:
                    model.load_state_dict(best_model_state)



    return model


if __name__ == '__main__':

    with open("../configs/default.yaml", 'r') as f:
        config = yaml.safe_load(f)['baseline']
    batch_size = config['batch_size']
    learning_rate = config['learning_rate']
    weight_decay = config['weight_decay']
    num_layers = config['num_layers']
    kernel_size = config['kernel_size']
    start_channels = config['start_channels']
    num_epochs = config['num_epochs']

    source_path = '../datasets/source/train/DC_T197_RP.txt'
    target_path = '../datasets/target/train/HC_T188_RP.txt'
    target_test_path = '../datasets/target/test/HC_T188_RP.txt'
    out_path = 'model'
    os.makedirs(out_path, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Flexible_DANN(num_layers=num_layers,
                          start_channels=start_channels,
                          kernel_size=kernel_size,
                          cnn_act='leakrelu',
                          num_classes=10,
                          lambda_=0.5).to(device)

    source_loader, target_loader = get_dataloaders(source_path, target_path, batch_size)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=num_epochs, eta_min=learning_rate * 0.1
    )
    criterion_cls = nn.CrossEntropyLoss()
    criterion_domain = nn.CrossEntropyLoss()

    print("[INFO] Starting standard DANN training (no pseudo labels)...")
    model = train_dann_with_mmd(model, source_loader, target_loader,
                                optimizer, criterion_cls, criterion_domain,
                                device, num_epochs=30, lambda_dann=0.5, use_mk=True,scheduler=scheduler)

    print("[INFO] Evaluating on target test set...")
    test_dataset = PKLDataset(target_test_path)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    general_test_model(model, criterion_cls, test_loader, device)


[INFO] Starting standard DANN training (no pseudo labels)...
[Epoch 1] Total: 0.9870 | Cls: 0.5132 | Dom: 0.9463 | MMD: 0.1543 | DomAcc: 0.7726 | λ_dann: 0.5000 | λ_mmd: 0.0010
[Epoch 2] Total: 0.6831 | Cls: 0.1454 | Dom: 1.0747 | MMD: 0.1511 | DomAcc: 0.7375 | λ_dann: 0.5000 | λ_mmd: 0.0014
[Epoch 3] Total: 0.7013 | Cls: 0.0918 | Dom: 1.2183 | MMD: 0.1638 | DomAcc: 0.6536 | λ_dann: 0.5000 | λ_mmd: 0.0020
[Epoch 4] Total: 0.6790 | Cls: 0.0974 | Dom: 1.1627 | MMD: 0.1552 | DomAcc: 0.6857 | λ_dann: 0.5000 | λ_mmd: 0.0028
[Epoch 5] Total: 0.7018 | Cls: 0.0732 | Dom: 1.2558 | MMD: 0.1663 | DomAcc: 0.6365 | λ_dann: 0.5000 | λ_mmd: 0.0039
[Epoch 6] Total: 0.7201 | Cls: 0.0595 | Dom: 1.3192 | MMD: 0.1746 | DomAcc: 0.6017 | λ_dann: 0.5000 | λ_mmd: 0.0055
[Epoch 7] Total: 0.7186 | Cls: 0.0466 | Dom: 1.3412 | MMD: 0.1763 | DomAcc: 0.5612 | λ_dann: 0.5000 | λ_mmd: 0.0076
[Epoch 8] Total: 0.7218 | Cls: 0.0579 | Dom: 1.3240 | MMD: 0.1765 | DomAcc: 0.5884 | λ_dann: 0.5000 | λ_mmd: 0.0105
[Epoch 9] T