In [1]:
import copy
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import os
import numpy as np
import random
import yaml
from models.Flexible_DANN import Flexible_DANN
from PKLDataset import PKLDataset
from utils.pseudo_train_and_test import pseudo_test_model
from models.get_no_label_dataloader import get_target_loader

In [2]:
def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def get_dataloaders(source_path, target_path, batch_size):
    source_dataset = PKLDataset(txt_path=source_path)
    source_loader = DataLoader(source_dataset, batch_size=batch_size, shuffle=True)
    target_loader = get_target_loader(target_path, batch_size=batch_size, shuffle=True)
    return source_loader, target_loader

def dann_lambda(epoch, num_epochs):
    """
    常用的 DANN λ 调度：从 0 平滑升到 1
    你也可以把 -10 调轻/重来改变上升速度
    """
    p = epoch / float(num_epochs)
    return 2. / (1. + np.exp(-10 * p)) - 1.

In [4]:
def train_dann(model, source_loader, target_loader,
               optimizer, criterion_cls, criterion_domain,
               device, num_epochs=20, lambda_=0.1,scheduler = None):
    best_gap = 0.5
    best_model_state = None
    patience = 0
    for epoch in range(num_epochs):
        total_loss, total_cls_loss, total_dom_loss = 0.0, 0.0, 0.0
        dom_correct, dom_total = 0, 0
        model.train()
        num_batches = 0
        for (src_x, src_y), tgt_x in zip(source_loader, target_loader):
            num_batches += 1
            src_x, src_y = src_x.to(device), src_y.to(device)
            tgt_x = tgt_x.to(device)

            cls_out_src, dom_out_src = model(src_x)
            _, dom_out_tgt = model(tgt_x)

            loss_cls = criterion_cls(cls_out_src, src_y)

            dom_label_src = torch.zeros(src_x.size(0), dtype=torch.long).to(device)
            dom_label_tgt = torch.ones(tgt_x.size(0), dtype=torch.long).to(device)
            loss_dom = criterion_domain(dom_out_src, dom_label_src) + \
                       criterion_domain(dom_out_tgt, dom_label_tgt)

            dom_preds_src = torch.argmax(dom_out_src, dim=1)
            dom_preds_tgt = torch.argmax(dom_out_tgt, dim=1)
            dom_correct += (dom_preds_src == dom_label_src).sum().item()
            dom_correct += (dom_preds_tgt == dom_label_tgt).sum().item()
            dom_total += dom_label_src.size(0) + dom_label_tgt.size(0)

            loss = loss_cls + lambda_ * loss_dom

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


            total_loss += loss.item()
            total_cls_loss += loss_cls.item()
            total_dom_loss += loss_dom.item()

        dom_acc = dom_correct / dom_total
        avg_cls_loss = total_cls_loss / num_batches
        gap = abs(dom_acc - 0.5)

        if scheduler is not None:
            scheduler.step()

        print(f"[Epoch {epoch+1}] Total Loss: {total_loss:.4f} | "
              f"Cls: {avg_cls_loss:.4f} | Dom: {total_dom_loss:.4f} | "
              f"DomAcc: {dom_acc:.4f}")

        # print("[INFO] Evaluating on target test set...")
        # target_test_path = '../datasets/HC_T185_RP.txt'
        # test_dataset = PKLDataset(target_test_path)
        # test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
        # pseudo_test_model(model, criterion_cls, test_loader, device)



        if gap < 0.005 and avg_cls_loss < 0.05 and epoch > 10:
            patience +=1
            if gap < best_gap:
                best_gap = gap
                best_model_state = copy.deepcopy(model.state_dict())
            print(f"[INFO] patience {patience} / 3")
            if patience > 3:
                model.load_state_dict(best_model_state)
                print("[INFO] Early stopping: domain aligned and classifier converged.")
                break
        else:
            patience = 0
            best_gap = gap


    return model






