In [1]:
# rock_hier_train_v2.py  ── ConvNeXtV2-Large + 双头层次分类 (改进版)
# ==============================================================
# 1. 标准库 & 基础设置
# --------------------------------------------------------------
import os, math, time, random, warnings, pathlib
import platform, psutil, csv
from collections import Counter
import numpy as np
import pandas as pd
from tqdm import tqdm
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from torch_ema import ExponentialMovingAverage
import timm

import csv
import glob
import math
import os
import pathlib
import random
import time
import warnings
from collections import defaultdict
import platform
import psutil
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader, WeightedRandomSampler
from torch.cuda.amp import GradScaler, autocast
from sklearn.metrics import confusion_matrix
from torch_ema import ExponentialMovingAverage
import matplotlib.pyplot as plt

# 数据处理
import numpy as np
import pandas as pd

# 第三方SOTA模型/工具包
import timm  # 顶级SOTA模型库，实际好像服务器翻不了墙

# PyTorch及生态
import torch
import torch.nn as nn
import torchvision.transforms as T

# 可视化
from PIL import Image
from sklearn.metrics import confusion_matrix

# 混合精度训练提速省显存grad防止半精度数值不稳定，滑动平均能提供泛化能力
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader, Dataset
from torch_ema import ExponentialMovingAverage

# 进度条
from tqdm import tqdm
warnings.filterwarnings("ignore")
SEED = 114514
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic = True, False

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {DEVICE} | GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")



Device: cuda | GPU: NVIDIA L20


In [2]:
# ==============================================================
# 2. 数据路径 & 超参数
# --------------------------------------------------------------
DATA_ROOT   = "/shareddata/project/dataset"
TRAIN_CSV   = f"{DATA_ROOT}/train_val/train_labels.csv"
VAL_CSV     = f"{DATA_ROOT}/train_val/val_labels.csv"
IMG_DIR_T   = f"{DATA_ROOT}/train_val/train"
IMG_DIR_V   = f"{DATA_ROOT}/train_val/val"
TEST_IMGDIR = f"{DATA_ROOT}/test/test_images"
TEST_CSV    = f"{DATA_ROOT}/test/test_ids.csv"

NUM_PARENT  = 3
BATCH_SIZE  = 64 #占用显存43977MiB /  46068MiB
NUM_EPOCHS  = 60 #早停无所谓        
MIXUP_CUTMIX_STOP = int(NUM_EPOCHS * 0.8)   # 后 20% epoch 关闭 MixUp/CutMix
BASE_LR     = 2e-4 
WARM_EPOCHS = 5
WEIGHT_DECAY= 1e-4
MAX_GRAD_NORM = 1.0
EMA_DECAY  = 0.9999



In [3]:
# ==============================================================
# 3. 数据增强 & Dataset
# --------------------------------------------------------------
import torchvision.transforms as T
MEAN = [0.46798, 0.45764, 0.44035]
STD  = [0.18461, 0.18712, 0.19482]

tf_train = T.Compose([
    T.RandomResizedCrop(224),
    T.RandomHorizontalFlip(),
    T.RandAugment(num_ops=2, magnitude=8),  # 略降幅度
    T.ToTensor(),
    T.Normalize(MEAN, STD),
])
tf_eval = T.Compose([
    T.Resize(256), T.CenterCrop(224),
    T.ToTensor(),  T.Normalize(MEAN, STD)
])

class RockDataset(Dataset):
    def __init__(self, csv_path, img_dir, transform):
        df = pd.read_csv(csv_path)
        self.paths = df["id"].values
        self.p_labels = df["label"].values.astype(np.int64)
        self.c_labels = pd.factorize(df["sublabel"])[0].astype(np.int64)
        self.transform = transform
        self.img_dir = img_dir
    def __len__(self): return len(self.paths)
    def __getitem__(self, idx):
        img = Image.open(os.path.join(self.img_dir, self.paths[idx])).convert("RGB")
        return self.transform(img), self.p_labels[idx], self.c_labels[idx]

train_ds, val_ds = RockDataset(TRAIN_CSV, IMG_DIR_T, tf_train), RockDataset(VAL_CSV, IMG_DIR_V, tf_eval)
NUM_CHILD = int(train_ds.c_labels.max()) + 1
print(f"Dataset: train={len(train_ds)}, val={len(val_ds)}, child={NUM_CHILD}")

