In [1]:
import copy
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import os
import numpy as np
import random
import yaml
from models.Flexible_DANN import Flexible_DANN
from PKLDataset import PKLDataset
from utils.general_train_and_test import general_test_model
from models.get_no_label_dataloader import get_target_loader

In [2]:
def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def get_dataloaders(source_path, target_path, batch_size):
    source_dataset = PKLDataset(txt_path=source_path)
    source_loader = DataLoader(source_dataset, batch_size=batch_size, shuffle=True)
    target_loader = get_target_loader(target_path, batch_size=batch_size, shuffle=True)
    return source_loader, target_loader

def dann_lambda(epoch, num_epochs):
    """
    常用的 DANN λ 调度：从 0 平滑升到 1
    你也可以把 -10 调轻/重来改变上升速度
    """
    # p = epoch / float(num_epochs)
    # return 2. / (1. + np.exp(-10 * p)) - 1.
    if epoch < 10:
        return 0.8
    elif epoch < 20:
        return 0.5
    else:
        return 0.3

In [17]:
def train_dann(model, source_loader, target_loader,
               optimizer, criterion_cls, criterion_domain,
               device, num_epochs=20, lambda_=0.1,scheduler = None):
    best_gap = 0.5
    best_model_state = None
    patience = 0
    for epoch in range(num_epochs):
        cls_loss_sum, dom_loss_sum, total_loss_sum = 0.0, 0.0, 0.0
        total_cls_samples, total_dom_samples = 0, 0
        dom_correct, dom_total = 0, 0
        model.train()

        for (src_x, src_y), tgt_x in zip(source_loader, target_loader):
            src_x, src_y = src_x.to(device), src_y.to(device)
            tgt_x = tgt_x.to(device)

            cls_out_src, dom_out_src = model(src_x)
            _, dom_out_tgt = model(tgt_x)

            loss_cls = criterion_cls(cls_out_src, src_y)

            dom_label_src = torch.zeros(src_x.size(0), dtype=torch.long).to(device)
            dom_label_tgt = torch.ones(tgt_x.size(0), dtype=torch.long).to(device)
            loss_dom = criterion_domain(dom_out_src, dom_label_src) + \
                       criterion_domain(dom_out_tgt, dom_label_tgt)

            dom_preds_src = torch.argmax(dom_out_src, dim=1)
            dom_preds_tgt = torch.argmax(dom_out_tgt, dim=1)
            dom_correct += (dom_preds_src == dom_label_src).sum().item()
            dom_correct += (dom_preds_tgt == dom_label_tgt).sum().item()
            dom_total += dom_label_src.size(0) + dom_label_tgt.size(0)

            loss = loss_cls + lambda_ * loss_dom

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            cls_loss_sum += loss_cls.item() * src_x.size(0)
            dom_loss_sum += loss_dom.item() * (src_x.size(0) + tgt_x.size(0))
            total_loss_sum += loss.item() * (src_x.size(0) + tgt_x.size(0))

            total_cls_samples += src_x.size(0)
            total_dom_samples += (src_x.size(0) + tgt_x.size(0))

        avg_cls_loss = cls_loss_sum / total_cls_samples
        avg_dom_loss = dom_loss_sum / total_dom_samples
        avg_total_loss = total_loss_sum / total_dom_samples

        # 域分类准确率（整轮）
        dom_acc = dom_correct / dom_total
        gap = abs(dom_acc - 0.5)

        if scheduler is not None:
            scheduler.step()

        print(f"[Epoch {epoch + 1}] Total Loss: {avg_total_loss:.4f} | "
              f"Cls: {avg_cls_loss:.4f} | Dom: {avg_dom_loss:.4f} | "
              f"DomAcc: {dom_acc:.4f}")


        if gap < 0.03 and avg_cls_loss < 0.5 and epoch > 10:
            patience +=1
            if gap < best_gap:
                best_gap = gap
                best_model_state = copy.deepcopy(model.state_dict())
            print(f"[INFO] patience {patience} / 3")
            if patience > 3:
                model.load_state_dict(best_model_state)
                print("[INFO] Early stopping: domain aligned and classifier converged.")
                break
        else:
            patience = 0
            best_gap = gap

    if best_model_state is not None:
        # torch.save(best_model_state, os.path.join(out_path, 'test_best_model.pth'))
        model.load_state_dict(best_model_state)


    return model






if __name__ == '__main__':
    set_seed(seed=42)
    with open("../configs/default.yaml", 'r') as f:
        config = yaml.safe_load(f)['baseline']
    batch_size = config['batch_size']
    learning_rate = config['learning_rate']
    weight_decay = config['weight_decay']
    num_layers = config['num_layers']
    kernel_size = config['kernel_size']
    start_channels = config['start_channels']
    num_epochs = config['num_epochs']

    source_path = '../datasets/source/train/DC_T197_RP.txt'
    target_path = '../datasets/target/test/HC_T185_RP.txt'
    target_test_path = '../datasets/target/train/HC_T185_RP.txt'
    out_path = 'model'
    os.makedirs(out_path, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Flexible_DANN(num_layers=num_layers,
                          start_channels=start_channels,
                          kernel_size=kernel_size,
                          cnn_act='leakrelu',
                          num_classes=10,
                          lambda_=0.5).to(device)

    source_loader, target_loader = get_dataloaders(source_path, target_path, batch_size)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=num_epochs, eta_min=learning_rate * 0.1
    )
    criterion_cls = nn.CrossEntropyLoss()
    criterion_domain = nn.CrossEntropyLoss()

    print("[INFO] Starting standard DANN training (no pseudo labels)...")
    model=train_dann(model, source_loader, target_loader,
               optimizer, criterion_cls, criterion_domain,
               device, num_epochs=40, lambda_=1,scheduler=scheduler)

    print("[INFO] Evaluating on target test set...")
    test_dataset = PKLDataset(target_test_path)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    general_test_model(model, criterion_cls, test_loader, device)


[INFO] Starting standard DANN training (no pseudo labels)...
[Epoch 1] Total Loss: 2.7007 | Cls: 1.3054 | Dom: 1.3953 | DomAcc: 0.5358
[Epoch 2] Total Loss: 1.6789 | Cls: 0.4881 | Dom: 1.1908 | DomAcc: 0.6557
[Epoch 3] Total Loss: 1.5989 | Cls: 0.3699 | Dom: 1.2290 | DomAcc: 0.6562
[Epoch 4] Total Loss: 1.4997 | Cls: 0.2143 | Dom: 1.2854 | DomAcc: 0.6522
[Epoch 5] Total Loss: 1.4682 | Cls: 0.1732 | Dom: 1.2950 | DomAcc: 0.6295
[Epoch 6] Total Loss: 1.4464 | Cls: 0.1505 | Dom: 1.2959 | DomAcc: 0.6195
[Epoch 7] Total Loss: 1.4164 | Cls: 0.1157 | Dom: 1.3007 | DomAcc: 0.6018
[Epoch 8] Total Loss: 1.5035 | Cls: 0.1407 | Dom: 1.3628 | DomAcc: 0.5696
[Epoch 9] Total Loss: 1.3922 | Cls: 0.0984 | Dom: 1.2938 | DomAcc: 0.6205
[Epoch 10] Total Loss: 1.3202 | Cls: 0.1204 | Dom: 1.1998 | DomAcc: 0.7026
[Epoch 11] Total Loss: 1.4701 | Cls: 0.0960 | Dom: 1.3740 | DomAcc: 0.5418
[Epoch 12] Total Loss: 1.5798 | Cls: 0.0841 | Dom: 1.4956 | DomAcc: 0.4350
[Epoch 13] Total Loss: 1.5396 | Cls: 0.0581 | Do

