In [None]:
# # ===== FAST Fine-tune: EfficientNetB3 / Xception =====
# import os, time, random
# from pathlib import Path
# import numpy as np
# import torch
# import torch.nn as nn
# from torch.utils.data import DataLoader, Subset
# from torchvision import datasets, transforms
# from torchvision.transforms import InterpolationMode

# # ---------- Params ----------
# DATASET_ROOT   = "../Dataset"                 # Dataset/train, val, test
# MODEL_KEY      = "efficientnet_b3"         # {"efficientnet_b3", "xception"}
# INPUT_SIZE     = 224
# BATCH_SIZE     = 64                        # ↑ nếu đủ VRAM
# EPOCHS         = 10
# LR             = 1e-4
# WEIGHT_DECAY   = 1e-4
# NUM_WORKERS    = min(8, os.cpu_count() or 4)
# USE_AMP_WISH   = True
# CHANNELS_LAST  = True
# AUGMENT        = False                     # False = nhanh hơn (giữ Normalize)
# MAX_TRAIN_SAMPLES_PER_EPOCH = 20000        # None = dùng full dữ liệu mỗi epoch
# FREEZE_BACKBONE = True
# WARMUP_EPOCHS   = 2
# DROP_CONNECT   = 0.2
# DROPOUT        = 0.3
# DROP_RATE_XCP  = 0.2
# DROP_PATH_XCP  = 0.1

# # ---------- Setup ----------
# def set_seed(s=42):
#     random.seed(s); np.random.seed(s); torch.manual_seed(s)
#     torch.cuda.manual_seed_all(s); torch.backends.cudnn.benchmark = True
# set_seed(42)

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# cc = torch.cuda.get_device_capability() if device.type=="cuda" else (0,0)
# USE_AMP = bool(USE_AMP_WISH and device.type=="cuda" and cc[0] >= 7)  # P100 (6.0) -> False
# DATASET_ALIAS = Path(DATASET_ROOT).name
# if device.type=="cuda":
#     try: torch.set_float32_matmul_precision("high")
#     except: pass
#     print(f"GPU: {torch.cuda.get_device_name(0)} | CC: {cc[0]}.{cc[1]} | AMP: {USE_AMP}")

# # ---------- Data ----------
# mean, std = [0.485,0.456,0.406], [0.229,0.224,0.225]
# eval_tfms = transforms.Compose([
#     transforms.Resize(int(INPUT_SIZE*1.15), interpolation=InterpolationMode.BILINEAR),
#     transforms.CenterCrop(INPUT_SIZE),
#     transforms.ToTensor(),
#     transforms.Normalize(mean, std),
# ])
# train_tfms = eval_tfms if not AUGMENT else transforms.Compose([
#     transforms.Resize(int(INPUT_SIZE*1.15), interpolation=InterpolationMode.BILINEAR),
#     transforms.RandomCrop(INPUT_SIZE),
#     transforms.RandomHorizontalFlip(),
#     transforms.ToTensor(),
#     transforms.Normalize(mean, std),
# ])

# def make_loader(ds, shuffle, batch_size=BATCH_SIZE):
#     kwargs = dict(batch_size=batch_size, shuffle=shuffle,
#                   num_workers=NUM_WORKERS, pin_memory=(device.type=="cuda"))
#     if NUM_WORKERS > 0: kwargs.update(dict(persistent_workers=True, prefetch_factor=4))
#     return DataLoader(ds, **kwargs)

# # tạo dataset 1 lần
# train_ds = datasets.ImageFolder(os.path.join(DATASET_ROOT, "train"), transform=train_tfms)
# val_ds   = datasets.ImageFolder(os.path.join(DATASET_ROOT, "val"),   transform=eval_tfms)
# test_ds  = datasets.ImageFolder(os.path.join(DATASET_ROOT, "test"),  transform=eval_tfms)
# print(f"[train] {len(train_ds)} | [val] {len(val_ds)} | [test] {len(test_ds)} | classes={train_ds.classes}")
# NUM_CLASSES = len(train_ds.classes)

# # ---------- Model ----------
# from my_efficientnet import EfficientNetB3
# from my_xception import Xception

# def build_model(key, num_classes):
#     k = key.lower()
#     if k == "efficientnet_b3":
#         m = EfficientNetB3(num_classes=num_classes,
#                            drop_connect_rate=DROP_CONNECT,
#                            dropout=DROPOUT,
#                            pretrained=True,
#                            freeze_backbone=FREEZE_BACKBONE)
#         name = "efficientnet_b3"
#     elif k == "xception":
#         m = Xception(num_classes=num_classes,
#                      drop_rate=DROP_RATE_XCP,
#                      drop_path_rate=DROP_PATH_XCP,
#                      pretrained=True,
#                      freeze_backbone=FREEZE_BACKBONE)
#         name = "xception"
#     else:
#         raise ValueError(k)
#     return m, name

