In [None]:
import torch,os
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import pandas as pd
from pathlib import Path
from tqdm import tqdm 
import matplotlib.pyplot as plt 
import cv2 
def load_model(weight_path):
    model = models.resnet18()
    model.fc = nn.Linear(model.fc.in_features, 2)
    model.load_state_dict(torch.load(weight_path, map_location="cpu"))
    model.eval()
    return model

def infer_folder(model, folder, threshold=0.50):
    folder = Path(folder)

    t = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])

    rows = []
    with torch.no_grad():
        for img_path in tqdm(folder.iterdir(),total=len(os.listdir(folder))):
            if not img_path.suffix.lower() in [".jpg", ".jpeg", ".png"]:
                continue

            img = Image.open(img_path).convert("RGB")
            x = t(img).unsqueeze(0)

            out = model(x)
            prob = torch.softmax(out, dim=1)[0]

            conf, cls_idx = torch.max(prob, dim=0)
            conf = conf.item()
            cls_idx = int(cls_idx)

            cls_name = "live" if cls_idx == 0 else "photo"
            
            passed_threshold = conf >= threshold

            rows.append({
                "path": str(img_path),
                "image": img_path.name,
                "class": cls_name,
                "confidence": conf,
                "above_threshold": passed_threshold
            })

    return pd.DataFrame(rows)


# usage
model = load_model("resnet18_live_photo.pth")
print('loaded . . ')


In [None]:
df = infer_folder(model,
                  '/Users/Chandraprakash.Patra/Downloads/codes/photo_of_photo/version_3.0/mixed_data/class_live' # own dataset 
                  , threshold=0.60)
print(df)