In [1]:
import torch
import torch.nn as nn
from itertools import cycle
from typing import Tuple

# --- imports from your project ---
from PrepareData.DataLoader import (
    source_train_loader, source_val_loader,
    target_train_loader, target_val_loader
)
from PrepareData.SignalSegments import LABEL_TO_IDX
from Backbone.CNN1D import CNN1D
from Backbone.CNN2D import CNN2D
from Untils.untils import domain_loss_from_batch

# ========================
# CONFIG (tune these)
# ========================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
FOURIER = "STFT"             # "FFT" or "STFT"

EPOCHS = 20
LR = 1e-4
LAMBDA_DOMAIN = 0.1          # weight for domain loss
BATCH_CLIP_NORM = 5.0
SAVE_PATH = "best_model.pth"

NUM_CLASSES = len(LABEL_TO_IDX)
# ========================
# Model / Optim / Loss
# ========================
model = (CNN2D(NUM_CLASSES) if FOURIER == "STFT" else CNN1D(NUM_CLASSES)).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

# ========================
# Evaluation utility
# ========================
def evaluate(loader) -> Tuple[float, float]:
    model.eval()
    loss_sum = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y, d in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            logits, _ = model(x)
            loss = criterion(logits, y)
            loss_sum += loss.item() * y.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return (loss_sum / total) if total > 0 else 0.0, (correct / total) if total > 0 else 0.0

# ========================
# Training loop
# ========================
best_score = -1.0  # dùng điểm tổng hợp 50/50

for epoch in range(1, EPOCHS + 1):
    model.train()
    running_loss = 0.0
    src_correct = src_total = 0
    tgt_correct = tgt_total = 0

    src_iter = iter(source_train_loader)
    tgt_cycle = cycle(target_train_loader)

    for xs, ys, ds in src_iter:
        xt, yt, dt = next(tgt_cycle)

        xs, ys = xs.to(DEVICE), ys.to(DEVICE)
        xt, yt = xt.to(DEVICE), yt.to(DEVICE)
        ds, dt = ds.to(DEVICE), dt.to(DEVICE)

        # combine batches
        x = torch.cat([xs, xt], dim=0)
        domain_ids = torch.cat([ds, dt], dim=0)

        optimizer.zero_grad()
        logits, feats = model(x)
        bs = xs.size(0)

        # classification on source
        loss_cls = criterion(logits[:bs], ys)

        # domain loss on whole batch
        loss_domain = domain_loss_from_batch(feats, domain_ids)

        loss = loss_cls + LAMBDA_DOMAIN * loss_domain
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), BATCH_CLIP_NORM)
        optimizer.step()

        running_loss += loss.item() * (xs.size(0) + xt.size(0))/ (len(source_train_loader) * source_train_loader.batch_size)


        # metrics
        src_correct += (logits[:bs].argmax(1) == ys).sum().item()
        src_total   += ys.size(0)
        tgt_correct += (logits[bs:].argmax(1) == yt).sum().item()
        tgt_total   += yt.size(0)

    scheduler.step()

    src_train_acc = src_correct / src_total
    tgt_train_acc = tgt_correct / tgt_total

    # validation
    _, src_val_acc = evaluate(source_val_loader)
    _, tgt_val_acc = evaluate(target_val_loader)

    # ===============================
    #  SAVE BEST MODEL (50% SRC + 50% TGT)
    # ===============================
    score = 0.5 * src_val_acc + 0.5 * tgt_val_acc


    print(
        f"Epoch {epoch:02d} | "
        f"S_train: {src_train_acc:.4f} T_train: {tgt_train_acc:.4f} | "
        f"S_val: {src_val_acc:.4f} T_val: {tgt_val_acc:.4f} | "
        f"avg_loss: {running_loss:.4f}"
    )

    if score > best_score:
        best_score = score
        torch.save(model.state_dict(), SAVE_PATH)
        print(f"Saved best model (epoch {epoch}) | combined = {score:.4f}")

GLOBAL_LABELS: ['Bpfi', 'Bpfo', 'Misalign', 'Normal', 'Unbalance']
LABEL_TO_IDX: {'Bpfi': 0, 'Bpfo': 1, 'Misalign': 2, 'Normal': 3, 'Unbalance': 4}
SOURCE: 17850 3825 3825

TARGET SPLIT:
Train: 4800
Val: 3150
Test: 13050


  seg = torch.tensor(seg, dtype=torch.float32)


Epoch 01 | S_train: 0.5908 T_train: 0.5654 | S_val: 0.8076 T_val: 0.5581 | avg_loss: 2.1340
Saved best model (epoch 1) | combined = 0.6828
Epoch 02 | S_train: 0.8524 T_train: 0.7290 | S_val: 0.9281 T_val: 0.5997 | avg_loss: 0.8207
Saved best model (epoch 2) | combined = 0.7639
Epoch 03 | S_train: 0.9267 T_train: 0.7557 | S_val: 0.9600 T_val: 0.5917 | avg_loss: 0.4924
Saved best model (epoch 3) | combined = 0.7759
Epoch 04 | S_train: 0.9568 T_train: 0.7578 | S_val: 0.9739 T_val: 0.6425 | avg_loss: 0.3222
Saved best model (epoch 4) | combined = 0.8082
Epoch 05 | S_train: 0.9714 T_train: 0.7642 | S_val: 0.9707 T_val: 0.6514 | avg_loss: 0.2397
Saved best model (epoch 5) | combined = 0.8111
Epoch 06 | S_train: 0.9788 T_train: 0.7662 | S_val: 0.9786 T_val: 0.6152 | avg_loss: 0.1838
Epoch 07 | S_train: 0.9824 T_train: 0.7736 | S_val: 0.9880 T_val: 0.6721 | avg_loss: 0.1539
Saved best model (epoch 7) | combined = 0.8300
Epoch 08 | S_train: 0.9841 T_train: 0.7764 | S_val: 0.9898 T_val: 0.6921 |