# model, model_name = build_model(MODEL_KEY, NUM_CLASSES)
# model = model.to(device, memory_format=torch.channels_last if CHANNELS_LAST else torch.contiguous_format)

# criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
# scaler = torch.amp.GradScaler('cuda', enabled=USE_AMP)

# # ---------- Train/Eval ----------
# def run_epoch(loader, train_mode=True):
#     model.train(train_mode)
#     tot, correct, loss_sum = 0, 0, 0.0
#     for x,y in loader:
#         x = x.to(device, non_blocking=True)
#         if CHANNELS_LAST: x = x.to(memory_format=torch.channels_last)
#         y = y.to(device, non_blocking=True)
#         with torch.set_grad_enabled(train_mode):
#             with torch.amp.autocast('cuda', enabled=scaler.is_enabled()):
#                 logits = model(x); loss = criterion(logits, y)
#         if train_mode:
#             optimizer.zero_grad(set_to_none=True)
#             if scaler.is_enabled():
#                 scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
#             else:
#                 loss.backward(); optimizer.step()
#         loss_sum += loss.item()*x.size(0)
#         correct += (logits.argmax(1)==y).sum().item()
#         tot += x.size(0)
#     return loss_sum/tot, correct/tot

# # ---------- Loop ----------
# best_val = -1.0
# ckpt_path = f"{model_name}_{DATASET_ALIAS}_best.pth"
# val_loader = make_loader(val_ds, shuffle=False)
# test_loader = make_loader(test_ds, shuffle=False)

# for ep in range(1, EPOCHS+1):
#     if FREEZE_BACKBONE and ep == WARMUP_EPOCHS+1 and hasattr(model, "unfreeze"):
#         model.unfreeze()

#     # chọn ngẫu nhiên N mẫu/epoch (nếu set)
#     if MAX_TRAIN_SAMPLES_PER_EPOCH and MAX_TRAIN_SAMPLES_PER_EPOCH < len(train_ds):
#         idx = np.random.permutation(len(train_ds))[:MAX_TRAIN_SAMPLES_PER_EPOCH]
#         train_sub = Subset(train_ds, idx)
#         train_loader = make_loader(train_sub, shuffle=True)
#         epoch_info = f"{len(idx)}/{len(train_ds)} imgs"
#     else:
#         train_loader = make_loader(train_ds, shuffle=True)
#         epoch_info = f"{len(train_ds)} imgs"

#     t0 = time.time()
#     tr_loss, tr_acc = run_epoch(train_loader, True)
#     val_loss, val_acc = run_epoch(val_loader, False)
#     scheduler.step()

#     if val_acc > best_val:
#         best_val = val_acc
#         torch.save({"model": model.state_dict(),
#                     "epoch": ep,
#                     "val_acc": best_val,
#                     "model_name": model_name,
#                     "dataset_alias": DATASET_ALIAS,
#                     "input_size": INPUT_SIZE}, ckpt_path)

#     print(f"Epoch {ep:02d} [{epoch_info}] | "
#           f"train {tr_loss:.4f}/{tr_acc:.4f} | "
#           f"val {val_loss:.4f}/{val_acc:.4f} | "
#           f"best {best_val:.4f} | time {time.time()-t0:.1f}s")

# print(f"\nSaved: {ckpt_path}")

# # ---------- Test ----------
# sd = torch.load(ckpt_path, map_location=device)["model"]
# model.load_state_dict(sd)
# test_loss, test_acc = run_epoch(test_loader, False)
# print(f"TEST: loss {test_loss:.4f} | acc {test_acc:.4f}")


[train] classes -> ['fake', 'real'] {'fake': 0, 'real': 1}
[val] classes -> ['fake', 'real'] {'fake': 0, 'real': 1}
[test] classes -> ['fake', 'real'] {'fake': 0, 'real': 1}
[EfficientNetB3] Đang tạo model 'tf_efficientnet_b3' với pretrained=True
[EfficientNetB3] ✅ Tạo thành công!
[EfficientNetB3] ✅ Model validation passed


KeyboardInterrupt: 

In [None]:
# ===== FAST Fine-tune: EfficientNetB3 / Xception với Augmentation khác nhau =====
import os, time, random
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from torchvision.transforms import InterpolationMode
from sklearn.metrics import precision_score, recall_score, f1_score

# ---------- Model-specific Config (CHỈ ĐỔI AUGMENTATION) ----------
CONFIG = {
    "xception": {
        "augment": "strong",      # Strong augmentation (RandAugment + RandomErasing)
        "lr": 1e-4,
        "weight_decay": 1e-4,
    },
    "efficientnet_b3": {
        "augment": "basic",       # Basic augmentation (Crop + Flip)
        "lr": 1e-4,
        "weight_decay": 1e-4,
    }
}

