In [None]:
# -*- coding: utf-8 -*-
"""
Batch Inference Script for DeepJSCC (Multi SNR and Ratio)
"""
import os
import glob
import torch
from PIL import Image
from torchvision import transforms

from model2 import DeepJSCC, ratio2filtersize
from utils import image_normalization

# -------- CONFIGURATION --------
BASE_INPUT_DIR = '/home/MATLAB_DATA/TiNguyen/Sentry_Data'  # Root directory with subfolders for input
BASE_OUTPUT_DIR = '/home/MATLAB_DATA/TiNguyen/SentryJSCC/Rayleigh100time'             # Root for saving output

DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
DATASET = 'imagenet'
CHANNEL_TYPE = 'Rayleigh'
IMAGE_SIZE = (64, 64)

SNR_LIST = [19.0, 11.0]             # Add more SNRs here if needed
RATIO_LIST = [1/6, 1/12]         # Add more ratios here if needed
Nsamples = 100

# -------- Auto Checkpoint Finder --------
def auto_find_checkpoint(dataset, c, snr, ratio, channel, Nsamples, base_dir='./out/checkpoints'):
    prefix = f"{dataset.upper()}_{c}_{snr}_{ratio:.2f}_{channel}_{Nsamples}"
    candidates = [
        os.path.join(base_dir, d)
        for d in os.listdir(base_dir)
        if os.path.isdir(os.path.join(base_dir, d)) and d.startswith(prefix)
    ]
    if not candidates:
        raise FileNotFoundError(f"No checkpoint directories found with prefix: {prefix}")
    latest_dir = max(candidates, key=os.path.getmtime)
    ckpts = glob.glob(os.path.join(latest_dir, 'epoch_*.pth'))
    if not ckpts:
        raise FileNotFoundError(f"No checkpoint files in: {latest_dir}")
    latest_ckpt = sorted(ckpts, key=os.path.getmtime)[-1]
    print(f"Found checkpoint: {latest_ckpt}")
    return latest_ckpt


# -------- Image Processing --------
def load_image(image_path, image_size):
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
    ])
    img = Image.open(image_path).convert("RGB")
    img_tensor = transform(img).unsqueeze(0)
    return img_tensor


# -------- Model Loader --------
def load_model(checkpoint_path, snr, ratio, channel_type, image_size, device):
    dummy_img = torch.randn(3, *image_size)
    c = ratio2filtersize(dummy_img, ratio)
    print(f"Loading model with inner channel c={c}")

    model = DeepJSCC(c=c, snr=snr, channel_type=channel_type)
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    model.to(device)
    model.eval()
    return model


# -------- Inference Function --------
def run_inference(model, image_tensor, device):
    image_tensor = image_tensor.to(device)
    with torch.no_grad():
        output = model(image_tensor)
        output = image_normalization('denormalization')(output)
    return output.squeeze(0).cpu()


# -------- Batch Evaluation --------
def process_folder(input_dir, output_dir, model, image_size, device):
    os.makedirs(output_dir, exist_ok=True)

    image_paths = sorted(
        glob.glob(os.path.join(input_dir, '**', '*.*'), recursive=True)
    )
    image_paths = [
        p for p in image_paths if p.lower().endswith(('.jpg', '.jpeg', '.png'))
    ]

    if not image_paths:
        print(f"No image files found in {input_dir}")
        return

    for img_path in image_paths:
        try:
            img_tensor = load_image(img_path, image_size)
            output_tensor = run_inference(model, img_tensor, device)
            output_tensor = output_tensor / 255

            relative_path = os.path.relpath(img_path, input_dir)
            save_path = os.path.join(output_dir, relative_path)
            os.makedirs(os.path.dirname(save_path), exist_ok=True)

            out_img = transforms.ToPILImage()(output_tensor.clamp(0, 1))
            out_img.save(save_path)
            print(f"Processed: {relative_path}")
        except Exception as e:
            print(f"Failed on {img_path}: {e}")