train_dl = DataLoader(train_ds, BATCH_SIZE, shuffle=True,  num_workers=8, pin_memory=True)
val_dl   = DataLoader(val_ds,   BATCH_SIZE, shuffle=False, num_workers=8, pin_memory=True)



Dataset: train=102213, val=15000, child=19160


In [4]:
# ==============================================================
# 4. 模型定义 (ConvNeXtV2-Large 主干 + 双头)
# --------------------------------------------------------------
CKPT_PATH  = "convnextv2_large.fcmae_ft_in22k_in1k_384.bin"
MODEL_NAME = "convnextv2_large.fcmae_ft_in22k_in1k_384"

state_dict = torch.load(CKPT_PATH, map_location="cpu")
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("head.")}
backbone   = timm.create_model(MODEL_NAME, pretrained=False, num_classes=0)
backbone.load_state_dict(state_dict, strict=False)
feat_dim   = backbone.num_features

head_parent = nn.Linear(feat_dim, NUM_PARENT)
head_child  = nn.Linear(feat_dim, NUM_CHILD)

backbone, head_parent, head_child = backbone.to(DEVICE), head_parent.to(DEVICE), head_child.to(DEVICE)

# 封装 forward
def forward_features(x):
    return backbone(x)



In [5]:
# ==============================================================
# 5. 损失函数 (类别加权 CE) & 优化器 & 调度器
# --------------------------------------------------------------
# 5.1 类别权重 (父类)
cnt_parent = Counter(train_ds.p_labels.tolist())
tot_samples = len(train_ds)
w_parent = torch.tensor([tot_samples / cnt_parent[i] for i in range(NUM_PARENT)], dtype=torch.float).to(DEVICE)
criterion_parent = nn.CrossEntropyLoss(weight=w_parent, label_smoothing=0.1)

# 5.2 子类可选加权 (实际没使用)
criterion_child  = nn.CrossEntropyLoss(label_smoothing=0.1)

# 5.3 Optimizer (分组 lr)
optimizer = torch.optim.AdamW([
    {"params": backbone.parameters(),   "lr": BASE_LR * 0.2},
    {"params": head_parent.parameters(),"lr": BASE_LR},
    {"params": head_child.parameters(), "lr": BASE_LR}
], weight_decay=WEIGHT_DECAY)

# 5.4 Scheduler: Linear Warmup -> Cosine
def lr_lambda(cur_epoch):
    if cur_epoch < WARM_EPOCHS:
        return float(cur_epoch + 1) / WARM_EPOCHS
    return 0.5 * (1 + math.cos(math.pi * (cur_epoch - WARM_EPOCHS) / (NUM_EPOCHS - WARM_EPOCHS)))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

scaler = GradScaler()
ema    = ExponentialMovingAverage([*backbone.parameters(), *head_parent.parameters(), *head_child.parameters()], decay=EMA_DECAY)