# ---------- Params ----------
DATASET_ROOT   = "../Dataset"
MODEL_KEY      = "xception"                    # {"efficientnet_b3", "xception"}
INPUT_SIZE     = 224
BATCH_SIZE     = 64
EPOCHS         = 10
NUM_WORKERS    = min(8, os.cpu_count() or 4)
USE_AMP_WISH   = True
CHANNELS_LAST  = True
MAX_TRAIN_SAMPLES_PER_EPOCH = 20000
FREEZE_BACKBONE = True
WARMUP_EPOCHS   = 2
DROP_CONNECT   = 0.2
DROPOUT        = 0.3
DROP_RATE_XCP  = 0.2
DROP_PATH_XCP  = 0.1

# Lấy config cho model hiện tại
cfg = CONFIG[MODEL_KEY.lower()]
LR = cfg["lr"]
WEIGHT_DECAY = cfg["weight_decay"]

# ---------- Setup ----------
def set_seed(s=42):
    random.seed(s); np.random.seed(s); torch.manual_seed(s)
    torch.cuda.manual_seed_all(s); torch.backends.cudnn.benchmark = True
set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cc = torch.cuda.get_device_capability() if device.type=="cuda" else (0,0)
USE_AMP = bool(USE_AMP_WISH and device.type=="cuda" and cc[0] >= 7)
DATASET_ALIAS = Path(DATASET_ROOT).name
if device.type=="cuda":
    try: torch.set_float32_matmul_precision("high")
    except: pass
    print(f"GPU: {torch.cuda.get_device_name(0)} | CC: {cc[0]}.{cc[1]} | AMP: {USE_AMP}")

print(f"\n[Model: {MODEL_KEY}]")
print(f"Config: {cfg}")

