In [18]:
import os
import random
import csv
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image

# =====================
# CONFIG
# =====================
IMG_DIR = "MauTest"
NUM_IMAGES = 8

RESNET_PATH = "resnet50_final.pt"
CNN_PATH    = "cnn_best.pt"


IMG_SIZE = 224
device = "cuda" if torch.cuda.is_available() else "cpu"

# =====================
# CNN ARCH
# =====================
class SimpleSkinCNN(nn.Module):
    def __init__(self, num_classes=4):
        super().__init__()

        def block(in_ch, out_ch):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),

                nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),

                nn.MaxPool2d(2)
            )

        self.features = nn.Sequential(
            block(3, 32),
            block(32, 64),
            block(64, 128),
            block(128, 256),
            block(256, 512),
        )
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.gap(x)
        x = self.classifier(x)
        return x

# =====================
# TRANSFORM
# =====================
val_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std =[0.229, 0.224, 0.225])
])

# =====================
# LOAD RESNET50
# =====================
res_ckpt = torch.load(RESNET_PATH, map_location=device)
class_names = res_ckpt["class_names"]
num_classes = len(class_names)

resnet = models.resnet50(weights=None)
resnet.fc = nn.Linear(resnet.fc.in_features, num_classes)
resnet.load_state_dict(res_ckpt["model_state_dict"])
resnet.to(device).eval()

# =====================
# LOAD CNN (state_dict)
# =====================
cnn = SimpleSkinCNN(num_classes=num_classes)
cnn.load_state_dict(torch.load(CNN_PATH, map_location=device))
cnn.to(device).eval()

# =====================
# HELPERS
# =====================
def predict(model, img_path):
    img = Image.open(img_path).convert("RGB")
    x = val_tf(img).unsqueeze(0).to(device)
    with torch.no_grad():
        out = model(x)
        pred = out.argmax(dim=1).item()
    return class_names[pred]

def collect_images(root_dir):
    items = []
    for label in os.listdir(root_dir):
        label_path = os.path.join(root_dir, label)
        if not os.path.isdir(label_path):
            continue
        for f in os.listdir(label_path):
            if f.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".webp")):
                items.append((os.path.join(label_path, f), label))
    return items

def sort_key(item):
    img_path, true_label = item
    order = {name: i for i, name in enumerate(class_names)}
    return (order.get(true_label, 999), os.path.basename(img_path).lower())

def format_table(headers, rows):
    # tính width từng cột
    cols = len(headers)
    widths = [len(str(h)) for h in headers]
    for r in rows:
        for i in range(cols):
            widths[i] = max(widths[i], len(str(r[i])))

    def line(ch="-"):
        return "+" + "+".join(ch * (w + 2) for w in widths) + "+"

    def row(values):
        return "|" + "|".join(f" {str(values[i]).ljust(widths[i])} " for i in range(cols)) + "|"

    out = []
    out.append(line("-"))
    out.append(row(headers))
    out.append(line("="))
    for r in rows:
        out.append(row(r))
    out.append(line("-"))
    return "\n".join(out)

def accuracy(correct_flags):
    c = sum(correct_flags)
    n = len(correct_flags)
    return c, n, (c / n if n else 0.0)

# =====================
# SAMPLE + SORT
# =====================
all_images = collect_images(IMG_DIR)
if not all_images:
    raise RuntimeError(f"Không tìm thấy ảnh trong {IMG_DIR}. Cần dạng {IMG_DIR}/<label>/*.jpg")

samples = random.sample(all_images, min(NUM_IMAGES, len(all_images)))
samples = sorted(samples, key=sort_key)

# =====================
# BUILD REPORT ROWS
# =====================
rows = []
res_ok = []
cnn_ok = []

for img_path, true_label in samples:
    fname = os.path.basename(img_path)

    res_pred = predict(resnet, img_path)
    cnn_pred = predict(cnn, img_path)

    r_ok = (res_pred == true_label)
    c_ok = (cnn_pred == true_label)

    res_ok.append(r_ok)
    cnn_ok.append(c_ok)

    rows.append([
        fname,
        true_label,
        res_pred,
        "Đúng" if r_ok else "Sai",
        cnn_pred,
        "Đúng" if c_ok else "Sai"
    ])

# =====================
# PRINT NICE TABLE
# =====================
headers = ["Ảnh", "Nhãn thật", "ResNet dự đoán", "ResNet", "CNN dự đoán", "CNN"]
print(format_table(headers, rows))

rc, rn, racc = accuracy(res_ok)
cc, cn, cacc = accuracy(cnn_ok)

print(f"\nKết quả: ResNet50 đúng {rc}/{rn} = {racc*100:.2f}% | CNN đúng {cc}/{cn} = {cacc*100:.2f}%")



+----------------+-----------+----------------+--------+-------------+------+
| Ảnh            | Nhãn thật | ResNet dự đoán | ResNet | CNN dự đoán | CNN  |
| dauden_1.png   | dauden    | dauden         | Đúng   | dauden      | Đúng |
| dauden_2.png   | dauden    | dauden         | Đúng   | dauden      | Đúng |
| dautrang_1.png | dautrang  | mun            | Sai    | mun         | Sai  |
| dautrang_3.png | dautrang  | dautrang       | Đúng   | dautrang    | Đúng |
| mun_1.jpg      | mun       | seo            | Sai    | seo         | Sai  |
| mun_2.png      | mun       | mun            | Đúng   | mun         | Đúng |
| seo_1.jpeg     | seo       | seo            | Đúng   | dautrang    | Sai  |
| seo_2.jpeg     | seo       | seo            | Đúng   | mun         | Sai  |
+----------------+-----------+----------------+--------+-------------+------+

Kết quả: ResNet50 đúng 6/8 = 75.00% | CNN đúng 4/8 = 50.00%
