In [None]:
# -*- coding: utf-8 -*-
"""
Batch Inference Script for DeepJSCC
"""
import os
import glob
import torch
from PIL import Image
from torchvision import transforms

from model import DeepJSCC, ratio2filtersize
from utils import image_normalization

# -------- CONFIGURATION --------
INPUT_DIR = '/home/MATLAB_DATA/TiNguyen/Sentry_Data/test'       # Folder with input images
OUTPUT_DIR = '/home/MATLAB_DATA/TiNguyen/Sentry_Data/test_snr13_ratio112'  # Folder to save outputs
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'

# Match your training configuration
DATASET = 'imagenet'
SNR = 13.0
RATIO = 1 / 12
CHANNEL_TYPE = 'AWGN'
IMAGE_SIZE = (64, 64)


# -------- Auto Checkpoint Finder --------
def auto_find_checkpoint(dataset, c, snr, ratio, channel, base_dir='./out/checkpoints'):
    prefix = f"{dataset.upper()}_{c}_{snr}_{ratio:.2f}_{channel}_"
    candidates = [
        os.path.join(base_dir, d)
        for d in os.listdir(base_dir)
        if os.path.isdir(os.path.join(base_dir, d)) and d.startswith(prefix)
    ]
    if not candidates:
        raise FileNotFoundError(f"No checkpoint directories found with prefix: {prefix}")
    latest_dir = max(candidates, key=os.path.getmtime)
    ckpts = glob.glob(os.path.join(latest_dir, 'epoch_*.pth'))
    if not ckpts:
        raise FileNotFoundError(f"No checkpoint files in: {latest_dir}")
    latest_ckpt = sorted(ckpts, key=os.path.getmtime)[-1]
    print(f"✅ Found checkpoint: {latest_ckpt}")
    return latest_ckpt


# -------- Image Processing --------
def load_image(image_path, image_size):
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
    ])
    img = Image.open(image_path).convert("RGB")
    img_tensor = transform(img).unsqueeze(0)
    return img_tensor


# -------- Model Loader --------
def load_model(checkpoint_path, snr, ratio, channel_type, image_size, device):
    dummy_img = torch.randn(3, *image_size)
    c = ratio2filtersize(dummy_img, ratio)
    print(f"Loading model with inner channel c={c}")

    model = DeepJSCC(c=c, snr=snr, channel_type=channel_type)
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    model.to(device)
    model.eval()
    return model


# -------- Inference Function --------
def run_inference(model, image_tensor, device):
    image_tensor = image_tensor.to(device)
    with torch.no_grad():
        output = model(image_tensor)
        output = image_normalization('denormalization')(output)
    return output.squeeze(0).cpu()


# -------- Batch Evaluation --------
def process_folder(input_dir, output_dir, model, image_size, device):
    os.makedirs(output_dir, exist_ok=True)

    # Recursively find all image files (supports jpg, png, jpeg)
    image_paths = sorted(
        glob.glob(os.path.join(input_dir, '**', '*.*'), recursive=True)
    )
    image_paths = [
        p for p in image_paths if p.lower().endswith(('.jpg', '.jpeg', '.png'))
    ]

    if not image_paths:
        print(f"No image files found in {input_dir}")
        return

    for img_path in image_paths:
        try:
            img_tensor = load_image(img_path, image_size)
            output_tensor = run_inference(model, img_tensor, device)
            output_tensor = output_tensor/255

            # Create subfolder in output dir if necessary
            relative_path = os.path.relpath(img_path, input_dir)
            save_path = os.path.join(output_dir, relative_path)
            os.makedirs(os.path.dirname(save_path), exist_ok=True)

            out_img = transforms.ToPILImage()(output_tensor.clamp(0, 1))
            out_img.save(save_path)
            print(f"✓ Processed: {relative_path}")
        except Exception as e:
            print(f"⚠️ Failed on {img_path}: {e}")