# ---------- Data với Augmentation theo config ----------
mean, std = [0.485,0.456,0.406], [0.229,0.224,0.225]
eval_tfms = transforms.Compose([
    transforms.Resize(int(INPUT_SIZE*1.15), interpolation=InterpolationMode.BILINEAR),
    transforms.CenterCrop(INPUT_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

def build_train_transform(augment_type):
    base_resize = transforms.Resize(int(INPUT_SIZE*1.15), interpolation=InterpolationMode.BILINEAR)
    
    if augment_type == "strong":
        # Strong augmentation: RandAugment + RandomErasing
        return transforms.Compose([
            base_resize,
            transforms.RandomCrop(INPUT_SIZE),
            transforms.RandAugment(num_ops=2, magnitude=9),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
            transforms.RandomErasing(p=0.25, scale=(0.02, 0.2)),
        ])
    else:  # basic
        # Basic augmentation: Crop + Flip
        return transforms.Compose([
            base_resize,
            transforms.RandomCrop(INPUT_SIZE),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])

train_tfms = build_train_transform(cfg["augment"])

def make_loader(ds, shuffle, batch_size=BATCH_SIZE):
    kwargs = dict(batch_size=batch_size, shuffle=shuffle,
                  num_workers=NUM_WORKERS, pin_memory=(device.type=="cuda"))
    if NUM_WORKERS > 0: kwargs.update(dict(persistent_workers=True, prefetch_factor=4))
    return DataLoader(ds, **kwargs)

# tạo dataset
train_ds = datasets.ImageFolder(os.path.join(DATASET_ROOT, "train"), transform=train_tfms)
val_ds   = datasets.ImageFolder(os.path.join(DATASET_ROOT, "val"),   transform=eval_tfms)
test_ds  = datasets.ImageFolder(os.path.join(DATASET_ROOT, "test"),  transform=eval_tfms)
print(f"[train] {len(train_ds)} | [val] {len(val_ds)} | [test] {len(test_ds)} | classes={train_ds.classes}")
NUM_CLASSES = len(train_ds.classes)

# ---------- Model ----------
from my_efficientnet import EfficientNetB3
from my_xception import Xception

def build_model(key, num_classes):
    k = key.lower()
    if k == "efficientnet_b3":
        m = EfficientNetB3(num_classes=num_classes,
                           drop_connect_rate=DROP_CONNECT,
                           dropout=DROPOUT,
                           pretrained=True,
                           freeze_backbone=FREEZE_BACKBONE)
        name = "efficientnet_b3"
    elif k == "xception":
        m = Xception(num_classes=num_classes,
                     drop_rate=DROP_RATE_XCP,
                     drop_path_rate=DROP_PATH_XCP,
                     pretrained=True,
                     freeze_backbone=FREEZE_BACKBONE)
        name = "xception"
    else:
        raise ValueError(k)
    return m, name

model, model_name = build_model(MODEL_KEY, NUM_CLASSES)
model = model.to(device, memory_format=torch.channels_last if CHANNELS_LAST else torch.contiguous_format)

# ---------- Optimizer & Scheduler (GIỐNG NHAU cho cả 2 model) ----------
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
scaler = torch.amp.GradScaler('cuda', enabled=USE_AMP)

print(f"Optimizer: AdamW | Scheduler: CosineAnnealingLR | Augment: {cfg['augment']}")

# ---------- Train/Eval ----------
def run_epoch(loader, train_mode=True, compute_metrics=False):
    model.train(train_mode)
    tot, correct, loss_sum = 0, 0, 0.0
    y_true, y_pred = [], []
    
    for x,y in loader:
        x = x.to(device, non_blocking=True)
        if CHANNELS_LAST: x = x.to(memory_format=torch.channels_last)
        y = y.to(device, non_blocking=True)
        with torch.set_grad_enabled(train_mode):
            with torch.amp.autocast('cuda', enabled=scaler.is_enabled()):
                logits = model(x); loss = criterion(logits, y)
        if train_mode:
            optimizer.zero_grad(set_to_none=True)
            if scaler.is_enabled():
                scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
            else:
                loss.backward(); optimizer.step()
        
        preds = logits.argmax(1)
        loss_sum += loss.item()*x.size(0)
        correct += (preds==y).sum().item()
        tot += x.size(0)
        
        if compute_metrics:
            y_true.extend(y.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())
    
    acc = correct/tot
    if compute_metrics:
        prec = precision_score(y_true, y_pred, average='weighted')
        rec = recall_score(y_true, y_pred, average='weighted')
        f1 = f1_score(y_true, y_pred, average='weighted')
        return loss_sum/tot, acc, prec, rec, f1
    return loss_sum/tot, acc

# ---------- Loop ----------
best_val = -1.0
ckpt_path = f"{model_name}_{DATASET_ALIAS}_best.pth"
val_loader = make_loader(val_ds, shuffle=False)
test_loader = make_loader(test_ds, shuffle=False)

for ep in range(1, EPOCHS+1):
    if FREEZE_BACKBONE and ep == WARMUP_EPOCHS+1 and hasattr(model, "unfreeze"):
        model.unfreeze()

    # chọn ngẫu nhiên N mẫu/epoch (nếu set)
    if MAX_TRAIN_SAMPLES_PER_EPOCH and MAX_TRAIN_SAMPLES_PER_EPOCH < len(train_ds):
        idx = np.random.permutation(len(train_ds))[:MAX_TRAIN_SAMPLES_PER_EPOCH]
        train_sub = Subset(train_ds, idx)
        train_loader = make_loader(train_sub, shuffle=True)
        epoch_info = f"{len(idx)}/{len(train_ds)} imgs"
    else:
        train_loader = make_loader(train_ds, shuffle=True)
        epoch_info = f"{len(train_ds)} imgs"

    t0 = time.time()
    tr_loss, tr_acc = run_epoch(train_loader, True)
    val_loss, val_acc = run_epoch(val_loader, False)
    scheduler.step()

    if val_acc > best_val:
        best_val = val_acc
        torch.save({"model": model.state_dict(),
                    "epoch": ep,
                    "val_acc": best_val,
                    "model_name": model_name,
                    "dataset_alias": DATASET_ALIAS,
                    "input_size": INPUT_SIZE,
                    "config": cfg}, ckpt_path)

    curr_lr = optimizer.param_groups[0]['lr']
    print(f"Epoch {ep:02d} [{epoch_info}] | "
          f"train {tr_loss:.4f}/{tr_acc:.4f} | "
          f"val {val_loss:.4f}/{val_acc:.4f} | "
          f"best {best_val:.4f} | lr {curr_lr:.2e} | time {time.time()-t0:.1f}s")

print(f"\nSaved: {ckpt_path}")

# ---------- Test với checkpoint đã lưu ----------
print("\n📂 Loading best checkpoint for testing...")
checkpoint = torch.load(ckpt_path, map_location=device)
model.load_state_dict(checkpoint["model"])
print("✅ Model loaded!")

test_loss, test_acc, prec, rec, f1 = run_epoch(test_loader, False, compute_metrics=True)
print(f"TEST: loss {test_loss:.4f} | acc {test_acc:.4f} | prec {prec:.4f} | rec {rec:.4f} | f1 {f1:.4f}")

# Cập nhật checkpoint với test_acc
checkpoint['test_acc'] = test_acc
torch.save(checkpoint, ckpt_path)
print(f"✅ Updated {ckpt_path} with test_acc: {test_acc:.4f}")