if __name__ == '__main__':
    set_seed(seed=42)
    with open("../configs/default.yaml", 'r') as f:
        config = yaml.safe_load(f)['baseline']
    batch_size = config['batch_size']
    learning_rate = config['learning_rate']
    weight_decay = config['weight_decay']
    num_layers = config['num_layers']
    kernel_size = config['kernel_size']
    start_channels = config['start_channels']
    num_epochs = config['num_epochs']

    source_path = '../datasets/source/train/DC_T197_RP.txt'
    target_path = '../datasets/HC_T185_RP.txt'
    target_test_path = '../datasets/HC_T185_RP.txt'
    out_path = 'model'
    os.makedirs(out_path, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Flexible_DANN(num_layers=num_layers,
                          start_channels=start_channels,
                          kernel_size=kernel_size,
                          cnn_act='leakrelu',
                          num_classes=10,
                          lambda_=0.5).to(device)

    source_loader, target_loader = get_dataloaders(source_path, target_path, batch_size)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=num_epochs, eta_min=learning_rate * 0.1
    )
    criterion_cls = nn.CrossEntropyLoss()
    criterion_domain = nn.CrossEntropyLoss()

    print("[INFO] Starting standard DANN training (no pseudo labels)...")
    model=train_dann(model, source_loader, target_loader,
               optimizer, criterion_cls, criterion_domain,
               device, num_epochs=30, lambda_=0.5,scheduler=scheduler)

    print("[INFO] Evaluating on target test set...")
    test_dataset = PKLDataset(target_test_path)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    pseudo_test_model(model, criterion_cls, test_loader, device)


[INFO] Starting standard DANN training (no pseudo labels)...
[Epoch 1] Total Loss: 266.3570 | Cls: 0.5783 | Dom: 243.5677 | DomAcc: 0.7468
[Epoch 2] Total Loss: 165.3113 | Cls: 0.1503 | Dom: 255.4878 | DomAcc: 0.7335
[Epoch 3] Total Loss: 163.9266 | Cls: 0.1208 | Dom: 267.4597 | DomAcc: 0.7175
[Epoch 4] Total Loss: 172.1174 | Cls: 0.0968 | Dom: 295.8527 | DomAcc: 0.6724
[Epoch 5] Total Loss: 170.2584 | Cls: 0.0913 | Dom: 294.8573 | DomAcc: 0.6853
[Epoch 6] Total Loss: 167.0112 | Cls: 0.0723 | Dom: 297.8749 | DomAcc: 0.6864
[Epoch 7] Total Loss: 174.1572 | Cls: 0.0752 | Dom: 310.6978 | DomAcc: 0.6551
[Epoch 8] Total Loss: 174.1581 | Cls: 0.0627 | Dom: 316.9883 | DomAcc: 0.6460
[Epoch 9] Total Loss: 174.9486 | Cls: 0.0472 | Dom: 326.2841 | DomAcc: 0.6196
[Epoch 10] Total Loss: 175.5995 | Cls: 0.0569 | Dom: 322.7432 | DomAcc: 0.6314
[Epoch 11] Total Loss: 173.8257 | Cls: 0.0440 | Dom: 325.6300 | DomAcc: 0.6222
[Epoch 12] Total Loss: 169.2174 | Cls: 0.0321 | Dom: 322.3864 | DomAcc: 0.6229