In [7]:
def train_dann(model, source_loader, target_loader,
               optimizer, criterion_cls, criterion_domain,
               device, num_epochs=20, lambda_=0.1,scheduler = None):
    best_gap = 0.5
    best_model_state = None
    patience = 0
    for epoch in range(num_epochs):
        total_loss, total_cls_loss, total_dom_loss = 0.0, 0.0, 0.0
        dom_correct, dom_total = 0, 0
        model.train()
        num_batches = 0
        for (src_x, src_y), tgt_x in zip(source_loader, target_loader):
            num_batches += 1
            src_x, src_y = src_x.to(device), src_y.to(device)
            tgt_x = tgt_x.to(device)

            cls_out_src, dom_out_src = model(src_x)
            _, dom_out_tgt = model(tgt_x)

            loss_cls = criterion_cls(cls_out_src, src_y)

            dom_label_src = torch.zeros(src_x.size(0), dtype=torch.long).to(device)
            dom_label_tgt = torch.ones(tgt_x.size(0), dtype=torch.long).to(device)
            loss_dom = criterion_domain(dom_out_src, dom_label_src) + \
                       criterion_domain(dom_out_tgt, dom_label_tgt)

            dom_preds_src = torch.argmax(dom_out_src, dim=1)
            dom_preds_tgt = torch.argmax(dom_out_tgt, dim=1)
            dom_correct += (dom_preds_src == dom_label_src).sum().item()
            dom_correct += (dom_preds_tgt == dom_label_tgt).sum().item()
            dom_total += dom_label_src.size(0) + dom_label_tgt.size(0)

            loss = loss_cls + lambda_ * loss_dom

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


            total_loss += loss.item()
            total_cls_loss += loss_cls.item()
            total_dom_loss += loss_dom.item()

        dom_acc = dom_correct / dom_total
        avg_cls_loss = total_cls_loss / num_batches
        gap = abs(dom_acc - 0.5)

        if scheduler is not None:
            scheduler.step()

        print(f"[Epoch {epoch+1}] Total Loss: {total_loss:.4f} | "
              f"Cls: {avg_cls_loss:.4f} | Dom: {total_dom_loss:.4f} | "
              f"DomAcc: {dom_acc:.4f}")

        # print("[INFO] Evaluating on target test set...")
        # target_test_path = '../datasets/HC_T185_RP.txt'
        # test_dataset = PKLDataset(target_test_path)
        # test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
        # pseudo_test_model(model, criterion_cls, test_loader, device)



        if gap < 0.005 and avg_cls_loss < 0.05 and epoch > 10:
            patience +=1
            if gap < best_gap:
                best_gap = gap
                best_model_state = copy.deepcopy(model.state_dict())
            print(f"[INFO] patience {patience} / 3")
            if patience > 3:
                model.load_state_dict(best_model_state)
                print("[INFO] Early stopping: domain aligned and classifier converged.")
                break
        else:
            patience = 0
            best_gap = gap
    if best_model_state is not None:

        model.load_state_dict(best_model_state)


    return model






if __name__ == '__main__':
    set_seed(seed=44)
    with open("../configs/default.yaml", 'r') as f:
        config = yaml.safe_load(f)['baseline']
    batch_size = config['batch_size']
    learning_rate = config['learning_rate']
    weight_decay = config['weight_decay']
    num_layers = config['num_layers']
    kernel_size = config['kernel_size']
    start_channels = config['start_channels']
    num_epochs = config['num_epochs']

    source_path = '../datasets/source/train/DC_T197_RP.txt'
    target_path = '../datasets/HC_T188_RP.txt'
    target_test_path = '../datasets/HC_T188_RP.txt'
    out_path = 'model'
    os.makedirs(out_path, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Flexible_DANN(num_layers=num_layers,
                          start_channels=start_channels,
                          kernel_size=kernel_size,
                          cnn_act='leakrelu',
                          num_classes=10,
                          lambda_=0.5).to(device)

    source_loader, target_loader = get_dataloaders(source_path, target_path, batch_size)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=num_epochs, eta_min=learning_rate * 0.1
    )
    criterion_cls = nn.CrossEntropyLoss()
    criterion_domain = nn.CrossEntropyLoss()

    print("[INFO] Starting standard DANN training (no pseudo labels)...")
    model=train_dann(model, source_loader, target_loader,
               optimizer, criterion_cls, criterion_domain,
               device, num_epochs=40, lambda_=0.5,scheduler=scheduler)

    print("[INFO] Evaluating on target test set...")
    test_dataset = PKLDataset(target_test_path)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    general_test_model(model, criterion_cls, test_loader, device)


[INFO] Starting standard DANN training (no pseudo labels)...
[Epoch 1] Total Loss: 312.2282 | Cls: 0.7016 | Dom: 273.6700 | DomAcc: 0.6900
[Epoch 2] Total Loss: 173.5931 | Cls: 0.1267 | Dom: 283.8536 | DomAcc: 0.6893
[Epoch 3] Total Loss: 192.4783 | Cls: 0.0823 | Dom: 343.7952 | DomAcc: 0.5965
[Epoch 4] Total Loss: 191.1798 | Cls: 0.0665 | Dom: 349.0959 | DomAcc: 0.5167
[Epoch 5] Total Loss: 183.9482 | Cls: 0.0480 | Dom: 343.8882 | DomAcc: 0.5465
[Epoch 6] Total Loss: 177.7356 | Cls: 0.0400 | Dom: 335.4661 | DomAcc: 0.5980
[Epoch 7] Total Loss: 171.7944 | Cls: 0.0369 | Dom: 325.1553 | DomAcc: 0.6084
[Epoch 8] Total Loss: 172.1987 | Cls: 0.0373 | Dom: 325.7460 | DomAcc: 0.6006
[Epoch 9] Total Loss: 176.5720 | Cls: 0.0323 | Dom: 336.9959 | DomAcc: 0.5458
[Epoch 10] Total Loss: 171.5147 | Cls: 0.0395 | Dom: 323.2586 | DomAcc: 0.6296
[Epoch 11] Total Loss: 173.0376 | Cls: 0.0349 | Dom: 328.6305 | DomAcc: 0.6055
[Epoch 12] Total Loss: 175.7335 | Cls: 0.0274 | Dom: 337.7629 | DomAcc: 0.5706