In [6]:
# ==============================================================
# 6. MixUp / CutMix 工具
# --------------------------------------------------------------
def rand_bbox(size, lam):
    H, W = size[2], size[3]
    cut_rat = math.sqrt(1. - lam)
    cut_w, cut_h = int(W * cut_rat), int(H * cut_rat)
    cx, cy = np.random.randint(W), np.random.randint(H)
    x1, y1 = np.clip(cx - cut_w // 2, 0, W), np.clip(cy - cut_h // 2, 0, H)
    x2, y2 = np.clip(cx + cut_w // 2, 0, W), np.clip(cy + cut_h // 2, 0, H)
    return x1, y1, x2, y2

def mixup_cutmix(images, labels, alpha=0.2):
    lam = np.random.beta(alpha, alpha)
    bs   = images.size(0)
    idx  = torch.randperm(bs, device=images.device)
    if random.random() < 0.5:  # MixUp
        mixed = lam * images + (1 - lam) * images[idx]
    else:                      # CutMix
        x1,y1,x2,y2 = rand_bbox(images.size(), lam)
        mixed = images.clone()
        mixed[:, :, y1:y2, x1:x2] = images[idx, :, y1:y2, x1:x2]
        lam = 1 - ((x2-x1)*(y2-y1)/(images.size(-1)*images.size(-2)))
    return mixed, labels, labels[idx], lam



In [7]:
# ==============================================================
# 7. 评估函数 (使用 EMA 权重)
# --------------------------------------------------------------
@torch.no_grad()
def evaluate():
    backbone.eval(); head_parent.eval(); head_child.eval()
    totals, corrects, loss_sum = 0, 0, 0.0
    for imgs, p_lbl, _ in val_dl:
        imgs, p_lbl = imgs.to(DEVICE), p_lbl.to(DEVICE)
        with autocast():
            feats = forward_features(imgs)
            logits = head_parent(feats)
            loss   = criterion_parent(logits, p_lbl)
        loss_sum += loss.item() * imgs.size(0)
        corrects += (logits.argmax(1) == p_lbl).sum().item()
        totals   += imgs.size(0)
    return loss_sum / totals, corrects / totals



In [None]:
# ==== 5. Train / Eval with Logging, EMA, EarlyStopping, Resume ======================
best_p_acc, best_c_acc = 0.0, 0.0
patience, non_improve = 5, 0

CKPT_DIR = "checkpoint"
os.makedirs(CKPT_DIR, exist_ok=True)
start_epoch = 1

# === Resume from checkpoint ===
ckpt_list = sorted(glob.glob(os.path.join(CKPT_DIR, "best_parent_*.pth")), key=os.path.getmtime)
if ckpt_list:
    latest = ckpt_list[-1]
    print(f"🔄 Resumed from {latest}")
    ckpt = torch.load(latest, map_location=DEVICE)
    backbone.load_state_dict(ckpt["backbone"])
    head_parent.load_state_dict(ckpt["head_p"])
    if "head_c" in ckpt: head_child.load_state_dict(ckpt["head_c"])
    if "opt" in ckpt: optimizer.load_state_dict(ckpt["opt"])
    if "sched" in ckpt: scheduler.load_state_dict(ckpt["sched"])
    if "scaler" in ckpt: scaler.load_state_dict(ckpt["scaler"])
    if "ema" in ckpt: ema.load_state_dict(ckpt["ema"])
    best_p_acc = ckpt.get("p_acc", best_p_acc)
    best_c_acc = ckpt.get("c_acc", best_c_acc)
    start_epoch = ckpt.get("epoch", start_epoch) + 1
    print(f"  → start_epoch={start_epoch} | best_p_acc={best_p_acc:.4f} | best_c_acc={best_c_acc:.4f}")

# === CSV Logger ===
log_file = pathlib.Path("train_log.csv")
with log_file.open("a", newline="") as log_fh:
    logger = csv.writer(log_fh)
    if log_file.stat().st_size == 0:
        logger.writerow(["epoch","lr","train_loss","p_loss","p_acc","c_loss","c_acc","time_sec"])

    def evaluate():
        backbone.eval(); head_parent.eval(); head_child.eval()
        p_correct = c_correct = total = 0
        p_loss_sum = c_loss_sum = 0.0
        all_p_preds, all_p_gts = [], []

        with torch.no_grad():
            for imgs, y_p, y_c in val_dl:
                imgs, y_p, y_c = imgs.to(DEVICE), y_p.to(DEVICE), y_c.to(DEVICE)
                with autocast():
                    feats = forward_features(imgs)
                    logits_p = head_parent(feats)
                    logits_c = head_child(feats)
                    loss_p = criterion_parent(logits_p, y_p)
                    loss_c = criterion_child(logits_c, y_c)

                p_loss_sum += loss_p.item() * y_p.size(0)
                c_loss_sum += loss_c.item() * y_c.size(0)
                pred_p = logits_p.argmax(1)
                p_correct += (pred_p == y_p).sum().item()
                c_correct += (logits_c.argmax(1) == y_c).sum().item()
                total += y_p.size(0)

                all_p_preds.extend(pred_p.cpu().tolist())
                all_p_gts.extend(y_p.cpu().tolist())

        from sklearn.metrics import confusion_matrix
        cm = confusion_matrix(all_p_gts, all_p_preds, labels=[0,1,2])
        return p_loss_sum/total, p_correct/total, c_loss_sum/total, c_correct/total, cm

    # === Training Loop ===
    for epoch in range(start_epoch, NUM_EPOCHS + 1):
        t0 = time.time()
        backbone.train(); head_parent.train(); head_child.train()
        run_loss, tot = 0.0, 0

        use_mix = epoch <= MIXUP_CUTMIX_STOP
        for imgs, y_p, y_c in tqdm(train_dl, desc=f"Epoch[{epoch}/{NUM_EPOCHS}]", ncols=120):
            imgs, y_p, y_c = imgs.to(DEVICE), y_p.to(DEVICE), y_c.to(DEVICE)
            optimizer.zero_grad()

            if use_mix:
                imgs, y_p_a, y_p_b, lam = mixup_cutmix(imgs, y_p)

            with autocast():
                feats = forward_features(imgs)
                logits_p = head_parent(feats)
                logits_c = head_child(feats)

                loss_p = (lam * criterion_parent(logits_p, y_p_a) + (1 - lam) * criterion_parent(logits_p, y_p_b)) if use_mix else criterion_parent(logits_p, y_p)
                child_weight = max(0.0, (epoch - WARM_EPOCHS) / (NUM_EPOCHS - WARM_EPOCHS))
                loss_c = criterion_child(logits_c, y_c)
                loss = (1 - child_weight) * loss_p + child_weight * loss_c

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(backbone.parameters(), MAX_GRAD_NORM)
            scaler.step(optimizer)
            scaler.update()
            ema.update()

            run_loss += loss_p.item() * imgs.size(0)
            tot += imgs.size(0)

        if epoch > WARM_EPOCHS:
            scheduler.step()

        # === EMA 验证 ===
        with ema.average_parameters():
            p_val_loss, p_val_acc, c_val_loss, c_val_acc, cm = evaluate()

        train_loss = run_loss / tot
        elapsed = time.time() - t0
        cur_lr = optimizer.param_groups[0]["lr"]

        print(f"[E{epoch:02d}] lr={cur_lr:.2e} | train={train_loss:.4f} | p[loss={p_val_loss:.4f}, acc={p_val_acc*100:.2f}%] | c[loss={c_val_loss:.4f}, acc={c_val_acc*100:.2f}%] | time={elapsed:.1f}s")
        print("↳ Parent Confusion Matrix:\n", cm)

        logger.writerow([epoch, f"{cur_lr:.2e}", f"{train_loss:.4f}", f"{p_val_loss:.4f}", f"{p_val_acc:.4f}", f"{c_val_loss:.4f}", f"{c_val_acc:.4f}", f"{elapsed:.1f}"])
        log_fh.flush()

        improved = False
        if p_val_acc > best_p_acc:
            best_p_acc, improved = p_val_acc, True
            torch.save({
                "backbone": backbone.state_dict(),
                "head_p": head_parent.state_dict(),
                "head_c": head_child.state_dict(),
                "opt": optimizer.state_dict(),
                "sched": scheduler.state_dict(),
                "scaler": scaler.state_dict(),
                "ema": ema.state_dict(),
                "p_acc": best_p_acc,
                "c_acc": c_val_acc,
                "epoch": epoch
            }, os.path.join(CKPT_DIR, f"best_parent_{best_p_acc:.4f}.pth"))
            print("  ✅ 最强父类模型已更新")
        if c_val_acc > best_c_acc:
            best_c_acc = c_val_acc
            torch.save({
                "backbone": backbone.state_dict(),
                "head_c": head_child.state_dict(),
                "c_acc": best_c_acc
            }, os.path.join(CKPT_DIR, f"best_child_{best_c_acc:.4f}.pth"))
            print("  ✅ 最强子类模型已更新")

        non_improve = 0 if improved else non_improve + 1
        if non_improve >= patience:
            print(f"⏹️ Early stopping at epoch {epoch} after {patience} non-improving epochs.")
            break

print(f"✅ Training Finished | Best Parent Acc={best_p_acc:.4f} | Best Child Acc={best_c_acc:.4f}")


Epoch[1/60]: 100%|██████████████████████████████████████████████████████████████████| 1598/1598 [17:57<00:00,  1.48it/s]


[E01] lr=8.00e-06 | train=0.9540 | p[loss=0.8023, acc=69.30%] | c[loss=10.0139, acc=0.01%] | time=1164.7s
↳ Parent Confusion Matrix:
 [[3504  643  853]
 [ 917 3565  518]
 [1201  473 3326]]
  ✅ 最强父类模型已更新
  ✅ 最强子类模型已更新


Epoch[2/60]: 100%|██████████████████████████████████████████████████████████████████| 1598/1598 [17:47<00:00,  1.50it/s]


[E02] lr=8.00e-06 | train=0.8902 | p[loss=0.7573, acc=72.65%] | c[loss=10.0196, acc=0.02%] | time=1153.5s
↳ Parent Confusion Matrix:
 [[3576  659  765]
 [ 769 3825  406]
 [1003  500 3497]]
  ✅ 最强父类模型已更新
  ✅ 最强子类模型已更新


Epoch[3/60]: 100%|██████████████████████████████████████████████████████████████████| 1598/1598 [17:48<00:00,  1.50it/s]


[E03] lr=8.00e-06 | train=0.8672 | p[loss=0.7311, acc=74.37%] | c[loss=10.0224, acc=0.01%] | time=1154.0s
↳ Parent Confusion Matrix:
 [[3631  614  755]
 [ 693 3910  397]
 [ 929  457 3614]]
  ✅ 最强父类模型已更新


Epoch[4/60]: 100%|██████████████████████████████████████████████████████████████████| 1598/1598 [17:48<00:00,  1.50it/s]


[E04] lr=8.00e-06 | train=0.8491 | p[loss=0.7112, acc=75.57%] | c[loss=10.0247, acc=0.01%] | time=1154.1s
↳ Parent Confusion Matrix:
 [[3697  600  703]
 [ 648 3965  387]
 [ 910  416 3674]]
  ✅ 最强父类模型已更新


Epoch[5/60]: 100%|██████████████████████████████████████████████████████████████████| 1598/1598 [17:46<00:00,  1.50it/s]


[E05] lr=8.00e-06 | train=0.8263 | p[loss=0.6974, acc=76.46%] | c[loss=10.0251, acc=0.01%] | time=1148.9s
↳ Parent Confusion Matrix:
 [[3740  591  669]
 [ 615 4030  355]
 [ 902  399 3699]]
  ✅ 最强父类模型已更新


Epoch[6/60]: 100%|██████████████████████████████████████████████████████████████████| 1598/1598 [17:46<00:00,  1.50it/s]


[E06] lr=1.60e-05 | train=0.8176 | p[loss=0.6884, acc=77.30%] | c[loss=10.4327, acc=0.09%] | time=1149.4s
↳ Parent Confusion Matrix:
 [[3775  557  668]
 [ 607 4040  353]
 [ 833  387 3780]]
  ✅ 最强父类模型已更新
  ✅ 最强子类模型已更新


Epoch[7/60]: 100%|██████████████████████████████████████████████████████████████████| 1598/1598 [17:46<00:00,  1.50it/s]


[E07] lr=2.40e-05 | train=0.8178 | p[loss=0.6787, acc=77.87%] | c[loss=10.6438, acc=0.09%] | time=1149.0s
↳ Parent Confusion Matrix:
 [[3837  514  649]
 [ 609 4040  351]
 [ 854  343 3803]]
  ✅ 最强父类模型已更新


Epoch[8/60]: 100%|██████████████████████████████████████████████████████████████████| 1598/1598 [17:46<00:00,  1.50it/s]


[E08] lr=3.20e-05 | train=0.8133 | p[loss=0.6687, acc=78.59%] | c[loss=10.7988, acc=0.05%] | time=1148.2s
↳ Parent Confusion Matrix:
 [[3890  467  643]
 [ 599 4039  362]
 [ 823  318 3859]]
  ✅ 最强父类模型已更新


Epoch[9/60]: 100%|██████████████████████████████████████████████████████████████████| 1598/1598 [17:46<00:00,  1.50it/s]


[E09] lr=4.00e-05 | train=0.8036 | p[loss=0.6586, acc=79.52%] | c[loss=10.9657, acc=0.05%] | time=1148.5s
↳ Parent Confusion Matrix:
 [[3904  489  607]
 [ 564 4098  338]
 [ 768  306 3926]]
  ✅ 最强父类模型已更新


Epoch[10/60]: 100%|█████████████████████████████████████████████████████████████████| 1598/1598 [17:46<00:00,  1.50it/s]


[E10] lr=4.00e-05 | train=0.7966 | p[loss=0.6525, acc=79.93%] | c[loss=11.1043, acc=0.01%] | time=1148.4s
↳ Parent Confusion Matrix:
 [[3942  456  602]
 [ 561 4098  341]
 [ 761  290 3949]]
  ✅ 最强父类模型已更新


Epoch[11/60]: 100%|█████████████████████████████████████████████████████████████████| 1598/1598 [17:46<00:00,  1.50it/s]


[E11] lr=4.00e-05 | train=0.7835 | p[loss=0.6458, acc=80.49%] | c[loss=11.2092, acc=0.03%] | time=1148.6s
↳ Parent Confusion Matrix:
 [[3941  460  599]
 [ 535 4132  333]
 [ 713  287 4000]]
  ✅ 最强父类模型已更新


Epoch[12/60]: 100%|█████████████████████████████████████████████████████████████████| 1598/1598 [17:46<00:00,  1.50it/s]


[E12] lr=3.99e-05 | train=0.7553 | p[loss=0.6408, acc=80.71%] | c[loss=11.3061, acc=0.03%] | time=1148.2s
↳ Parent Confusion Matrix:
 [[3940  462  598]
 [ 522 4143  335]
 [ 699  278 4023]]
  ✅ 最强父类模型已更新


Epoch[13/60]: 100%|█████████████████████████████████████████████████████████████████| 1598/1598 [17:46<00:00,  1.50it/s]


[E13] lr=3.97e-05 | train=0.7320 | p[loss=0.6374, acc=81.05%] | c[loss=11.3469, acc=0.03%] | time=1149.1s
↳ Parent Confusion Matrix:
 [[3919  476  605]
 [ 488 4196  316]
 [ 675  282 4043]]
  ✅ 最强父类模型已更新


Epoch[14/60]: 100%|█████████████████████████████████████████████████████████████████| 1598/1598 [17:46<00:00,  1.50it/s]


[E14] lr=3.95e-05 | train=0.7093 | p[loss=0.6384, acc=81.12%] | c[loss=11.3650, acc=0.01%] | time=1148.8s
↳ Parent Confusion Matrix:
 [[3912  456  632]
 [ 494 4188  318]
 [ 663  269 4068]]
  ✅ 最强父类模型已更新


Epoch[15/60]: 100%|█████████████████████████████████████████████████████████████████| 1598/1598 [17:45<00:00,  1.50it/s]


[E15] lr=3.92e-05 | train=0.7020 | p[loss=0.6401, acc=81.47%] | c[loss=11.4073, acc=0.02%] | time=1147.8s
↳ Parent Confusion Matrix:
 [[3910  443  647]
 [ 468 4222  310]
 [ 638  274 4088]]
  ✅ 最强父类模型已更新


Epoch[16/60]: 100%|█████████████████████████████████████████████████████████████████| 1598/1598 [17:46<00:00,  1.50it/s]


[E16] lr=3.88e-05 | train=0.6798 | p[loss=0.6430, acc=81.71%] | c[loss=11.4206, acc=0.03%] | time=1148.3s
↳ Parent Confusion Matrix:
 [[3906  441  653]
 [ 466 4233  301]
 [ 616  266 4118]]
  ✅ 最强父类模型已更新


Epoch[17/60]: 100%|█████████████████████████████████████████████████████████████████| 1598/1598 [17:46<00:00,  1.50it/s]


[E17] lr=3.84e-05 | train=0.6692 | p[loss=0.6463, acc=81.53%] | c[loss=11.4367, acc=0.03%] | time=1148.6s
↳ Parent Confusion Matrix:
 [[3882  461  657]
 [ 461 4224  315]
 [ 614  262 4124]]


Epoch[18/60]: 100%|█████████████████████████████████████████████████████████████████| 1598/1598 [17:46<00:00,  1.50it/s]


[E18] lr=3.79e-05 | train=0.6578 | p[loss=0.6483, acc=81.72%] | c[loss=11.4240, acc=0.03%] | time=1148.7s
↳ Parent Confusion Matrix:
 [[3886  439  675]
 [ 445 4238  317]
 [ 603  263 4134]]
  ✅ 最强父类模型已更新


Epoch[19/60]: 100%|█████████████████████████████████████████████████████████████████| 1598/1598 [17:46<00:00,  1.50it/s]


[E19] lr=3.74e-05 | train=0.6484 | p[loss=0.6516, acc=81.55%] | c[loss=11.4160, acc=0.03%] | time=1149.3s
↳ Parent Confusion Matrix:
 [[3857  452  691]
 [ 446 4233  321]
 [ 589  268 4143]]


Epoch[20/60]: 100%|█████████████████████████████████████████████████████████████████| 1598/1598 [17:45<00:00,  1.50it/s]


[E20] lr=3.68e-05 | train=0.6377 | p[loss=0.6534, acc=81.61%] | c[loss=11.4204, acc=0.03%] | time=1148.1s
↳ Parent Confusion Matrix:
 [[3844  445  711]
 [ 430 4241  329]
 [ 588  256 4156]]


Epoch[21/60]: 100%|█████████████████████████████████████████████████████████████████| 1598/1598 [17:46<00:00,  1.50it/s]


[E21] lr=3.62e-05 | train=0.6245 | p[loss=0.6549, acc=81.77%] | c[loss=11.4018, acc=0.03%] | time=1148.5s
↳ Parent Confusion Matrix:
 [[3828  449  723]
 [ 421 4246  333]
 [ 560  248 4192]]
  ✅ 最强父类模型已更新


Epoch[22/60]: 100%|█████████████████████████████████████████████████████████████████| 1598/1598 [17:46<00:00,  1.50it/s]


[E22] lr=3.55e-05 | train=0.6249 | p[loss=0.6567, acc=81.85%] | c[loss=11.3920, acc=0.03%] | time=1148.5s
↳ Parent Confusion Matrix:
 [[3835  444  721]
 [ 419 4243  338]
 [ 550  251 4199]]
  ✅ 最强父类模型已更新


Epoch[23/60]: 100%|█████████████████████████████████████████████████████████████████| 1598/1598 [17:46<00:00,  1.50it/s]


[E23] lr=3.47e-05 | train=0.6192 | p[loss=0.6580, acc=81.79%] | c[loss=11.3926, acc=0.03%] | time=1147.9s
↳ Parent Confusion Matrix:
 [[3832  445  723]
 [ 425 4235  340]
 [ 548  250 4202]]


Epoch[24/60]:  20%|████████████▉                                                     | 313/1598 [03:30<14:15,  1.50it/s]

In [15]:
# ==============================================================
# 9. 推理生成 submission.csv (EMA 权重 + 多尺度水平翻转 TTA)
# --------------------------------------------------------------
print("\n==> Generating submission.csv")
CKPT_DIR = "checkpoint"
os.makedirs(CKPT_DIR, exist_ok=True)
#state = torch.load("best_parent_0.8185.pth", map_location=DEVICE)
best_ckpt = sorted(os.listdir(CKPT_DIR))[-1]
state = torch.load(os.path.join(CKPT_DIR, best_ckpt), map_location=DEVICE)

backbone.load_state_dict(state["backbone"]); head_parent.load_state_dict(state["head_p"])
ema.load_state_dict(state["ema"])
backbone.eval(); head_parent.eval()

# TTA 变换 (3 尺度 × 左右翻转)
SCALES = [224, 256, 288]
tta_trans = []
for s in SCALES:
    for flip in [False, True]:
        t = [T.Resize(s), T.CenterCrop(224)]
        if flip: t.append(T.RandomHorizontalFlip(p=1.0))
        t += [T.ToTensor(), T.Normalize(MEAN, STD)]
        tta_trans.append(T.Compose(t))
TTA_N = len(tta_trans)

df_test = pd.read_csv(TEST_CSV)
ids = df_test["id"].values
out_ids, out_labels = [], []

with torch.no_grad(), ema.average_parameters(), autocast():
    for i in tqdm(range(0, len(ids), BATCH_SIZE), ncols=120, desc="Inference"):
        batch_ids = ids[i:i+BATCH_SIZE]
        imgs_all = []
        for fname in batch_ids:
            img = Image.open(os.path.join(TEST_IMGDIR, fname)).convert("RGB")
            imgs_all.extend([tf(img) for tf in tta_trans])
        imgs_all = torch.stack(imgs_all).to(DEVICE)

        feats = forward_features(imgs_all)
        logits = head_parent(feats).view(len(batch_ids), TTA_N, NUM_PARENT)
        logits = logits.mean(1)          # TTA 平均
        preds  = logits.argmax(1).cpu().numpy()

        out_ids.extend(batch_ids)
        out_labels.extend(preds.tolist())

pd.DataFrame({"id": out_ids, "label": out_labels}).to_csv("submission.csv", index=False)
print("Saved submission.csv")



==> Generating submission.csv


Inference: 100%|██████████████████████████████████████████████████████████████████████| 235/235 [20:49<00:00,  5.32s/it]

Saved submission.csv