In [7]:
def train_dann(model, source_loader, target_loader,
               optimizer, criterion_cls, criterion_domain,
               device, num_epochs=20, lambda_=0.1,scheduler = None):
    best_gap = 0.5
    best_model_state = None
    patience = 0
    for epoch in range(num_epochs):
        total_loss, total_cls_loss, total_dom_loss = 0.0, 0.0, 0.0
        dom_correct, dom_total = 0, 0
        model.train()
        num_batches = 0
        for (src_x, src_y), tgt_x in zip(source_loader, target_loader):
            num_batches += 1
            src_x, src_y = src_x.to(device), src_y.to(device)
            tgt_x = tgt_x.to(device)

            cls_out_src, dom_out_src = model(src_x)
            _, dom_out_tgt = model(tgt_x)

            loss_cls = criterion_cls(cls_out_src, src_y)

            dom_label_src = torch.zeros(src_x.size(0), dtype=torch.long).to(device)
            dom_label_tgt = torch.ones(tgt_x.size(0), dtype=torch.long).to(device)
            loss_dom = criterion_domain(dom_out_src, dom_label_src) + \
                       criterion_domain(dom_out_tgt, dom_label_tgt)

            dom_preds_src = torch.argmax(dom_out_src, dim=1)
            dom_preds_tgt = torch.argmax(dom_out_tgt, dim=1)
            dom_correct += (dom_preds_src == dom_label_src).sum().item()
            dom_correct += (dom_preds_tgt == dom_label_tgt).sum().item()
            dom_total += dom_label_src.size(0) + dom_label_tgt.size(0)

            loss = loss_cls + lambda_ * loss_dom

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


            total_loss += loss.item()
            total_cls_loss += loss_cls.item()
            total_dom_loss += loss_dom.item()

        dom_acc = dom_correct / dom_total
        avg_cls_loss = total_cls_loss / num_batches
        gap = abs(dom_acc - 0.5)

        if scheduler is not None:
            scheduler.step()

        print(f"[Epoch {epoch+1}] Total Loss: {total_loss:.4f} | "
              f"Cls: {avg_cls_loss:.4f} | Dom: {total_dom_loss:.4f} | "
              f"DomAcc: {dom_acc:.4f}")

        # print("[INFO] Evaluating on target test set...")
        # target_test_path = '../datasets/HC_T185_RP.txt'
        # test_dataset = PKLDataset(target_test_path)
        # test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
        # pseudo_test_model(model, criterion_cls, test_loader, device)



        if gap < 0.005 and avg_cls_loss < 0.05 and epoch > 10:
            patience +=1
            if gap < best_gap:
                best_gap = gap
                best_model_state = copy.deepcopy(model.state_dict())
            print(f"[INFO] patience {patience} / 3")
            if patience > 3:
                model.load_state_dict(best_model_state)
                print("[INFO] Early stopping: domain aligned and classifier converged.")
                break
        else:
            patience = 0
            best_gap = gap
        if best_model_state is not None:
            # torch.save(best_model_state, os.path.join(out_path, 'test_best_model.pth'))
            model.load_state_dict(best_model_state)


    return model






if __name__ == '__main__':
    set_seed(seed=44)
    with open("../configs/default.yaml", 'r') as f:
        config = yaml.safe_load(f)['baseline']
    batch_size = config['batch_size']
    learning_rate = config['learning_rate']
    weight_decay = config['weight_decay']
    num_layers = config['num_layers']
    kernel_size = config['kernel_size']
    start_channels = config['start_channels']
    num_epochs = config['num_epochs']

    source_path = '../datasets/source/train/DC_T197_RP.txt'
    target_path = '../datasets/HC_T188_RP.txt'
    target_test_path = '../datasets/HC_T188_RP.txt'
    out_path = 'model'
    os.makedirs(out_path, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Flexible_DANN(num_layers=num_layers,
                          start_channels=start_channels,
                          kernel_size=kernel_size,
                          cnn_act='leakrelu',
                          num_classes=10,
                          lambda_=0.5).to(device)

    source_loader, target_loader = get_dataloaders(source_path, target_path, batch_size)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=num_epochs, eta_min=learning_rate * 0.1
    )
    criterion_cls = nn.CrossEntropyLoss()
    criterion_domain = nn.CrossEntropyLoss()

    print("[INFO] Starting standard DANN training (no pseudo labels)...")
    model=train_dann(model, source_loader, target_loader,
               optimizer, criterion_cls, criterion_domain,
               device, num_epochs=40, lambda_=0.5,scheduler=scheduler)

    print("[INFO] Evaluating on target test set...")
    test_dataset = PKLDataset(target_test_path)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    pseudo_test_model(model, criterion_cls, test_loader, device)