In [4]:
def train_dann(model, source_loader, target_loader,
               optimizer, criterion_cls, criterion_domain,
               device, num_epochs=20, lambda_=0.1,scheduler = None):
    best_gap = 0.5
    best_model_state = None
    patience = 0
    for epoch in range(num_epochs):
        total_loss, total_cls_loss, total_dom_loss = 0.0, 0.0, 0.0
        dom_correct, dom_total = 0, 0
        model.train()
        num_batches = 0
        for (src_x, src_y), tgt_x in zip(source_loader, target_loader):
            num_batches += 1
            src_x, src_y = src_x.to(device), src_y.to(device)
            tgt_x = tgt_x.to(device)

            cls_out_src, dom_out_src = model(src_x)
            _, dom_out_tgt = model(tgt_x)

            loss_cls = criterion_cls(cls_out_src, src_y)

            dom_label_src = torch.zeros(src_x.size(0), dtype=torch.long).to(device)
            dom_label_tgt = torch.ones(tgt_x.size(0), dtype=torch.long).to(device)
            loss_dom = criterion_domain(dom_out_src, dom_label_src) + \
                       criterion_domain(dom_out_tgt, dom_label_tgt)

            dom_preds_src = torch.argmax(dom_out_src, dim=1)
            dom_preds_tgt = torch.argmax(dom_out_tgt, dim=1)
            dom_correct += (dom_preds_src == dom_label_src).sum().item()
            dom_correct += (dom_preds_tgt == dom_label_tgt).sum().item()
            dom_total += dom_label_src.size(0) + dom_label_tgt.size(0)

            loss = loss_cls + lambda_ * loss_dom

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


            total_loss += loss.item()
            total_cls_loss += loss_cls.item()
            total_dom_loss += loss_dom.item()

        dom_acc = dom_correct / dom_total
        avg_cls_loss = total_cls_loss / num_batches
        gap = abs(dom_acc - 0.5)

        if scheduler is not None:
            scheduler.step()

        print(f"[Epoch {epoch+1}] Total Loss: {total_loss:.4f} | "
              f"Cls: {avg_cls_loss:.4f} | Dom: {total_dom_loss:.4f} | "
              f"DomAcc: {dom_acc:.4f}")


        if gap < 0.03 and avg_cls_loss < 0.05 and epoch > 10:
            patience +=1
            if gap < best_gap:
                best_gap = gap
                best_model_state = copy.deepcopy(model.state_dict())
            print(f"[INFO] patience {patience} / 3")
            if patience > 3:
                model.load_state_dict(best_model_state)
                print("[INFO] Early stopping: domain aligned and classifier converged.")
                break
        else:
            patience = 0
            best_gap = gap
    if best_model_state is not None:

        model.load_state_dict(best_model_state)


    return model






if __name__ == '__main__':
    set_seed(seed=188)
    with open("../configs/default.yaml", 'r') as f:
        config = yaml.safe_load(f)['baseline']
    batch_size = config['batch_size']
    learning_rate = config['learning_rate']
    weight_decay = config['weight_decay']
    num_layers = config['num_layers']
    kernel_size = config['kernel_size']
    start_channels = config['start_channels']
    num_epochs = config['num_epochs']

    source_path = '../datasets/source/train/DC_T197_RP.txt'
    target_path = '../datasets/target/train/HC_T188_RP.txt'
    target_test_path = '../datasets/target/test/HC_T188_RP.txt'
    out_path = 'model'
    os.makedirs(out_path, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Flexible_DANN(num_layers=num_layers,
                          start_channels=start_channels,
                          kernel_size=kernel_size,
                          cnn_act='leakrelu',
                          num_classes=10,
                          lambda_=0.5).to(device)

    source_loader, target_loader = get_dataloaders(source_path, target_path, batch_size)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=num_epochs, eta_min=learning_rate * 0.1
    )
    criterion_cls = nn.CrossEntropyLoss()
    criterion_domain = nn.CrossEntropyLoss()

    print("[INFO] Starting standard DANN training (no pseudo labels)...")
    model=train_dann(model, source_loader, target_loader,
               optimizer, criterion_cls, criterion_domain,
               device, num_epochs=40, lambda_=0.5,scheduler=scheduler)

    print("[INFO] Evaluating on target test set...")
    test_dataset = PKLDataset(target_test_path)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    general_test_model(model, criterion_cls, test_loader, device)


[INFO] Starting standard DANN training (no pseudo labels)...
[Epoch 1] Total Loss: 262.9256 | Cls: 0.5452 | Dom: 253.2420 | DomAcc: 0.7581
[Epoch 2] Total Loss: 183.4174 | Cls: 0.1369 | Dom: 298.4002 | DomAcc: 0.6896
[Epoch 3] Total Loss: 168.9419 | Cls: 0.0898 | Dom: 293.0012 | DomAcc: 0.6950
[Epoch 4] Total Loss: 175.1027 | Cls: 0.0783 | Dom: 311.0342 | DomAcc: 0.6667
[Epoch 5] Total Loss: 174.7883 | Cls: 0.0521 | Dom: 323.5196 | DomAcc: 0.6151
[Epoch 6] Total Loss: 179.4945 | Cls: 0.0465 | Dom: 335.7604 | DomAcc: 0.5862
[Epoch 7] Total Loss: 169.3155 | Cls: 0.0487 | Dom: 314.2756 | DomAcc: 0.6539
[Epoch 8] Total Loss: 163.3561 | Cls: 0.0370 | Dom: 308.2311 | DomAcc: 0.6749
[Epoch 9] Total Loss: 173.5950 | Cls: 0.0430 | Dom: 325.7055 | DomAcc: 0.6066
[Epoch 10] Total Loss: 169.5951 | Cls: 0.0224 | Dom: 328.0061 | DomAcc: 0.6178
[Epoch 11] Total Loss: 174.5310 | Cls: 0.0343 | Dom: 331.8966 | DomAcc: 0.5922
[Epoch 12] Total Loss: 174.0767 | Cls: 0.0216 | Dom: 337.3530 | DomAcc: 0.5893


In [10]:
def train_dann(model, source_loader, target_loader,
               optimizer, criterion_cls, criterion_domain,
               device, num_epochs=20, lambda_=0.1,scheduler = None):
    best_gap = 0.5
    best_model_state = None
    patience = 0
    for epoch in range(num_epochs):
        total_loss, total_cls_loss, total_dom_loss = 0.0, 0.0, 0.0
        dom_correct, dom_total = 0, 0
        model.train()
        num_batches = 0
        for (src_x, src_y), tgt_x in zip(source_loader, target_loader):
            num_batches += 1
            src_x, src_y = src_x.to(device), src_y.to(device)
            tgt_x = tgt_x.to(device)

            cls_out_src, dom_out_src = model(src_x)
            _, dom_out_tgt = model(tgt_x)

            loss_cls = criterion_cls(cls_out_src, src_y)

            dom_label_src = torch.zeros(src_x.size(0), dtype=torch.long).to(device)
            dom_label_tgt = torch.ones(tgt_x.size(0), dtype=torch.long).to(device)
            loss_dom = criterion_domain(dom_out_src, dom_label_src) + \
                       criterion_domain(dom_out_tgt, dom_label_tgt)

            dom_preds_src = torch.argmax(dom_out_src, dim=1)
            dom_preds_tgt = torch.argmax(dom_out_tgt, dim=1)
            dom_correct += (dom_preds_src == dom_label_src).sum().item()
            dom_correct += (dom_preds_tgt == dom_label_tgt).sum().item()
            dom_total += dom_label_src.size(0) + dom_label_tgt.size(0)

            loss = loss_cls + lambda_ * loss_dom

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


            total_loss += loss.item()
            total_cls_loss += loss_cls.item()
            total_dom_loss += loss_dom.item()

        dom_acc = dom_correct / dom_total
        avg_cls_loss = total_cls_loss / num_batches
        gap = abs(dom_acc - 0.5)

        if scheduler is not None:
            scheduler.step()

        print(f"[Epoch {epoch+1}] Total Loss: {total_loss:.4f} | "
              f"Cls: {avg_cls_loss:.4f} | Dom: {total_dom_loss:.4f} | "
              f"DomAcc: {dom_acc:.4f}")


        if gap < 0.015 and avg_cls_loss < 0.05 and epoch > 10:
            patience +=1
            if gap < best_gap:
                best_gap = gap
                best_model_state = copy.deepcopy(model.state_dict())
            print(f"[INFO] patience {patience} / 3")
            if patience > 3:
                model.load_state_dict(best_model_state)
                print("[INFO] Early stopping: domain aligned and classifier converged.")
                break
        else:
            patience = 0
    if best_model_state is not None:

        model.load_state_dict(best_model_state)


    return model