# -------- Main Loop --------
def main():
    print("Starting batch inference...")
    

    for snr in SNR_LIST:
        for ratio in RATIO_LIST:
            print(f"\n🔧 SNR: {snr}, Ratio: {ratio:.4f}")

            dummy_img = torch.randn(3, *IMAGE_SIZE)
            c = ratio2filtersize(dummy_img, ratio)

            # Auto checkpoint
            checkpoint_path = auto_find_checkpoint(DATASET, c, snr, ratio, CHANNEL_TYPE, Nsamples)

            # Load model
            model = load_model(checkpoint_path, snr, ratio, CHANNEL_TYPE, IMAGE_SIZE, DEVICE)

            # Build input/output paths
            input_dir = BASE_INPUT_DIR
            output_dir = os.path.join(BASE_OUTPUT_DIR, f"Sentry_Data_snr{snr}_ratio{int(1/ratio)}")

            print(f"Input: {input_dir}")
            print(f"Output: {output_dir}")

            # Process
            process_folder(input_dir, output_dir, model, IMAGE_SIZE, DEVICE)

    print("\nAll configurations processed.")


if __name__ == "__main__":
    main()


Starting batch inference...

🔧 SNR: 19.0, Ratio: 0.1667
Found checkpoint: ./out/checkpoints/IMAGENET_8_19.0_0.17_AWGN_10_21h24m09s_on_Jul_18_2025/epoch_502.pth
Loading model with inner channel c=8
Input: /home/MATLAB_DATA/TiNguyen/Sentry_Data
Output: /home/MATLAB_DATA/TiNguyen/SentryJSCC/AWGN10time/Sentry_Data_snr19.0_ratio6


  model.load_state_dict(torch.load(checkpoint_path, map_location=device))


Processed: test/AnnualCrop/AnnualCrop_2111.jpg
Processed: test/AnnualCrop/AnnualCrop_2112.jpg
Processed: test/AnnualCrop/AnnualCrop_2113.jpg
Processed: test/AnnualCrop/AnnualCrop_2114.jpg
Processed: test/AnnualCrop/AnnualCrop_2115.jpg
Processed: test/AnnualCrop/AnnualCrop_2116.jpg
Processed: test/AnnualCrop/AnnualCrop_2117.jpg
Processed: test/AnnualCrop/AnnualCrop_2118.jpg
Processed: test/AnnualCrop/AnnualCrop_2119.jpg
Processed: test/AnnualCrop/AnnualCrop_2120.jpg
Processed: test/AnnualCrop/AnnualCrop_2121.jpg
Processed: test/AnnualCrop/AnnualCrop_2122.jpg
Processed: test/AnnualCrop/AnnualCrop_2123.jpg
Processed: test/AnnualCrop/AnnualCrop_2124.jpg
Processed: test/AnnualCrop/AnnualCrop_2125.jpg
Processed: test/AnnualCrop/AnnualCrop_2126.jpg
Processed: test/AnnualCrop/AnnualCrop_2127.jpg
Processed: test/AnnualCrop/AnnualCrop_2128.jpg
Processed: test/AnnualCrop/AnnualCrop_2129.jpg
Processed: test/AnnualCrop/AnnualCrop_2130.jpg
Processed: test/AnnualCrop/AnnualCrop_2131.jpg
Processed: te



Processed: test/Industrial/Industrial_1766.jpg
Processed: test/Industrial/Industrial_1767.jpg
Processed: test/Industrial/Industrial_1768.jpg
Processed: test/Industrial/Industrial_1769.jpg
Processed: test/Industrial/Industrial_1770.jpg
Processed: test/Industrial/Industrial_1771.jpg
Processed: test/Industrial/Industrial_1772.jpg
Processed: test/Industrial/Industrial_1773.jpg
Processed: test/Industrial/Industrial_1774.jpg
Processed: test/Industrial/Industrial_1775.jpg
Processed: test/Industrial/Industrial_1776.jpg
Processed: test/Industrial/Industrial_1777.jpg
Processed: test/Industrial/Industrial_1778.jpg
Processed: test/Industrial/Industrial_1779.jpg
Processed: test/Industrial/Industrial_1780.jpg
Processed: test/Industrial/Industrial_1781.jpg
Processed: test/Industrial/Industrial_1782.jpg
Processed: test/Industrial/Industrial_1783.jpg
Processed: test/Industrial/Industrial_1784.jpg
Processed: test/Industrial/Industrial_1785.jpg
Processed: test/Industrial/Industrial_1786.jpg
Processed: te