[INFO] Starting standard DANN training (no pseudo labels)...
[Epoch 1] Total Loss: 312.2282 | Cls: 0.7016 | Dom: 273.6700 | DomAcc: 0.6900
[Epoch 2] Total Loss: 173.5931 | Cls: 0.1267 | Dom: 283.8536 | DomAcc: 0.6893
[Epoch 3] Total Loss: 192.4783 | Cls: 0.0823 | Dom: 343.7952 | DomAcc: 0.5965
[Epoch 4] Total Loss: 191.1798 | Cls: 0.0665 | Dom: 349.0959 | DomAcc: 0.5167
[Epoch 5] Total Loss: 183.9482 | Cls: 0.0480 | Dom: 343.8882 | DomAcc: 0.5465
[Epoch 6] Total Loss: 177.7356 | Cls: 0.0400 | Dom: 335.4661 | DomAcc: 0.5980
[Epoch 7] Total Loss: 171.7944 | Cls: 0.0369 | Dom: 325.1553 | DomAcc: 0.6084
[Epoch 8] Total Loss: 172.1987 | Cls: 0.0373 | Dom: 325.7460 | DomAcc: 0.6006
[Epoch 9] Total Loss: 176.5720 | Cls: 0.0323 | Dom: 336.9959 | DomAcc: 0.5458
[Epoch 10] Total Loss: 171.5147 | Cls: 0.0395 | Dom: 323.2586 | DomAcc: 0.6296
[Epoch 11] Total Loss: 173.0376 | Cls: 0.0349 | Dom: 328.6305 | DomAcc: 0.6055
[Epoch 12] Total Loss: 175.7335 | Cls: 0.0274 | Dom: 337.7629 | DomAcc: 0.5706


In [8]:
def train_dann(model, source_loader, target_loader,
               optimizer, criterion_cls, criterion_domain,
               device, num_epochs=20, lambda_=0.1,scheduler = None):
    best_gap = 0.5
    best_model_state = None
    patience = 0
    for epoch in range(num_epochs):
        total_loss, total_cls_loss, total_dom_loss = 0.0, 0.0, 0.0
        dom_correct, dom_total = 0, 0
        model.train()
        num_batches = 0
        for (src_x, src_y), tgt_x in zip(source_loader, target_loader):
            num_batches += 1
            src_x, src_y = src_x.to(device), src_y.to(device)
            tgt_x = tgt_x.to(device)

            cls_out_src, dom_out_src = model(src_x)
            _, dom_out_tgt = model(tgt_x)

            loss_cls = criterion_cls(cls_out_src, src_y)

            dom_label_src = torch.zeros(src_x.size(0), dtype=torch.long).to(device)
            dom_label_tgt = torch.ones(tgt_x.size(0), dtype=torch.long).to(device)
            loss_dom = criterion_domain(dom_out_src, dom_label_src) + \
                       criterion_domain(dom_out_tgt, dom_label_tgt)

            dom_preds_src = torch.argmax(dom_out_src, dim=1)
            dom_preds_tgt = torch.argmax(dom_out_tgt, dim=1)
            dom_correct += (dom_preds_src == dom_label_src).sum().item()
            dom_correct += (dom_preds_tgt == dom_label_tgt).sum().item()
            dom_total += dom_label_src.size(0) + dom_label_tgt.size(0)

            loss = loss_cls + lambda_ * loss_dom

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


            total_loss += loss.item()
            total_cls_loss += loss_cls.item()
            total_dom_loss += loss_dom.item()

        dom_acc = dom_correct / dom_total
        avg_cls_loss = total_cls_loss / num_batches
        gap = abs(dom_acc - 0.5)

        if scheduler is not None:
            scheduler.step()

        print(f"[Epoch {epoch+1}] Total Loss: {total_loss:.4f} | "
              f"Cls: {avg_cls_loss:.4f} | Dom: {total_dom_loss:.4f} | "
              f"DomAcc: {dom_acc:.4f}")


        if gap < 0.02 and avg_cls_loss < 0.05 and epoch > 10:
            patience +=1
            if gap < best_gap:
                best_gap = gap
                best_model_state = copy.deepcopy(model.state_dict())
            print(f"[INFO] patience {patience} / 3")
            if patience > 3:
                model.load_state_dict(best_model_state)
                print("[INFO] Early stopping: domain aligned and classifier converged.")
                break
        else:
            patience = 0
            best_gap = gap
        if best_model_state is not None:
            # torch.save(best_model_state, os.path.join(out_path, 'test_best_model.pth'))
            model.load_state_dict(best_model_state)


    return model