if __name__ == '__main__':
    set_seed(seed=191)
    with open("../configs/default.yaml", 'r') as f:
        config = yaml.safe_load(f)['baseline']
    batch_size = config['batch_size']
    learning_rate = config['learning_rate']
    weight_decay = config['weight_decay']
    num_layers = config['num_layers']
    kernel_size = config['kernel_size']
    start_channels = config['start_channels']
    num_epochs = config['num_epochs']

    source_path = '../datasets/source/train/DC_T197_RP.txt'
    target_path = '../datasets/target/test/HC_T191_RP.txt'
    target_test_path = '../datasets/target/train/HC_T191_RP.txt'
    out_path = 'model'
    os.makedirs(out_path, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Flexible_DANN(num_layers=num_layers,
                          start_channels=start_channels,
                          kernel_size=kernel_size,
                          cnn_act='leakrelu',
                          num_classes=10,
                          lambda_=0.5).to(device)

    source_loader, target_loader = get_dataloaders(source_path, target_path, batch_size)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=num_epochs, eta_min=learning_rate * 0.1
    )
    criterion_cls = nn.CrossEntropyLoss()
    criterion_domain = nn.CrossEntropyLoss()

    print("[INFO] Starting standard DANN training (no pseudo labels)...")
    model=train_dann(model, source_loader, target_loader,
               optimizer, criterion_cls, criterion_domain,
               device, num_epochs=40, lambda_=0.5,scheduler=scheduler)

    print("[INFO] Evaluating on target test set...")
    test_dataset = PKLDataset(target_test_path)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    general_test_model(model, criterion_cls, test_loader, device)


[INFO] Starting standard DANN training (no pseudo labels)...
[Epoch 1] Total Loss: 107.4720 | Cls: 1.2015 | Dom: 61.1553 | DomAcc: 0.7852
[Epoch 2] Total Loss: 42.0315 | Cls: 0.3317 | Dom: 41.6113 | DomAcc: 0.9031
[Epoch 3] Total Loss: 35.3511 | Cls: 0.1812 | Dom: 47.5135 | DomAcc: 0.8437
[Epoch 4] Total Loss: 35.0832 | Cls: 0.1315 | Dom: 53.3284 | DomAcc: 0.7684
[Epoch 5] Total Loss: 35.8496 | Cls: 0.1410 | Dom: 53.6561 | DomAcc: 0.8269
[Epoch 6] Total Loss: 32.5553 | Cls: 0.1231 | Dom: 49.3595 | DomAcc: 0.8491
[Epoch 7] Total Loss: 33.1143 | Cls: 0.0969 | Dom: 53.8270 | DomAcc: 0.8368
[Epoch 8] Total Loss: 38.0941 | Cls: 0.1343 | Dom: 58.9997 | DomAcc: 0.8024
[Epoch 9] Total Loss: 38.6681 | Cls: 0.0865 | Dom: 66.2631 | DomAcc: 0.7522
[Epoch 10] Total Loss: 35.6480 | Cls: 0.0938 | Dom: 59.2918 | DomAcc: 0.7979
[Epoch 11] Total Loss: 38.1515 | Cls: 0.1236 | Dom: 60.4816 | DomAcc: 0.7817
[Epoch 12] Total Loss: 37.1170 | Cls: 0.0730 | Dom: 64.8887 | DomAcc: 0.7557
[Epoch 13] Total Loss: 

In [15]:
def train_dann(model, source_loader, target_loader,
               optimizer, criterion_cls, criterion_domain,
               device, num_epochs=20, lambda_=0.1,scheduler = None):
    best_gap = 0.5
    best_model_state = None
    patience = 0
    for epoch in range(num_epochs):
        cls_loss_sum, dom_loss_sum, total_loss_sum = 0.0, 0.0, 0.0
        total_cls_samples, total_dom_samples = 0, 0
        dom_correct, dom_total = 0, 0
        model.train()

        for (src_x, src_y), tgt_x in zip(source_loader, target_loader):
            src_x, src_y = src_x.to(device), src_y.to(device)
            tgt_x = tgt_x.to(device)

            cls_out_src, dom_out_src = model(src_x)
            _, dom_out_tgt = model(tgt_x)

            loss_cls = criterion_cls(cls_out_src, src_y)

            dom_label_src = torch.zeros(src_x.size(0), dtype=torch.long).to(device)
            dom_label_tgt = torch.ones(tgt_x.size(0), dtype=torch.long).to(device)
            loss_dom = criterion_domain(dom_out_src, dom_label_src) + \
                       criterion_domain(dom_out_tgt, dom_label_tgt)

            dom_preds_src = torch.argmax(dom_out_src, dim=1)
            dom_preds_tgt = torch.argmax(dom_out_tgt, dim=1)
            dom_correct += (dom_preds_src == dom_label_src).sum().item()
            dom_correct += (dom_preds_tgt == dom_label_tgt).sum().item()
            dom_total += dom_label_src.size(0) + dom_label_tgt.size(0)
            lambda_ = dann_lambda(epoch, num_epochs)

            loss = loss_cls + lambda_ * loss_dom

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            cls_loss_sum += loss_cls.item() * src_x.size(0)
            dom_loss_sum += loss_dom.item() * (src_x.size(0) + tgt_x.size(0))
            total_loss_sum += loss.item() * (src_x.size(0) + tgt_x.size(0))

            total_cls_samples += src_x.size(0)
            total_dom_samples += (src_x.size(0) + tgt_x.size(0))

        avg_cls_loss = cls_loss_sum / total_cls_samples
        avg_dom_loss = dom_loss_sum / total_dom_samples
        avg_total_loss = total_loss_sum / total_dom_samples

        # 域分类准确率（整轮）
        dom_acc = dom_correct / dom_total
        gap = abs(dom_acc - 0.5)

        if scheduler is not None:
            scheduler.step()

        print(f"[Epoch {epoch + 1}] Total Loss: {avg_total_loss:.4f} | "
              f"Cls: {avg_cls_loss:.4f} | Dom: {avg_dom_loss:.4f} | "
              f"DomAcc: {dom_acc:.4f}")


        if gap < 0.03 and avg_cls_loss < 0.5 and epoch > 10:
            patience +=1
            if gap < best_gap:
                best_gap = gap
                best_model_state = copy.deepcopy(model.state_dict())
            print(f"[INFO] patience {patience} / 3")
            if patience > 3:
                model.load_state_dict(best_model_state)
                print("[INFO] Early stopping: domain aligned and classifier converged.")
                break
        else:
            patience = 0
            best_gap = gap

    if best_model_state is not None:
        # torch.save(best_model_state, os.path.join(out_path, 'test_best_model.pth'))
        model.load_state_dict(best_model_state)



    return model