# -------- Main --------
def main():
    print("🚀 Starting batch inference...")

    dummy_img = torch.randn(3, *IMAGE_SIZE)
    c = ratio2filtersize(dummy_img, RATIO)
    checkpoint_path = auto_find_checkpoint(DATASET, c, SNR, RATIO, CHANNEL_TYPE)

    model = load_model(checkpoint_path, SNR, RATIO, CHANNEL_TYPE, IMAGE_SIZE, DEVICE)
    process_folder(INPUT_DIR, OUTPUT_DIR, model, IMAGE_SIZE, DEVICE)

    print("✅ All images processed.")


if __name__ == "__main__":
    main()


🚀 Starting batch inference...
✅ Found checkpoint: ./out/checkpoints/IMAGENET_4_13.0_0.08_AWGN_19h20m35s_on_Jul_15_2025/epoch_499.pth
Loading model with inner channel c=4
✓ Processed: AnnualCrop/AnnualCrop_2111.jpg
✓ Processed: AnnualCrop/AnnualCrop_2112.jpg
✓ Processed: AnnualCrop/AnnualCrop_2113.jpg
✓ Processed: AnnualCrop/AnnualCrop_2114.jpg
✓ Processed: AnnualCrop/AnnualCrop_2115.jpg
✓ Processed: AnnualCrop/AnnualCrop_2116.jpg
✓ Processed: AnnualCrop/AnnualCrop_2117.jpg
✓ Processed: AnnualCrop/AnnualCrop_2118.jpg
✓ Processed: AnnualCrop/AnnualCrop_2119.jpg
✓ Processed: AnnualCrop/AnnualCrop_2120.jpg
✓ Processed: AnnualCrop/AnnualCrop_2121.jpg
✓ Processed: AnnualCrop/AnnualCrop_2122.jpg
✓ Processed: AnnualCrop/AnnualCrop_2123.jpg
✓ Processed: AnnualCrop/AnnualCrop_2124.jpg
✓ Processed: AnnualCrop/AnnualCrop_2125.jpg
✓ Processed: AnnualCrop/AnnualCrop_2126.jpg
✓ Processed: AnnualCrop/AnnualCrop_2127.jpg
✓ Processed: AnnualCrop/AnnualCrop_2128.jpg
✓ Processed: AnnualCrop/AnnualCrop_212

  model.load_state_dict(torch.load(checkpoint_path, map_location=device))


✓ Processed: AnnualCrop/AnnualCrop_2188.jpg
✓ Processed: AnnualCrop/AnnualCrop_2189.jpg
✓ Processed: AnnualCrop/AnnualCrop_2190.jpg
✓ Processed: AnnualCrop/AnnualCrop_2191.jpg
✓ Processed: AnnualCrop/AnnualCrop_2192.jpg
✓ Processed: AnnualCrop/AnnualCrop_2193.jpg
✓ Processed: AnnualCrop/AnnualCrop_2194.jpg
✓ Processed: AnnualCrop/AnnualCrop_2195.jpg
✓ Processed: AnnualCrop/AnnualCrop_2196.jpg
✓ Processed: AnnualCrop/AnnualCrop_2197.jpg
✓ Processed: AnnualCrop/AnnualCrop_2198.jpg
✓ Processed: AnnualCrop/AnnualCrop_2199.jpg
✓ Processed: AnnualCrop/AnnualCrop_2200.jpg
✓ Processed: AnnualCrop/AnnualCrop_2201.jpg
✓ Processed: AnnualCrop/AnnualCrop_2202.jpg
✓ Processed: AnnualCrop/AnnualCrop_2203.jpg
✓ Processed: AnnualCrop/AnnualCrop_2204.jpg
✓ Processed: AnnualCrop/AnnualCrop_2205.jpg
✓ Processed: AnnualCrop/AnnualCrop_2206.jpg
✓ Processed: AnnualCrop/AnnualCrop_2207.jpg
✓ Processed: AnnualCrop/AnnualCrop_2208.jpg
✓ Processed: AnnualCrop/AnnualCrop_2209.jpg
✓ Processed: AnnualCrop/AnnualCr