if __name__ == '__main__':
    set_seed(seed=44)
    with open("../configs/default.yaml", 'r') as f:
        config = yaml.safe_load(f)['baseline']
    batch_size = config['batch_size']
    learning_rate = config['learning_rate']
    weight_decay = config['weight_decay']
    num_layers = config['num_layers']
    kernel_size = config['kernel_size']
    start_channels = config['start_channels']
    num_epochs = config['num_epochs']

    source_path = '../datasets/source/train/DC_T197_RP.txt'
    target_path = '../datasets/target/train/HC_T188_RP.txt'
    target_test_path = '../datasets/target/test/HC_T188_RP.txt'
    out_path = 'model'
    os.makedirs(out_path, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Flexible_DANN(num_layers=num_layers,
                          start_channels=start_channels,
                          kernel_size=kernel_size,
                          cnn_act='leakrelu',
                          num_classes=10,
                          lambda_=0.5).to(device)

    source_loader, target_loader = get_dataloaders(source_path, target_path, batch_size)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=num_epochs, eta_min=learning_rate * 0.1
    )
    criterion_cls = nn.CrossEntropyLoss()
    criterion_domain = nn.CrossEntropyLoss()

    print("[INFO] Starting standard DANN training (no pseudo labels)...")
    model=train_dann(model, source_loader, target_loader,
               optimizer, criterion_cls, criterion_domain,
               device, num_epochs=40, lambda_=0.5,scheduler=scheduler)

    print("[INFO] Evaluating on target test set...")
    test_dataset = PKLDataset(target_test_path)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    pseudo_test_model(model, criterion_cls, test_loader, device)


[INFO] Starting standard DANN training (no pseudo labels)...
[Epoch 1] Total Loss: 301.3290 | Cls: 0.7057 | Dom: 249.8221 | DomAcc: 0.7611
[Epoch 2] Total Loss: 170.2509 | Cls: 0.1536 | Dom: 263.6899 | DomAcc: 0.7250
[Epoch 3] Total Loss: 166.8857 | Cls: 0.0880 | Dom: 289.7780 | DomAcc: 0.6844
[Epoch 4] Total Loss: 173.9117 | Cls: 0.0730 | Dom: 311.3126 | DomAcc: 0.6421
[Epoch 5] Total Loss: 172.0922 | Cls: 0.0737 | Dom: 307.3454 | DomAcc: 0.6677
[Epoch 6] Total Loss: 171.1056 | Cls: 0.0514 | Dom: 316.5218 | DomAcc: 0.6331
[Epoch 7] Total Loss: 164.3007 | Cls: 0.0284 | Dom: 314.4232 | DomAcc: 0.6486
[Epoch 8] Total Loss: 164.9267 | Cls: 0.0443 | Dom: 307.6916 | DomAcc: 0.6545
[Epoch 9] Total Loss: 169.1464 | Cls: 0.0382 | Dom: 319.1860 | DomAcc: 0.6306
[Epoch 10] Total Loss: 165.3398 | Cls: 0.0345 | Dom: 313.4244 | DomAcc: 0.6603
[Epoch 11] Total Loss: 172.6444 | Cls: 0.0324 | Dom: 329.0802 | DomAcc: 0.5953
[Epoch 12] Total Loss: 174.2674 | Cls: 0.0225 | Dom: 337.2850 | DomAcc: 0.5726