if __name__ == '__main__':
    set_seed(seed=194)
    with open("../configs/default.yaml", 'r') as f:
        config = yaml.safe_load(f)['baseline']
    batch_size = config['batch_size']
    learning_rate = config['learning_rate']
    weight_decay = config['weight_decay']
    num_layers = config['num_layers']
    kernel_size = config['kernel_size']
    start_channels = config['start_channels']
    num_epochs = config['num_epochs']

    source_path = '../datasets/source/train/DC_T197_RP.txt'
    target_path = '../datasets/target/train/HC_T194_RP.txt'
    target_test_path = '../datasets/target/test/HC_T194_RP.txt'
    out_path = 'model'
    os.makedirs(out_path, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Flexible_DANN(num_layers=num_layers,
                          start_channels=start_channels,
                          kernel_size=kernel_size,
                          cnn_act='leakrelu',
                          num_classes=10,
                          lambda_=0.5).to(device)

    source_loader, target_loader = get_dataloaders(source_path, target_path, batch_size)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=num_epochs, eta_min=learning_rate * 0.1
    )
    criterion_cls = nn.CrossEntropyLoss()
    criterion_domain = nn.CrossEntropyLoss()

    print("[INFO] Starting standard DANN training (no pseudo labels)...")
    model=train_dann(model, source_loader, target_loader,
               optimizer, criterion_cls, criterion_domain,
               device, num_epochs=40, lambda_=1,scheduler=scheduler)

    print("[INFO] Evaluating on target test set...")
    test_dataset = PKLDataset(target_test_path)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    general_test_model(model, criterion_cls, test_loader, device)


[INFO] Starting standard DANN training (no pseudo labels)...
[Epoch 1] Total Loss: 1.3830 | Cls: 0.4813 | Dom: 1.1275 | DomAcc: 0.6980
[INFO] Evaluating on target test set...
- test Loss: 2.379489, test Acc: 0.3473
[Epoch 2] Total Loss: 1.1783 | Cls: 0.1138 | Dom: 1.3307 | DomAcc: 0.5889
[INFO] Evaluating on target test set...
- test Loss: 1.943707, test Acc: 0.3565
[Epoch 3] Total Loss: 1.0821 | Cls: 0.0630 | Dom: 1.2737 | DomAcc: 0.6454
[INFO] Evaluating on target test set...
- test Loss: 2.992497, test Acc: 0.4413
[Epoch 4] Total Loss: 1.1115 | Cls: 0.0695 | Dom: 1.3025 | DomAcc: 0.6242
[INFO] Evaluating on target test set...
- test Loss: 2.874673, test Acc: 0.4382
[Epoch 5] Total Loss: 1.0684 | Cls: 0.0725 | Dom: 1.2448 | DomAcc: 0.6642
[INFO] Evaluating on target test set...
- test Loss: 3.328156, test Acc: 0.4331
[Epoch 6] Total Loss: 1.0800 | Cls: 0.0498 | Dom: 1.2878 | DomAcc: 0.6309
[INFO] Evaluating on target test set...
- test Loss: 3.347125, test Acc: 0.3115
[Epoch 7] Total

In [3]:
def train_dann(model, source_loader, target_loader,
               optimizer, criterion_cls, criterion_domain,
               device, num_epochs=20, lambda_=0.1,scheduler = None):
    best_gap = 0.5
    best_model_state = None
    patience = 0
    for epoch in range(num_epochs):
        cls_loss_sum, dom_loss_sum, total_loss_sum = 0.0, 0.0, 0.0
        total_cls_samples, total_dom_samples = 0, 0
        dom_correct, dom_total = 0, 0
        model.train()

        for (src_x, src_y), tgt_x in zip(source_loader, target_loader):
            src_x, src_y = src_x.to(device), src_y.to(device)
            tgt_x = tgt_x.to(device)

            cls_out_src, dom_out_src = model(src_x)
            _, dom_out_tgt = model(tgt_x)

            loss_cls = criterion_cls(cls_out_src, src_y)

            dom_label_src = torch.zeros(src_x.size(0), dtype=torch.long).to(device)
            dom_label_tgt = torch.ones(tgt_x.size(0), dtype=torch.long).to(device)
            bs_src, bs_tgt = src_x.size(0), tgt_x.size(0)
            loss_dom_src = criterion_domain(dom_out_src, dom_label_src)
            loss_dom_tgt = criterion_domain(dom_out_tgt, dom_label_tgt)

            # 样本数加权的“单个域损失均值”
            loss_dom = (loss_dom_src * bs_src + loss_dom_tgt * bs_tgt) / (bs_src + bs_tgt)

            dom_preds_src = torch.argmax(dom_out_src, dim=1)
            dom_preds_tgt = torch.argmax(dom_out_tgt, dim=1)
            dom_correct += (dom_preds_src == dom_label_src).sum().item()
            dom_correct += (dom_preds_tgt == dom_label_tgt).sum().item()
            dom_total += dom_label_src.size(0) + dom_label_tgt.size(0)

            loss = loss_cls + lambda_ * loss_dom

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            cls_loss_sum += loss_cls.item() * src_x.size(0)
            dom_loss_sum += loss_dom.item() * (src_x.size(0) + tgt_x.size(0))
            total_loss_sum += loss.item() * (src_x.size(0) + tgt_x.size(0))

            total_cls_samples += src_x.size(0)
            total_dom_samples += (src_x.size(0) + tgt_x.size(0))

        avg_cls_loss = cls_loss_sum / total_cls_samples
        avg_dom_loss = dom_loss_sum / total_dom_samples
        avg_total_loss = total_loss_sum / total_dom_samples

        # 域分类准确率（整轮）
        dom_acc = dom_correct / dom_total
        gap = abs(dom_acc - 0.5)

        if scheduler is not None:
            scheduler.step()

        print(f"[Epoch {epoch + 1}] Total Loss: {avg_total_loss:.4f} | "
              f"Cls: {avg_cls_loss:.4f} | Dom: {avg_dom_loss:.4f} | "
              f"DomAcc: {dom_acc:.4f}")


        if gap < 0.03 and avg_cls_loss < 0.5 and epoch > 10:
            patience +=1
            if gap < best_gap:
                best_gap = gap
                best_model_state = copy.deepcopy(model.state_dict())
            print(f"[INFO] patience {patience} / 3")
            if patience > 3:
                model.load_state_dict(best_model_state)
                print("[INFO] Early stopping: domain aligned and classifier converged.")
                break
        else:
            patience = 0
            best_gap = gap

    if best_model_state is not None:

        model.load_state_dict(best_model_state)



    return model






