In [None]:
import torch
import timm
from torchvision import transforms
from PIL import Image
import os
# ==== Config ====
pth_path = "vit_plants_final_state_dict.pth"
num_classes = 10  # đổi thành số class thật
backbone_name = "vit_base_patch16_224"
# Load model từ timm với số class phù hợp
model = timm.create_model(backbone_name, pretrained=False, num_classes=num_classes)
# Load state_dict
state_dict = torch.load(pth_path, map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
# Transform giống lúc train
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])
# Test với 1 ảnh
img_path = "/kaggle/working/identify-plant-1/test/class_name/sample.jpg"  # đổi path
image = Image.open(img_path).convert("RGB")
image = transform(image).unsqueeze(0)  # thêm batch dim
with torch.no_grad():
    outputs = model(image)
    pred = torch.argmax(outputs, dim=1).item()
print("Predicted class index:", pred)





1:52
# test_model.py
import argparse
import os
import re
from collections import defaultdict
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
# Thử import timm (recommended)
try:
    import timm
    _HAS_TIMM = True
except Exception:
    timm = None
    _HAS_TIMM = False
# ---------- Nếu bạn đã có PlantDataset/PlantDataModule trong môi trường, import chúng ----------
try:
    from your_dataset_module import PlantDataModule, PlantDataset  # thay tên file nếu bạn đã đặt
except Exception:
    # Nếu không có, dùng bản PlantDataset rút gọn (dựa trên code bạn cung cấp)
    import os
    from collections import Counter
    from torch.utils.data import Dataset
    class PlantDataset(Dataset):
        def __init__(self, root_dir, split='test', transform=None, extensions=('.jpg', '.jpeg', '.png', '.bmp', '.tiff')):
            self.root_dir = root_dir
            self.split = split
            self.transform = transform
            self.extensions = tuple(e.lower() for e in extensions)
            self.images = []
            self.labels = []
            self.classes = []
            self.class_to_idx = {}
            self.idx_to_class = {}
            split_dir = os.path.join(self.root_dir, self.split)
            if not os.path.isdir(split_dir):
                raise ValueError(f"Split folder not found: {split_dir}")
            classes = [d for d in os.listdir(split_dir) if os.path.isdir(os.path.join(split_dir, d))]
            classes = sorted(classes)
            if len(classes) == 0:
                raise ValueError(f"No class subfolders found in {split_dir}")
            self.classes = classes
            self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}
            self.idx_to_class = {v: k for k, v in self.class_to_idx.items()}
            for cls_name in self.classes:
                cls_dir = os.path.join(split_dir, cls_name)
                for root, _, files in os.walk(cls_dir):
                    for fname in files:
                        if fname.lower().endswith(self.extensions):
                            path = os.path.join(root, fname)
                            self.images.append(path)
                            self.labels.append(self.class_to_idx[cls_name])
        def __len__(self):
            return len(self.images)
        def __getitem__(self, idx):
            img_path = self.images[idx]
            label = self.labels[idx]
            with open(img_path, 'rb') as f:
                image = Image.open(f).convert('RGB')
            if self.transform is not None:
                image = self.transform(image)
            return image, label
# ---------- Utility helpers ----------
def clean_state_dict(sd: dict):
    """Loại bỏ tiền tố 'model.' hoặc 'module.' nếu có."""
    new_sd = {}
    for k, v in sd.items():
        new_k = re.sub(r'^(model\.|module\.)', '', k)
        new_sd[new_k] = v
    return new_sd
def guess_num_classes_from_state_dict(sd: dict):
    """Tìm key classifier/ head weight để suy ra num_classes (shape[0])."""
    patterns = [
        r'(^|\.)(head|heads|classifier)(\.|_).*weight$',
        r'(^|\.)(fc)(\.|_).*weight$'
    ]
    for k, v in sd.items():
        if any(re.search(p, k) for p in patterns):
            if v.dim() >= 2:
                return v.shape[0]
    # một số checkpoint lưu weight là "model.head.weight" etc, đã được clean trước
    for k, v in sd.items():
        if ('head' in k or 'classifier' in k) and k.endswith('weight'):
            return v.shape[0]
    return None
def build_timm_model(backbone_name: str, num_classes: int, pretrained=False):
    if not _HAS_TIMM:
        raise RuntimeError("timm not installed. Install timm (`pip install timm`) to use ViT backbones.")
    model = timm.create_model(backbone_name, pretrained=pretrained, num_classes=num_classes)
    return model
# ---------- Loading weights (pth or ckpt) ----------
def load_model_from_pth(pth_path, backbone_name, num_classes, device='cpu'):
    sd = torch.load(pth_path, map_location='cpu')
    # Nếu file là dict chứa 'state_dict', lấy ra
    if isinstance(sd, dict) and 'state_dict' in sd:
        sd = sd['state_dict']
    sd = clean_state_dict(sd)
    if num_classes is None:
        num_classes = guess_num_classes_from_state_dict(sd)
        if num_classes is None:
            raise ValueError("Không thể đoán num_classes từ state_dict; hãy cung cấp --num_classes.")
    model = build_timm_model(backbone_name, num_classes, pretrained=False)
    model.load_state_dict(sd, strict=False)
    return model.to(device)
