# Advanced Restoration Pipeline Test

This notebook tests the full advanced restoration pipeline in a logical sequence:
1. **Find the Correct Deconvolution Direction:** Use `ClippedInverseFilter` to test all 5 possible directions and visually identify the most likely "correct" one for the sample image.
2. **Run PnP Restoration:** Use the "correct" direction found in step 1 as input for the `PnP_Restoration` algorithm to see its true performance.
3. **(Optional) Run DiffPIR:** As a comparison, run the `DiffPIR_Pipeline` using the same "correct" direction.


In [None]:
import sys
from pathlib import Path
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# --- Environment Setup ---
try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if IN_COLAB:
    print("Running in Google Colab. Mounting Google Drive...")
    from google.colab import drive
    drive.mount('/content/drive')
    ROOT = Path('/content/drive/MyDrive/Data Scientist/Project/Week5/week5') 
    %cd {ROOT}
    !git pull origin main
else:
    print("Running in local environment.")
    ROOT = Path.cwd()

if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))
print(f"Project Root set to: {ROOT}")

# --- Module Imports ---
from code_denoising.classical_methods.deconvolution import ClippedInverseFilter
from code_denoising.diffusion_methods.hf_denoiser import HuggingFace_Denoiser
from code_denoising.diffusion_methods.hf_diffpir import DiffPIR_Pipeline
from code_denoising.pnp_restoration import PnP_Restoration
from diffusers import DDPMScheduler, UNet2DModel

print("Imports successful!")


In [None]:
# --- 1. Load Sample Data & Define Constants ---
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Load the sample degraded image
try:
    sample_path = ROOT / "dataset/test_y/L1_000d090392623f8046ebe84a1b345bf7.npy"
    sample_image_np = np.load(sample_path)
except FileNotFoundError:
    # Adjust path for Colab if the dataset is a sibling folder
    sample_path = ROOT.parent / "dataset/test_y/L1_000d090392623f8046ebe84a1b345bf7.npy"
    sample_image_np = np.load(sample_path)
    
sample_image_torch = torch.from_numpy(sample_image_np).unsqueeze(0).unsqueeze(0).float().to(DEVICE)

# Define the 5 known convolution directions
B0_DIRS = [(-0.809, -0.5878), (-0.809, 0.5878), (0.309, -0.9511), (0.309, 0.9511), (1.0, 0.0)]

def plot_image(tensor, title=""):
    plt.imshow(tensor.squeeze().cpu().numpy(), cmap='gray')
    plt.title(title)
    plt.axis('off')

plt.figure()
plot_image(sample_image_torch, title="Original Degraded Image")
plt.show()


## Step 1: Find the Correct Deconvolution Direction
We run `ClippedInverseFilter` for all 5 directions. The output that looks the most structured and least like noise/artifacts is the result of applying the "correct" deconvolution kernel.


In [None]:
deconv_filter = ClippedInverseFilter()
restored_images_deconv = deconv_filter.run_on_all_directions(sample_image_torch, B0_DIRS)

plt.figure(figsize=(20, 4))
plt.subplot(1, 6, 1)
plot_image(sample_image_torch, title="Input")
for i, (img, b0_dir) in enumerate(zip(restored_images_deconv, B0_DIRS)):
    plt.subplot(1, 6, i + 2)
    plot_image(img, title=f"Deconv Dir {i+1}")
plt.suptitle("ClippedInverseFilter Results for all 5 Directions", fontsize=16)
plt.show()

# Based on visual inspection, we determine the most likely correct direction index.
# From the previous run, Dir 2 (-0.809, 0.5878) looked the most promising.
CORRECT_DIR_INDEX = 1 # Index 1 corresponds to Dir 2
correct_b0_dir = B0_DIRS[CORRECT_DIR_INDEX]
print(f"Visually identified correct direction: Dir {CORRECT_DIR_INDEX+1} -> {correct_b0_dir}")


## Step 2: Run PnP Restoration with the Correct Direction
Now we use the `correct_b0_dir` identified above to run the PnP algorithm. This should yield a much better result than before.


In [None]:
# Load the pre-trained denoiser model
model_save_path = ROOT / "hf_models/ddpm-celebahq-256"
if not model_save_path.exists():
    # This will trigger download and save on first run
    print("Downloading and saving denoiser model for the first time...")
    denoiser_for_save = HuggingFace_Denoiser(model_name="google/ddpm-celebahq-256", device=DEVICE)
    denoiser_for_save.model.save_pretrained(model_save_path)
    denoiser_for_save.scheduler.save_pretrained(model_save_path)
    print("Model saved.")
else:
    print("Loading denoiser model from local path.")

denoiser = HuggingFace_Denoiser(model_name=str(model_save_path), device=DEVICE)
pnp_restorer = PnP_Restoration(denoiser=denoiser)

print(f"Running PnP Restoration with direction: {correct_b0_dir}")
restored_image_pnp = pnp_restorer.run(
    degraded_image=sample_image_torch,
    B0_dir=correct_b0_dir,
    max_iter=10,
    rho=1.0, # This is a key hyperparameter to tune
    denoiser_noise_level=50 # This is another key hyperparameter
)

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plot_image(sample_image_torch, title="Input")
plt.subplot(1, 2, 2)
plot_image(restored_image_pnp, title="PnP Restored Output (Corrected)")
plt.suptitle("PnP_Restoration Result", fontsize=16)
plt.show()


## Step 3: Hyperparameter Tuning for PnP
The result from Step 2 was poor, which is expected with default hyperparameters. Now, we'll test a range of `rho` and `denoiser_noise_level` values to find a better combination for our specific image.

- `rho`: Balances deconvolution and denoising.
- `denoiser_noise_level`: Tells the denoiser how aggressively to remove noise. Our original noise sigmas were ~0.07-0.13, which are ~18-33 on a 0-255 scale. We should test values in this range.


In [None]:
import itertools

# Define ranges for hyperparameters
rho_values = [0.1, 1.0, 5.0]
noise_level_values = [15, 25, 35]

# Store results
pnp_tuning_results = []

# Create all combinations
param_combinations = list(itertools.product(rho_values, noise_level_values))

print(f"Testing {len(param_combinations)} hyperparameter combinations...")

for rho_val, noise_level_val in param_combinations:
    print(f"  Testing rho = {rho_val}, noise_level = {noise_level_val}...")
    restored_image = pnp_restorer.run(
        degraded_image=sample_image_torch,
        B0_dir=correct_b0_dir,
        max_iter=10,
        rho=rho_val,
        denoiser_noise_level=noise_level_val
    )
    pnp_tuning_results.append({
        'rho': rho_val,
        'noise_level': noise_level_val,
        'image': restored_image
    })

print("Tuning finished.")

# Plot the results
num_results = len(pnp_tuning_results)
num_cols = len(noise_level_values)
num_rows = len(rho_values)

fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols * 4, num_rows * 4))
fig.suptitle('PnP Hyperparameter Tuning Results', fontsize=20)

for i, result in enumerate(pnp_tuning_results):
    row = i // num_cols
    col = i % num_cols
    ax = axes[row, col]
    ax.imshow(result['image'].squeeze().cpu().numpy(), cmap='gray')
    ax.set_title(f"rho={result['rho']}, noise={result['noise_level']}")
    ax.axis('off')

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()