if __name__ == '__main__':
    set_seed(seed=94)
    with open("../configs/default.yaml", 'r') as f:
        config = yaml.safe_load(f)['baseline']
    batch_size = config['batch_size']
    learning_rate = config['learning_rate']
    weight_decay = config['weight_decay']
    num_layers = config['num_layers']
    kernel_size = config['kernel_size']
    start_channels = config['start_channels']
    num_epochs = config['num_epochs']

    source_path = '../datasets/source/train/DC_T197_RP.txt'
    target_path = '../datasets/target/train/HC_T194_RP.txt'
    target_test_path = '../datasets/target/test/HC_T194_RP.txt'
    out_path = 'model'
    os.makedirs(out_path, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Flexible_DANN(num_layers=num_layers,
                          start_channels=start_channels,
                          kernel_size=kernel_size,
                          cnn_act='leakrelu',
                          num_classes=10,
                          lambda_=1).to(device)

    source_loader, target_loader = get_dataloaders(source_path, target_path, batch_size)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=num_epochs, eta_min=learning_rate * 0.1
    )
    criterion_cls = nn.CrossEntropyLoss()
    criterion_domain = nn.CrossEntropyLoss()

    print("[INFO] Starting standard DANN training (no pseudo labels)...")
    model=train_dann(model, source_loader, target_loader,
               optimizer, criterion_cls, criterion_domain,
               device, num_epochs=40, lambda_=0.5,scheduler=scheduler)

    print("[INFO] Evaluating on target test set...")
    test_dataset = PKLDataset(target_test_path)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    general_test_model(model, criterion_cls, test_loader, device)


[INFO] Starting standard DANN training (no pseudo labels)...
[Epoch 1] Total Loss: 0.6434 | Cls: 0.4293 | Dom: 0.4273 | DomAcc: 0.8061
[Epoch 2] Total Loss: 0.3447 | Cls: 0.0964 | Dom: 0.4963 | DomAcc: 0.7478
[Epoch 3] Total Loss: 0.3646 | Cls: 0.0772 | Dom: 0.5748 | DomAcc: 0.6913
[Epoch 4] Total Loss: 0.3461 | Cls: 0.0578 | Dom: 0.5765 | DomAcc: 0.7081
[Epoch 5] Total Loss: 0.3559 | Cls: 0.0550 | Dom: 0.6017 | DomAcc: 0.6849
[Epoch 6] Total Loss: 0.3525 | Cls: 0.0476 | Dom: 0.6096 | DomAcc: 0.6699
[Epoch 7] Total Loss: 0.3397 | Cls: 0.0334 | Dom: 0.6124 | DomAcc: 0.6695
[Epoch 8] Total Loss: 0.3448 | Cls: 0.0378 | Dom: 0.6140 | DomAcc: 0.6650
[Epoch 9] Total Loss: 0.3503 | Cls: 0.0371 | Dom: 0.6263 | DomAcc: 0.6551
[Epoch 10] Total Loss: 0.3501 | Cls: 0.0275 | Dom: 0.6452 | DomAcc: 0.6216
[Epoch 11] Total Loss: 0.3472 | Cls: 0.0229 | Dom: 0.6485 | DomAcc: 0.6230
[Epoch 12] Total Loss: 0.3577 | Cls: 0.0248 | Dom: 0.6657 | DomAcc: 0.5986
[Epoch 13] Total Loss: 0.3550 | Cls: 0.0236 | Do

In [4]:
def train_dann(model, source_loader, target_loader,
               optimizer, criterion_cls, criterion_domain,
               device, num_epochs=20, lambda_=0.1,scheduler = None):
    best_gap = 0.5
    best_model_state = None
    patience = 0
    for epoch in range(num_epochs):
        cls_loss_sum, dom_loss_sum, total_loss_sum = 0.0, 0.0, 0.0
        total_cls_samples, total_dom_samples = 0, 0
        dom_correct, dom_total = 0, 0
        model.train()

        for (src_x, src_y), tgt_x in zip(source_loader, target_loader):
            src_x, src_y = src_x.to(device), src_y.to(device)
            tgt_x = tgt_x.to(device)

            cls_out_src, dom_out_src = model(src_x)
            _, dom_out_tgt = model(tgt_x)

            loss_cls = criterion_cls(cls_out_src, src_y)

            dom_label_src = torch.zeros(src_x.size(0), dtype=torch.long).to(device)
            dom_label_tgt = torch.ones(tgt_x.size(0), dtype=torch.long).to(device)
            bs_src, bs_tgt = src_x.size(0), tgt_x.size(0)
            loss_dom_src = criterion_domain(dom_out_src, dom_label_src)
            loss_dom_tgt = criterion_domain(dom_out_tgt, dom_label_tgt)

            # 样本数加权的“单个域损失均值”
            loss_dom = (loss_dom_src * bs_src + loss_dom_tgt * bs_tgt) / (bs_src + bs_tgt)

            dom_preds_src = torch.argmax(dom_out_src, dim=1)
            dom_preds_tgt = torch.argmax(dom_out_tgt, dim=1)
            dom_correct += (dom_preds_src == dom_label_src).sum().item()
            dom_correct += (dom_preds_tgt == dom_label_tgt).sum().item()
            dom_total += dom_label_src.size(0) + dom_label_tgt.size(0)

            loss = loss_cls + lambda_ * loss_dom

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            cls_loss_sum += loss_cls.item() * src_x.size(0)
            dom_loss_sum += loss_dom.item() * (src_x.size(0) + tgt_x.size(0))
            total_loss_sum += loss.item() * (src_x.size(0) + tgt_x.size(0))

            total_cls_samples += src_x.size(0)
            total_dom_samples += (src_x.size(0) + tgt_x.size(0))

        avg_cls_loss = cls_loss_sum / total_cls_samples
        avg_dom_loss = dom_loss_sum / total_dom_samples
        avg_total_loss = total_loss_sum / total_dom_samples

        # 域分类准确率（整轮）
        dom_acc = dom_correct / dom_total
        gap = abs(dom_acc - 0.5)

        if scheduler is not None:
            scheduler.step()

        print(f"[Epoch {epoch + 1}] Total Loss: {avg_total_loss:.4f} | "
              f"Cls: {avg_cls_loss:.4f} | Dom: {avg_dom_loss:.4f} | "
              f"DomAcc: {dom_acc:.4f}")


        if gap < 0.03 and avg_cls_loss < 0.5 and epoch > 10:
            patience +=1
            if gap < best_gap:
                best_gap = gap
                best_model_state = copy.deepcopy(model.state_dict())
            print(f"[INFO] patience {patience} / 3")
            if patience > 3:
                model.load_state_dict(best_model_state)
                print("[INFO] Early stopping: domain aligned and classifier converged.")
                break
        else:
            patience = 0
            best_gap = gap

    if best_model_state is not None:
        # torch.save(best_model_state, os.path.join(out_path, 'test_best_model.pth'))
        model.load_state_dict(best_model_state)



    return model






