## Cell 1: Imports and Setup

In [None]:
# Install dependencies if you are in Colab or haven't installed them yet
# !pip install tifffile imagecodecs opencv-python matplotlib

import os
import torch
import cv2
import numpy as np
import tifffile
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from tqdm import tqdm # Progress bar

# Setup Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"✅ Using device: {device}")

# Display settings for Jupyter
%matplotlib inline
plt.rcParams['figure.figsize'] = [12, 6]

## Cell 2: Configuration

In [None]:
# --- USER CONFIGURATION ---
CHECKPOINT_PATH = './checkpoints/best_models/best_psnr.pt'  # Update this path!
INPUT_DIR = './real_microscope_images'                        # Update this path!
OUTPUT_DIR = './results_jupyter'                              # Where to save results

# Model Architecture Config (Must match your training exactly)
MODEL_CONFIG = {
    'img_channels': 3,
    'width': 16,
    'middle_blk_num': 1,
    'enc_blk_nums': [1, 1, 1, 14],
    'dec_blk_nums': [1, 1, 1, 1]
}

# Ensure output directory exists
os.makedirs(OUTPUT_DIR, exist_ok=True)

## Cell 3: Model Loader

In [None]:
def load_nafnet(checkpoint_path, config, device):
    """Load the trained NAFNet model."""
    print(f"⏳ Loading model from: {checkpoint_path}")
    
    # Import your local architecture
    try:
        from archs import create_model
    except ImportError:
        raise ImportError("Could not import 'archs'. Make sure archs.py is in the current directory!")

    # Create Model
    # We add 'name': 'NAFNet' assuming create_model expects it
    model, _, _ = create_model({'name': 'NAFNet', **config}, 
                               local_rank=0, global_rank=0)
    
    # Load Weights
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    if 'model_state_dict' in checkpoint:
        state_dict = checkpoint['model_state_dict']
    elif 'model' in checkpoint:
        state_dict = checkpoint['model']
    else:
        state_dict = checkpoint
        
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    print("✅ Model loaded successfully!")
    return model

# Load the model now
model = load_nafnet(CHECKPOINT_PATH, MODEL_CONFIG, device)

## Cell 4: Inference Engine (The Core Logic)

In [None]:
def process_and_visualize(model, img_path, save_dir, device):
    filename = os.path.basename(img_path)
    
    # --- 1. Load Image (Robust TIF Handling) ---
    if img_path.lower().endswith(('.tif', '.tiff')):
        img_data = tifffile.imread(img_path)
        # Handle 16-bit
        if img_data.dtype in (np.uint16, np.int16, np.uint32, np.int32):
            max_val = np.iinfo(img_data.dtype).max
            img_np = (img_data / max_val * 255.0).astype(np.uint8)
        else:
            img_np = img_data.astype(np.uint8)
            
        # Ensure RGB
        if img_np.ndim == 2:
            img_np = cv2.cvtColor(img_np, cv2.COLOR_GRAY2RGB)
        elif img_np.shape[2] == 1:
            img_np = cv2.cvtColor(img_np, cv2.COLOR_GRAY2RGB)
            
        img_pil = Image.fromarray(img_np).convert('RGB')
    else:
        img_pil = Image.open(img_path).convert('RGB')

    # --- 2. Preprocessing (Assuming [0, 1] Range) ---
    # NOTE: We DO NOT use Normalize(0.5, 0.5) based on NAFNet standard practice
    transform = transforms.ToTensor()
    img_tensor = transform(img_pil).unsqueeze(0).to(device)

    # --- 3. Inference ---
    with torch.no_grad():
        output_tensor = model(img_tensor)

    # --- 4. Post-processing ---
    # Convert tensor (0.0 - 1.0) back to numpy (0 - 255)
    output_np = output_tensor.squeeze().cpu().numpy().transpose(1, 2, 0)
    output_np = np.clip(output_np * 255.0, 0, 255).astype(np.uint8)

    # --- 5. Save Result ---
    save_path = os.path.join(save_dir, f"deblurred_{filename.split('.')[0]}.png")
    # OpenCV uses BGR, so convert RGB -> BGR for saving
    cv2.imwrite(save_path, cv2.cvtColor(output_np, cv2.COLOR_RGB2BGR))
    
    return img_pil, output_np, save_path

## Cell 5: Run and Display Results

In [None]:
# Get list of images
image_extensions = ('.png', '.jpg', '.jpeg', '.tif', '.tiff', '.bmp')
all_images = [f for f in os.listdir(INPUT_DIR) if f.lower().endswith(image_extensions)]

print(f"Found {len(all_images)} images to process.")

# Process up to 5 images for demonstration (remove [:5] to process all)
for img_name in tqdm(all_images[:5]): 
    img_path = os.path.join(INPUT_DIR, img_name)
    
    try:
        # Run inference
        original, result, save_loc = process_and_visualize(model, img_path, OUTPUT_DIR, device)
        
        # --- Visualization in Jupyter ---
        fig, axes = plt.subplots(1, 2, figsize=(12, 6))
        
        # Original
        axes[0].imshow(original)
        axes[0].set_title(f"Original: {img_name}")
        axes[0].axis('off')
        
        # Deblurred
        axes[1].imshow(result)
        axes[1].set_title("NAFNet Restoration")
        axes[1].axis('off')
        
        plt.tight_layout()
        plt.show()
        print(f"Saved to: {save_loc}\n" + "-"*50)
        
    except Exception as e:
        print(f"❌ Error processing {img_name}: {e}")