def load_model_from_ckpt(ckpt_path, backbone_name, num_classes, device='cpu'):
    ckpt = torch.load(ckpt_path, map_location='cpu')
    # Lightning thường lưu trong 'state_dict'
    if isinstance(ckpt, dict) and 'state_dict' in ckpt:
        sd = ckpt['state_dict']
    else:
        sd = ckpt
    sd = clean_state_dict(sd)
    # try to get num_classes from checkpoint hparams
    if num_classes is None:
        candidates = []
        if isinstance(ckpt, dict):
            for k in ('hyper_parameters', 'hparams', 'hyper_parameters_on_save'):
                if k in ckpt and isinstance(ckpt[k], dict):
                    for name in ('num_classes', 'n_classes', 'num_labels', 'classes'):
                        if name in ckpt[k]:
                            candidates.append(ckpt[k][name])
        # fallback to guess from state_dict
        if len(candidates) > 0:
            num_classes = int(candidates[0])
        else:
            num_classes = guess_num_classes_from_state_dict(sd)
    if num_classes is None:
        raise ValueError("Không thể xác định num_classes từ checkpoint; hãy cung cấp --num_classes.")
    model = build_timm_model(backbone_name, num_classes, pretrained=False)
    model.load_state_dict(sd, strict=False)
    return model.to(device)
# ---------- Evaluation ----------
def evaluate_model(model, dataloader, device):
    model.eval()
    total = 0
    correct = 0
    per_class_counts = defaultdict(int)
    per_class_correct = defaultdict(int)
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in dataloader:
            images, labels = batch
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            preds = outputs.argmax(dim=1)
            total += labels.size(0)
            correct += (preds == labels).sum().item()
            for p, t in zip(preds.cpu().tolist(), labels.cpu().tolist()):
                per_class_counts[t] += 1
                if p == t:
                    per_class_correct[t] += 1
            all_preds.extend(preds.cpu().tolist())
            all_labels.extend(labels.cpu().tolist())
    acc = correct / total if total > 0 else 0.0
    return acc, per_class_counts, per_class_correct, all_preds, all_labels
# ---------- Main ----------
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--weights', type=str, required=True, help='.pth or .ckpt path')
    parser.add_argument('--data_dir', type=str, required=True, help='root dataset directory (contains train/valid/test)')
    parser.add_argument('--backbone', type=str, default='vit_base_patch16_224', help='timm backbone name')
    parser.add_argument('--num_classes', type=int, default=None, help='force num_classes if cannot infer')
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
    args = parser.parse_args()
    weights = args.weights
    data_dir = args.data_dir
    backbone = args.backbone
    num_classes = args.num_classes
    device = args.device
    # Transforms (same as training val transform)
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    # Load test dataset (try PlantDataModule if exists)
    test_dataset = None
    try:
        # nếu bạn có PlantDataModule trong môi trường, sử dụng nó để đảm bảo class order giống lúc train
        from your_dataset_module import PlantDataModule  # replace if you have it in specific file
        dm = PlantDataModule(root_dir=data_dir, image_size=224, batch_size=args.batch_size, num_workers=args.num_workers)
        dm.setup('test')
        test_loader = dm.test_dataloader()
        class_names = dm.test_dataset.classes if hasattr(dm, 'test_dataset') and dm.test_dataset is not None else None
    except Exception:
        # fallback: dùng PlantDataset rút gọn
        test_dataset = PlantDataset(root_dir=data_dir, split='test', transform=transform)
        test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
        class_names = test_dataset.classes
    print(f"Found {len(class_names) if class_names is not None else '??'} classes.")
    # Load model depending on extension
    ext = os.path.splitext(weights)[1].lower()
    if ext == '.pth':
        model = load_model_from_pth(weights, backbone, num_classes, device=device)
    elif ext == '.ckpt':
        model = load_model_from_ckpt(weights, backbone, num_classes, device=device)
    else:
        raise ValueError("Unknown weights extension. Use .pth or .ckpt")
    print("Model loaded. Eval on test set...")
    acc, per_class_counts, per_class_correct, preds, labels = evaluate_model(model, test_loader, device)
    print(f"Test Top-1 Accuracy: {acc*100:.2f}%")
    print("Per-class results:")
    if class_names is None:
        for idx in sorted(per_class_counts.keys()):
            correct = per_class_correct.get(idx, 0)
            total = per_class_counts[idx]
            print(f"  class {idx:3d}: {correct}/{total} = {100*correct/total if total>0 else 0:.2f}%")
    else:
        for idx, name in enumerate(class_names):
            total = per_class_counts.get(idx, 0)
            correct = per_class_correct.get(idx, 0)
            print(f"  {name:20s} -> {correct}/{total} = {100*correct/total if total>0 else 0:.2f}%")
if __name__ == '__main__':
    main()

ModuleNotFoundError: No module named 'model'