In [4]:
# ===== Train EfficientNetB3 / Xception (pretrained) =====
import os, time, random
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import InterpolationMode

# -------- Params (sửa ở đây) --------
DATASET_ROOT   = "Dataset"                  # Dataset/train, val, test
MODEL_KEY      = "efficientnet_b3"          # {"efficientnet_b3", "xception"}
INPUT_SIZE     = 224
BATCH_SIZE     = 32
EPOCHS         = 10
LR             = 1e-4
WEIGHT_DECAY   = 1e-4
NUM_WORKERS    = 4
USE_AMP        = True                       # chỉ bật khi dùng CUDA
FREEZE_BACKBONE = False                     # True: chỉ train head vài epoch đầu
WARMUP_EPOCHS   = 0                         # >0: số epoch freeze, sau đó unfreeze
DROP_CONNECT   = 0.2                        # cho EfficientNetB3
DROPOUT        = 0.3                        # cho EfficientNetB3
DROP_RATE_XCP  = 0.2                        # cho Xception
DROP_PATH_XCP  = 0.1                        # cho Xception

# -------- Setup --------
def set_seed(s=42):
    random.seed(s); np.random.seed(s); torch.manual_seed(s)
    torch.cuda.manual_seed_all(s); torch.backends.cudnn.benchmark = True
set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DATASET_ALIAS = Path(DATASET_ROOT).name

# -------- Data --------
mean, std = [0.485,0.456,0.406], [0.229,0.224,0.225]
train_tfms = transforms.Compose([
    transforms.Resize(int(INPUT_SIZE*1.15), interpolation=InterpolationMode.BILINEAR),
    transforms.RandomCrop(INPUT_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])
eval_tfms = transforms.Compose([
    transforms.Resize(int(INPUT_SIZE*1.15), interpolation=InterpolationMode.BILINEAR),
    transforms.CenterCrop(INPUT_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

def make_loader(split, tfm, shuffle):
    ds = datasets.ImageFolder(os.path.join(DATASET_ROOT, split), transform=tfm)
    print(f"[{split}] classes ->", ds.classes, ds.class_to_idx)
    kw = dict(batch_size=BATCH_SIZE, shuffle=shuffle,
              num_workers=NUM_WORKERS, pin_memory=(device.type=="cuda"))
    if NUM_WORKERS > 0:
        kw.update(dict(persistent_workers=True, prefetch_factor=2))
    return ds, DataLoader(ds, **kw)

train_ds, train_loader = make_loader("train", train_tfms, True)
val_ds,   val_loader   = make_loader("val",   eval_tfms,   False)
test_ds,  test_loader  = make_loader("test",  eval_tfms,   False)
NUM_CLASSES = len(train_ds.classes)

# -------- Model --------
from my_efficientnet import EfficientNetB3
from my_xception import Xception

def build_model(key, num_classes):
    k = key.lower()
    if k == "efficientnet_b3":
        m = EfficientNetB3(num_classes=num_classes,
                           drop_connect_rate=DROP_CONNECT,
                           dropout=DROPOUT,
                           pretrained=True,
                           freeze_backbone=FREEZE_BACKBONE)
        name = "efficientnet_b3"
    elif k == "xception":
        m = Xception(num_classes=num_classes,
                     drop_rate=DROP_RATE_XCP,
                     drop_path_rate=DROP_PATH_XCP,
                     pretrained=True,
                     freeze_backbone=FREEZE_BACKBONE)
        name = "xception"
    else:
        raise ValueError(k)
    return m, name

model, model_name = build_model(MODEL_KEY, NUM_CLASSES)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
scaler = torch.amp.GradScaler('cuda', enabled=(USE_AMP and device.type=="cuda"))

# -------- Train/Eval --------
def run_epoch(loader, train_mode=True):
    model.train(train_mode)
    total, correct, loss_sum = 0, 0, 0.0
    for x,y in loader:
        x = x.to(device, non_blocking=True); y = y.to(device, non_blocking=True)
        with torch.set_grad_enabled(train_mode):
            with torch.amp.autocast('cuda', enabled=scaler.is_enabled()):
                logits = model(x); loss = criterion(logits, y)
        if train_mode:
            optimizer.zero_grad(set_to_none=True)
            if scaler.is_enabled():
                scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
            else:
                loss.backward(); optimizer.step()
        loss_sum += loss.item()*x.size(0)
        correct += (logits.argmax(1)==y).sum().item()
        total += x.size(0)
    return loss_sum/total, correct/total

# -------- Loop --------
best_val = -1.0
ckpt_path = f"{model_name}_{DATASET_ALIAS}_best.pth"

for ep in range(1, EPOCHS+1):
    if FREEZE_BACKBONE and ep == WARMUP_EPOCHS+1:
        if hasattr(model, "unfreeze"): model.unfreeze()
    t0 = time.time()
    tr_loss, tr_acc = run_epoch(train_loader, True)
    val_loss, val_acc = run_epoch(val_loader, False)
    scheduler.step()
    if val_acc > best_val:
        best_val = val_acc
        torch.save({"model": model.state_dict(),
                    "epoch": ep,
                    "val_acc": best_val,
                    "model_name": model_name,
                    "dataset_alias": DATASET_ALIAS,
                    "input_size": INPUT_SIZE}, ckpt_path)
    print(f"Epoch {ep:02d} | train {tr_loss:.4f}/{tr_acc:.4f} | "
          f"val {val_loss:.4f}/{val_acc:.4f} | best {best_val:.4f} | "
          f"time {time.time()-t0:.1f}s")

print(f"\nSaved: {ckpt_path}")

# -------- Test --------
if os.path.isfile(ckpt_path):
    sd = torch.load(ckpt_path, map_location=device)["model"]
    model.load_state_dict(sd)
test_loss, test_acc = run_epoch(test_loader, False)
print(f"TEST: loss {test_loss:.4f} | acc {test_acc:.4f}")


[train] classes -> ['fake', 'real'] {'fake': 0, 'real': 1}
[val] classes -> ['fake', 'real'] {'fake': 0, 'real': 1}
[test] classes -> ['fake', 'real'] {'fake': 0, 'real': 1}
[EfficientNetB3] Đang tạo model 'tf_efficientnet_b3' với pretrained=True
[EfficientNetB3] ✅ Tạo thành công!
[EfficientNetB3] ✅ Model validation passed


KeyboardInterrupt: 