if __name__ == '__main__':
    set_seed(seed=97)
    with open("../configs/default.yaml", 'r') as f:
        config = yaml.safe_load(f)['baseline']
    batch_size = config['batch_size']
    learning_rate = config['learning_rate']
    weight_decay = config['weight_decay']
    num_layers = config['num_layers']
    kernel_size = config['kernel_size']
    start_channels = config['start_channels']
    num_epochs = config['num_epochs']

    source_path = '../datasets/source/train/DC_T197_RP.txt'
    target_path = '../datasets/target/train/HC_T197_RP.txt'
    target_test_path = '../datasets/target/test/HC_T197_RP.txt'
    out_path = 'model'
    os.makedirs(out_path, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Flexible_DANN(num_layers=num_layers,
                          start_channels=start_channels,
                          kernel_size=kernel_size,
                          cnn_act='leakrelu',
                          num_classes=10,
                          lambda_=1).to(device)

    source_loader, target_loader = get_dataloaders(source_path, target_path, batch_size)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=num_epochs, eta_min=learning_rate * 0.1
    )
    criterion_cls = nn.CrossEntropyLoss()
    criterion_domain = nn.CrossEntropyLoss()

    print("[INFO] Starting standard DANN training (no pseudo labels)...")
    model=train_dann(model, source_loader, target_loader,
               optimizer, criterion_cls, criterion_domain,
               device, num_epochs=40, lambda_=0.5,scheduler=scheduler)

    print("[INFO] Evaluating on target test set...")
    test_dataset = PKLDataset(target_test_path)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    general_test_model(model, criterion_cls, test_loader, device)


[INFO] Starting standard DANN training (no pseudo labels)...
[Epoch 1] Total Loss: 0.9912 | Cls: 0.7145 | Dom: 0.5539 | DomAcc: 0.6969
[Epoch 2] Total Loss: 0.4013 | Cls: 0.1274 | Dom: 0.5478 | DomAcc: 0.6991
[Epoch 3] Total Loss: 0.3422 | Cls: 0.0700 | Dom: 0.5444 | DomAcc: 0.7088
[Epoch 4] Total Loss: 0.3666 | Cls: 0.0772 | Dom: 0.5787 | DomAcc: 0.6749
[Epoch 5] Total Loss: 0.3608 | Cls: 0.0582 | Dom: 0.6052 | DomAcc: 0.6564
[Epoch 6] Total Loss: 0.3631 | Cls: 0.0549 | Dom: 0.6164 | DomAcc: 0.6455
[Epoch 7] Total Loss: 0.3532 | Cls: 0.0450 | Dom: 0.6165 | DomAcc: 0.6389
[Epoch 8] Total Loss: 0.3536 | Cls: 0.0415 | Dom: 0.6242 | DomAcc: 0.6320
[Epoch 9] Total Loss: 0.3472 | Cls: 0.0373 | Dom: 0.6198 | DomAcc: 0.6329
[Epoch 10] Total Loss: 0.3488 | Cls: 0.0395 | Dom: 0.6187 | DomAcc: 0.6391
[Epoch 11] Total Loss: 0.3452 | Cls: 0.0315 | Dom: 0.6273 | DomAcc: 0.6315
[Epoch 12] Total Loss: 0.3381 | Cls: 0.0230 | Dom: 0.6301 | DomAcc: 0.6310
[Epoch 13] Total Loss: 0.3513 | Cls: 0.0270 | Do

In [2]:
import copy
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import os
import numpy as np
import random
import yaml
from models.Flexible_DANN_MMD import Flexible_DANN
from PKLDataset import PKLDataset
from utils.general_train_and_test import general_test_model
from models.get_no_label_dataloader import get_target_loader
from models.MMD import *
from collections import deque



def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def get_dataloaders(source_path, target_path, batch_size):
    source_dataset = PKLDataset(txt_path=source_path)
    source_loader = DataLoader(source_dataset, batch_size=batch_size, shuffle=True)
    target_loader = get_target_loader(target_path, batch_size=batch_size, shuffle=True)
    return source_loader, target_loader

def dann_lambda(epoch, num_epochs):
    """
    常用的 DANN λ 调度：从 0 平滑升到 1
    你也可以把 -10 调轻/重来改变上升速度
    """
    p = epoch / float(num_epochs)
    return 2. / (1. + np.exp(-10 * p)) - 1.

def mmd_lambda(epoch, num_epochs, max_lambda=1e-1):
    # 0 → max_lambda，S 型上升
    p = epoch / max(1, num_epochs - 1)         # p ∈ [0,1]
    s = 1.0 / (1.0 + torch.exp(torch.tensor(-10.0*(p - 0.5))))  # ∈ (0,1)
    return float(max_lambda * s)

def train_dann_with_mmd(model, source_loader, target_loader,
                        optimizer, criterion_cls, criterion_domain,
                        device, num_epochs=20,
                        lambda_dann=0.1,           # 域分类器的权重
                        lambda_mmd_max=1e-1,       # MMD 的最大权重
                        use_mk=False,               # 是否用多核
                        scheduler=None):
    PATIENCE = 3
    MIN_EPOCH = 10

    best_gap = 0.5
    best_cls = float('inf')
    best_mmd = float('inf')
    best_model_state = None
    patience = 0

    MMD_THRESH = 3e-2  # MMD²足够小的阈值，按任务可调（0.02~0.05常见）
    MMD_PLATEAU_EPS = 5e-3  # 平台期判定的波动阈值
    mmd_hist = deque(maxlen=5)  # 用最近5个epoch判断是否进入平台期

    mmd_fn = (lambda x, y: mmd_mk_biased(x, y, gammas=(0.5,1,2,4,8))) if use_mk \
             else (lambda x, y: mmd_rbf_biased(x, y, gamma=None))

    for epoch in range(num_epochs):
        cls_loss_sum, dom_loss_sum, mmd_loss_sum, total_loss_sum = 0.0, 0.0, 0.0, 0.0
        total_cls_samples, total_dom_samples = 0, 0
        dom_correct, dom_total = 0, 0
        model.train()


        for (src_x, src_y), tgt_x in zip(source_loader, target_loader):
            src_x, src_y = src_x.to(device), src_y.to(device)
            tgt_x = tgt_x.to(device)

            cls_out_src, dom_out_src, feat_src = model(src_x)
            _,            dom_out_tgt, feat_tgt = model(tgt_x)

            # 1) 分类损失（仅源域）
            loss_cls = criterion_cls(cls_out_src, src_y)

            # 2) 域分类损失（DANN）
            dom_label_src = torch.zeros(src_x.size(0), dtype=torch.long, device=device)
            dom_label_tgt = torch.ones(tgt_x.size(0),  dtype=torch.long, device=device)
            bs_src, bs_tgt = src_x.size(0), tgt_x.size(0)
            loss_dom_src = criterion_domain(dom_out_src, dom_label_src)
            loss_dom_tgt = criterion_domain(dom_out_tgt, dom_label_tgt)

            # 样本数加权的“单个域损失均值”
            loss_dom = (loss_dom_src * bs_src + loss_dom_tgt * bs_tgt) / (bs_src + bs_tgt)
            # 3) RBF‑MMD（特征对齐）
            # 建议先做 L2 归一化，提升稳定性
            feat_src_n = F.normalize(feat_src, dim=1)
            feat_tgt_n = F.normalize(feat_tgt, dim=1)
            loss_mmd = mmd_fn(feat_src_n, feat_tgt_n)

            # 4) 组合总损失
            #    - DANN 的 lambda 可继续用你已有的动态 dann_lambda
            #    - MMD 的权重做 warm‑up（避免一开始就把决策结构抹平）
            lambda_dann_now = dann_lambda(epoch, num_epochs) if callable(lambda_dann) else lambda_dann
            lambda_mmd_now  = float(mmd_lambda(epoch, num_epochs, max_lambda=lambda_mmd_max))

            loss = loss_cls + lambda_dann_now * loss_dom + lambda_mmd_now * loss_mmd

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # 记录指标
            cls_loss_sum  += loss_cls.item() * src_x.size(0)
            dom_loss_sum  += loss_dom.item() * (src_x.size(0) + tgt_x.size(0))
            mmd_loss_sum  += loss_mmd.item() * (src_x.size(0) + tgt_x.size(0))
            total_loss_sum += loss.item() * (src_x.size(0) + tgt_x.size(0))

            total_cls_samples += src_x.size(0)
            total_dom_samples += (src_x.size(0) + tgt_x.size(0))

            # 域分类准确率
            dom_preds_src = torch.argmax(dom_out_src, dim=1)
            dom_preds_tgt = torch.argmax(dom_out_tgt, dim=1)
            dom_correct += (dom_preds_src == dom_label_src).sum().item()
            dom_correct += (dom_preds_tgt == dom_label_tgt).sum().item()
            dom_total   += dom_label_src.size(0) + dom_label_tgt.size(0)

        # ——Epoch 级日志——
        avg_cls_loss  = cls_loss_sum  / max(1, total_cls_samples)
        avg_dom_loss  = dom_loss_sum  / max(1, total_dom_samples)
        avg_mmd_loss  = mmd_loss_sum  / max(1, total_dom_samples)
        avg_total_loss= total_loss_sum/ max(1, total_dom_samples)
        dom_acc = dom_correct / max(1, dom_total)
        gap = abs(dom_acc - 0.5)

        if scheduler is not None:
            scheduler.step()

        print(f"[Epoch {epoch + 1}] Total: {avg_total_loss:.4f} | "
              f"Cls: {avg_cls_loss:.4f} | Dom: {avg_dom_loss:.4f} | "
              f"MMD: {avg_mmd_loss:.4f} | DomAcc: {dom_acc:.4f} | "
              f"λ_dann: {lambda_dann_now:.4f} | λ_mmd: {lambda_mmd_now:.4f}")

        mmd_hist.append(avg_mmd_loss)
        mmd_plateau = (len(mmd_hist) == mmd_hist.maxlen) and (max(mmd_hist) - min(mmd_hist) < MMD_PLATEAU_EPS)

        # 触发条件
        cond_align = (gap < 0.05)
        cond_cls = (avg_cls_loss < 0.5)
        cond_mmd_small = (avg_mmd_loss < MMD_THRESH)
        cond_mmd_plateau = mmd_plateau

        # 是否有任何指标刷新“最好”
        improved = False
        if gap < best_gap - 1e-4:
            best_gap = gap
            best_model_state = copy.deepcopy(model.state_dict())
            improved = True
        if avg_cls_loss < best_cls - 1e-4:
            best_cls = avg_cls_loss
            best_model_state = copy.deepcopy(model.state_dict())
            improved = True
        if avg_mmd_loss < best_mmd - 1e-5:
            best_mmd = avg_mmd_loss
            best_model_state = copy.deepcopy(model.state_dict())
            improved = True

        # ——Early stopping：对齐 + 分类收敛 + （MMD小 或 MMD平台期）——
        if epoch > MIN_EPOCH and cond_align and cond_cls and (cond_mmd_small or cond_mmd_plateau):
            if not improved:
                patience += 1
            else:
                patience = 0
            print(f"[INFO] patience {patience} / {PATIENCE} | MMD_small={cond_mmd_small} plateau={cond_mmd_plateau}")
            if patience >= PATIENCE:
                if best_model_state is not None:
                    model.load_state_dict(best_model_state)
                print("[INFO] Early stopping: domain aligned, classifier converged, and MMD stabilized.")
                break
        else:

            patience = 0
        print("[INFO] Evaluating on target test set...")
        test_path = '../datasets/target/test/HC_T191_RP.txt'
        dataset = PKLDataset(test_path)
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
        general_test_model(model, criterion_cls, loader, device)

    if best_model_state is not None:
       model.load_state_dict(best_model_state)



    return model


if __name__ == '__main__':
    set_seed(seed=91)
    with open("../configs/default.yaml", 'r') as f:
        config = yaml.safe_load(f)['baseline']
    batch_size = config['batch_size']
    learning_rate = config['learning_rate']
    weight_decay = config['weight_decay']
    num_layers = config['num_layers']
    kernel_size = config['kernel_size']
    start_channels = config['start_channels']
    num_epochs = config['num_epochs']

    source_path = '../datasets/source/train/DC_T197_RP.txt'
    target_path = '../datasets/target/train/HC_T191_RP.txt'
    target_test_path = '../datasets/target/test/HC_T191_RP.txt'
    out_path = 'model'
    os.makedirs(out_path, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Flexible_DANN(num_layers=num_layers,
                          start_channels=start_channels,
                          kernel_size=kernel_size,
                          cnn_act='leakrelu',
                          num_classes=10,
                          lambda_=1).to(device)

    source_loader, target_loader = get_dataloaders(source_path, target_path, batch_size)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=num_epochs, eta_min=learning_rate * 0.1
    )
    criterion_cls = nn.CrossEntropyLoss()
    criterion_domain = nn.CrossEntropyLoss()

    print("[INFO] Starting standard DANN training (no pseudo labels)...")
    model = train_dann_with_mmd(model, source_loader, target_loader,
                                optimizer, criterion_cls, criterion_domain,
                                device, num_epochs=30, lambda_dann=0.5, use_mk=True,scheduler=scheduler)

    print("[INFO] Evaluating on target test set...")
    test_dataset = PKLDataset(target_test_path)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    general_test_model(model, criterion_cls, test_loader, device)


[INFO] Starting standard DANN training (no pseudo labels)...
[Epoch 1] Total: 0.6300 | Cls: 0.3753 | Dom: 0.5094 | MMD: 0.1848 | DomAcc: 0.7435 | λ_dann: 0.5000 | λ_mmd: 0.0007
[INFO] Evaluating on target test set...
- test Loss: 2.631674, test Acc: 0.3990
[Epoch 2] Total: 0.3158 | Cls: 0.0853 | Dom: 0.4608 | MMD: 0.1471 | DomAcc: 0.7887 | λ_dann: 0.5000 | λ_mmd: 0.0009
[INFO] Evaluating on target test set...
- test Loss: 2.982124, test Acc: 0.3139
[Epoch 3] Total: 0.3364 | Cls: 0.0615 | Dom: 0.5494 | MMD: 0.1451 | DomAcc: 0.7184 | λ_dann: 0.5000 | λ_mmd: 0.0013
[INFO] Evaluating on target test set...
- test Loss: 3.924983, test Acc: 0.4050
[Epoch 4] Total: 0.3546 | Cls: 0.0592 | Dom: 0.5904 | MMD: 0.1306 | DomAcc: 0.6851 | λ_dann: 0.5000 | λ_mmd: 0.0019
[INFO] Evaluating on target test set...
- test Loss: 5.874240, test Acc: 0.2208
[Epoch 5] Total: 0.3726 | Cls: 0.0582 | Dom: 0.6282 | MMD: 0.1258 | DomAcc: 0.6469 | λ_dann: 0.5000 | λ_mmd: 0.0026
[INFO] Evaluating on target test set...