<a href="https://colab.research.google.com/github/monikayyy/crowd-enVent-modeling/blob/master/DeBlur_Notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install opencv-python
!pip install natsort
!pip install torchmetrics
!pip install lmdb
!pip install timm

In [None]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Load TransformerPerceptualLoss module

Reference: Image Deblurring by Exploring In-depth Properties of Transformer

- changes in double_mae_model: remove qv_scale
- changes in deblurloss: add a check for dict extraction from checkpoint

In [None]:
!git clone https://github.com/erfect2020/TransformerPerceptualLoss

In [None]:
!cp -r /content/TransformerPerceptualLoss /content/drive/MyDrive/REDS/

# Fetch data
- REDS dataset
  - train
  - val
  - test

In [None]:
from huggingface_hub import hf_hub_download


repo_id = "snah/REDS"
repo_type = "dataset"

files_to_download = [
    "train_blur.zip",
    "train_sharp.zip",
    "val_blur.zip",
    "val_sharp.zip",
    "test_blur.zip"
]

downloaded_files = []
for file_path in files_to_download:
    print(f"Downloading: {file_path} from {repo_id}")
    try:
        local_path = hf_hub_download(
            repo_id=repo_id,
            filename=file_path,
            repo_type=repo_type,
            # token=True, # Use if you logged in via notebook_login() or HF_TOKEN secret
            local_dir='.',          # Optional: Download directly to current dir (./content/) instead of cache
            local_dir_use_symlinks=False # Recommended with local_dir='.' to avoid symlinks
        )
        downloaded_files.append(local_path)
        print(f"Downloaded to: {local_path}")
    except Exception as e:
        print(f"Error downloading {file_path}: {e}")

print("\nFinished download attempts.")
print("Downloaded file paths:", downloaded_files)

In [None]:
!cp -r /content/train_blur.zip /content/drive/MyDrive/REDS/
!cp -r /content/train_sharp.zip /content/drive/MyDrive/REDS/
!cp -r /content/val_blur.zip /content/drive/MyDrive/REDS/
!cp -r /content/val_sharp.zip /content/drive/MyDrive/REDS/
!cp -r /content/test_blur.zip /content/drive/MyDrive/REDS/

# Dataset Preparation
- select pairs of blur and sharp image files from train and val
- 24000 pairs in train
- 3000 pairs in val

In [None]:
import zipfile
import random
from io import BytesIO
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import os
import glob

class PairedZipDataset(Dataset):
    def __init__(self, blur_zip_path, sharp_zip_path, transform=None):
        self.blur_zip = zipfile.ZipFile(blur_zip_path, 'r')
        self.sharp_zip = zipfile.ZipFile(sharp_zip_path, 'r')
        self.transform = transform
        self.blur_files = sorted([f for f in self.blur_zip.namelist() if not f.endswith('/')])
        self.sharp_files = sorted([f for f in self.sharp_zip.namelist() if not f.endswith('/')])
        self.blur_folders = self.group_by_IMG_folder(self.blur_files)
        self.sharp_folders = self.group_by_IMG_folder(self.sharp_files)
        self.paired_files = []
        paired_count = 0

        for folder in self.blur_folders.keys():
            if folder in self.sharp_folders.keys():
                blur_imgs = self.blur_folders[folder]
                sharp_imgs = self.sharp_folders[folder]
                if not blur_imgs or not blur_imgs:
                  continue

                blur_map = {os.path.splitext(os.path.basename(p))[0]: p for p in blur_imgs}
                sharp_map = {os.path.splitext(os.path.basename(p))[0]: p for p in sharp_imgs}
                common_filenames = sorted(list(blur_map.keys() & sharp_map.keys())) # Intersection of keys

                if not common_filenames:
                    continue

                # folder_pairs = []
                for fname in common_filenames:
                    self.paired_files.append((blur_map[fname], sharp_map[fname]))
                    paired_count += 1
        print(f"Total paired files added: {paired_count}")
        if not self.paired_files:
            print(f"Warning: No paired files found after matching filenames. Check filenames and structure in zip files.")

    def group_by_IMG_folder(self, file_paths):
        folder_dict = {}
        for file_path in file_paths:
            folder = os.path.basename(os.path.dirname(file_path))
            folder_dict.setdefault(folder, []).append(file_path)
        return folder_dict

    def __len__(self):
        return len(self.paired_files)

    def __getitem__(self, idx):
        blur_file, sharp_file = self.paired_files[idx]


        blur_img = self.load_image_from_zip(self.blur_zip, blur_file)
        sharp_img = self.load_image_from_zip(self.sharp_zip, sharp_file)

        if self.transform:
            blur_img = self.transform(blur_img)
            sharp_img = self.transform(sharp_img)

        return blur_img, sharp_img

    def load_image_from_zip(self, zip_file, file):
        with zip_file.open(file) as file:
            img_data = file.read()
            img = Image.open(BytesIO(img_data)).convert('RGB')
        return img

    def load_image_from_idx(self, idx):
        blur_img, sharp_img = self.paired_files[idx]

        blur = self.load_image_from_zip(self.blur_zip, blur_img)
        sharp = self.load_image_from_zip(self.sharp_zip, sharp_img)
        return blur, sharp

    def close(self):
        try:
            if self.blur_zip:
                self.blur_zip.close()
        except Exception as e:
            print(f"Error closing blur zip: {e}")
        try:
            if self.sharp_zip:
                self.sharp_zip.close()
        except Exception as e:
            print(f"Error closing sharp zip: {e}")

    def __del__(self):
        self.close()

#Parameters

In [None]:
import os
import torch

BLUR_ZIP_PATH = '/content/drive/MyDrive/REDS/train_blur.zip'
SHARP_ZIP_PATH = '/content/drive/MyDrive/REDS/train_sharp.zip'
VAL_BLUR_ZIP_PATH = '/content/drive/MyDrive/REDS/val_blur.zip'
VAL_SHARP_ZIP_PATH = '/content/drive/MyDrive/REDS/val_sharp.zip'
TEST_BLUR_ZIP_PATH = '/content/drive/MyDrive/REDS/test_blur.zip'
OUTPUT_DIR = '/content/drive/MyDrive/REDS/mae_vit_output'

#Pretrained MAE vision transformer for feature extraction
PRETRAINED_WEIGHTS_PATH = '/content/drive/MyDrive/REDS/pytorch_model.bin'

IMG_SIZE = 224
VAL_BATCH_SIZE = 15
BATCH_SIZE = 32
EPOCHS = 20
LR_DECAY_EPOCHS = 40
LR_FINAL = 2e-5
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 0.05
NUM_WORKERS = 1
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

In [None]:
#Load modules from TransformerPerceptualLoss for feature extraction and perceptual loss calculation
import sys
sys.path.append('/content/drive/MyDrive/REDS/TransformerPerceptualLoss')
sys.path.append('/content/drive/MyDrive/REDS/TransformerPerceptualLoss/models')
sys.path.append('/content/drive/MyDrive/REDS/TransformerPerceptualLoss/loss')
sys.path.append('/content/drive/MyDrive/REDS/TransformerPerceptualLoss/utils')
print(sys.path)

In [None]:
from torch.utils.data import Dataset, DataLoader

os.makedirs(OUTPUT_DIR, exist_ok=True)


def worker_init_fn(worker_id):
    worker_info = torch.utils.data.get_worker_info()
    dataset = worker_info.dataset
    if hasattr(dataset, '_open_zips'):
         dataset._open_zips()

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

dataset = PairedZipDataset(
    blur_zip_path=BLUR_ZIP_PATH,
    sharp_zip_path=SHARP_ZIP_PATH,
    transform=transform
)

val_dataset = PairedZipDataset(
    blur_zip_path=VAL_BLUR_ZIP_PATH,
    sharp_zip_path=VAL_SHARP_ZIP_PATH,
    transform=transform
)

dataloader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True, # Use True if DEVICE is 'cuda'
        drop_last=True,
        worker_init_fn=worker_init_fn,
        persistent_workers=True if NUM_WORKERS > 0 else False # Can speed up epoch start
    )

val_dataloader = DataLoader(
        val_dataset,
        batch_size=VAL_BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True, # Use True if DEVICE is 'cuda'
        drop_last=True,
        worker_init_fn=worker_init_fn,
        persistent_workers=True if NUM_WORKERS > 0 else False # Can speed up epoch start
    )

print("DataLoader created.")

#Visualize blur, sharp, and restored images
Visualize few images from every validation batch to monitor the results of training

In [None]:
import torchvision.transforms.functional as TF
import numpy as np

import matplotlib.pyplot as plt

def visualize_validation_batch(blur_batch, pred_batch, sharp_batch, num_images=4, title_prefix=""):

    if not all(isinstance(t, torch.Tensor) for t in [blur_batch, pred_batch, sharp_batch]):
        print("Warning: All input batches must be PyTorch Tensors.")
        return

    if not (blur_batch.ndim == 4 and pred_batch.ndim == 4 and sharp_batch.ndim == 4):
        print("Warning: All input batches must be 4D tensors (B, C, H, W).")
        return

    num_to_show = min(num_images, blur_batch.shape[0], pred_batch.shape[0], sharp_batch.shape[0])

    if num_to_show == 0:
        print("Warning: No images to show (batch size might be 0, num_images=0, or mismatched batch sizes).")
        return

    # Detach tensors from the computation graph and select the subset to show
    blur_imgs_t = blur_batch[:num_to_show].detach()
    pred_imgs_t = pred_batch[:num_to_show].detach()
    sharp_imgs_t = sharp_batch[:num_to_show].detach()

    # Clamp image values to [0, 1] for proper display
    # (Important if model outputs are not strictly in this range)
    blur_imgs_t = torch.clamp(blur_imgs_t, 0, 1)
    pred_imgs_t = torch.clamp(pred_imgs_t, 0, 1)
    sharp_imgs_t = torch.clamp(sharp_imgs_t, 0, 1)

    # Create subplots: num_to_show rows, 3 columns (Input, Predicted, Ground Truth)
    # Adjust figsize as needed
    fig, axes = plt.subplots(num_to_show, 3, figsize=(12, num_to_show * 4))

    # If num_to_show is 1, axes is a 1D array, so we need to handle it
    if num_to_show == 1:
        axes = axes.reshape(1, -1) # Reshape to (1, 3) to make indexing consistent

    for i in range(num_to_show):
        # Convert tensors to NumPy arrays for matplotlib
        # Permute from (C, H, W) to (H, W, C) and move to CPU
        blur_np = blur_imgs_t[i].cpu().permute(1, 2, 0).numpy()
        pred_np = pred_imgs_t[i].cpu().permute(1, 2, 0).numpy()
        sharp_np = sharp_imgs_t[i].cpu().permute(1, 2, 0).numpy()

        # --- Column 0: Input Blurry ---
        axes[i, 0].imshow(blur_np)
        axes[i, 0].set_title(f"Input Blurry {i+1}")
        axes[i, 0].axis('off')

        # --- Column 1: Predicted Sharp ---
        axes[i, 1].imshow(pred_np)
        axes[i, 1].set_title(f"Predicted Sharp {i+1}")
        axes[i, 1].axis('off')

        # --- Column 2: Ground Truth Sharp ---
        axes[i, 2].imshow(sharp_np)
        axes[i, 2].set_title(f"Ground Truth Sharp {i+1}")
        axes[i, 2].axis('off')

    plt.suptitle(title_prefix, fontsize=14, y=1.0) # y=1.0 might be slightly high, adjust if needed
    plt.tight_layout(rect=[0, 0, 1, 0.97]) # rect to make space for suptitle
    plt.show()


#Validation Loss
Functions for calculating validation loss every epoch while training the models



- restormer
- deepdeblur

In [None]:
#Validation Logic for Restormer


from tqdm.notebook import tqdm
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
import gc

@torch.no_grad()
def validate_epoch(model, dataloader, device, criterion, visualize=False, num_images_to_show=4, epoch_num=None):
    model.eval()
    # Initialize validation loss
    val_loss = 0.0
    visualized_this_epoch = not visualize
    # Handle empty dataloader
    if not dataloader:
        print("Validation dataloader is empty.")
        model.train()
        return 0.0

    progress_bar = tqdm(dataloader, desc=f"Validating Epoch {epoch_num if epoch_num is not None else 'N/A'}", leave=False)
    for batch_idx, (blur_imgs, sharp_imgs) in enumerate(progress_bar):
        blur_imgs = blur_imgs.to(device)
        sharp_imgs = sharp_imgs.to(device)

        recover_img = model(blur_imgs)

        if not visualized_this_epoch:
            print(f"\nVisualizing Validation Batch {batch_idx} (Epoch {epoch_num if epoch_num is not None else 'N/A'})...")
            title = f"Validation - Epoch {epoch_num}" if epoch_num is not None else "Validation"
            visualize_validation_batch(blur_imgs.cpu(),
                                       recover_img.cpu(),
                                       sharp_imgs.cpu(),
                                       num_images=num_images_to_show,
                                       title_prefix=title)
            visualized_this_epoch = True # Ensure visualization happens only once per epoch call

        # Uses ReconstructLoss's forward
        losses = criterion(recover_img, sharp_imgs)
        grad_loss = losses["total_loss"]

        current_batch_loss = grad_loss.item()
        val_loss += current_batch_loss

        # Display current batch loss and running average epoch loss
        progress_bar.set_postfix(
            BatchLoss=f"{current_batch_loss:.4f}",
            AvgEpochLoss=f"{val_loss / (batch_idx + 1):.4f}"
        )

        del recover_img, losses, grad_loss, current_batch_loss


    if len(dataloader) > 0:
        avg_val_loss = val_loss / len(dataloader)
    else:
        avg_val_loss = 0.0

    try:
        del blur_imgs, sharp_imgs
    except NameError:
        pass
    gc.collect()
    if device == 'cuda':
        torch.cuda.empty_cache()

    model.train()

    return avg_val_loss

In [None]:
#Validation Logic for DeepDeblur


import torch
from tqdm.notebook import tqdm
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
import gc

@torch.no_grad()
def validate_epoch(model, dataloader, device, criterion, visualize=False, num_images_to_show=4, epoch_num=None):
    model.eval()
    # Initialize validation loss
    val_loss = 0.0
    visualized_this_epoch = not visualize
    # Handle empty dataloader
    if not dataloader:
        print("Validation dataloader is empty.")
        model.train()
        return 0.0

    progress_bar = tqdm(dataloader, desc=f"Validating Epoch {epoch_num if epoch_num is not None else 'N/A'}", leave=False)
    for batch_idx, (blur_imgs, sharp_imgs) in enumerate(progress_bar):
        blur_imgs = blur_imgs.to(device)
        sharp_imgs = sharp_imgs.to(device)

        sharp_gt_half = F.interpolate(sharp_imgs, scale_factor=0.5, mode='bilinear', align_corners=False)
        sharp_gt_quarter = F.interpolate(sharp_gt_half, scale_factor=0.5, mode='bilinear', align_corners=False)

        fine_out, mid_out, coarse_out = model(blur_imgs)

        sharp_gt_half_224 = F.interpolate(sharp_gt_half, size = 224, mode='bilinear', align_corners=False)
        sharp_gt_quarter_224 = F.interpolate(sharp_gt_quarter, size = 224, mode='bilinear', align_corners=False)

        mid_out_224 = F.interpolate(mid_out, size = 224, mode='bilinear', align_corners=False)
        coarse_out_224 = F.interpolate(coarse_out, size = 224, mode='bilinear', align_corners=False)

        if not visualized_this_epoch:
            print(f"\nVisualizing Validation Batch {batch_idx} (Epoch {epoch_num if epoch_num is not None else 'N/A'})...")
            title = f"Validation - Epoch {epoch_num}" if epoch_num is not None else "Validation"
            visualize_validation_batch(blur_imgs.cpu(),
                                       fine_out.cpu(),
                                       sharp_imgs.cpu(),
                                       num_images=num_images_to_show,
                                       title_prefix=title)
            visualized_this_epoch = True # Ensure visualization happens only once per epoch call

        # Uses ReconstructLoss's forward
        loss_fine = criterion(fine_out, sharp_imgs)
        loss_mid = criterion(mid_out_224, sharp_gt_half_224)
        loss_coarse = criterion(coarse_out_224, sharp_gt_quarter_224)
        # Combining the fine, mid and coarse losses
        losses = loss_fine["total_loss"] + loss_mid["total_loss"] + loss_coarse["total_loss"]

        current_batch_loss = losses.item()
        val_loss += current_batch_loss

        # Display current batch loss and running average epoch loss
        progress_bar.set_postfix(
            BatchLoss=f"{current_batch_loss:.4f}",
            AvgEpochLoss=f"{val_loss / (batch_idx + 1):.4f}"
        )

        del fine_out, mid_out, coarse_out, mid_out_224, coarse_out_224, sharp_gt_half, sharp_gt_quarter, sharp_gt_half_224, sharp_gt_quarter_224


    if len(dataloader) > 0:
        avg_val_loss = val_loss / len(dataloader)
    else:
        avg_val_loss = 0.0

    try:
        del blur_imgs, sharp_imgs
    except NameError:
        pass
    gc.collect()
    if device == 'cuda':
        torch.cuda.empty_cache()

    model.train()

    return avg_val_loss

# Models

##Setting up Restormer model

In [None]:
!pip install einops

if os.path.isdir('Restormer'):
  !rm -r Restormer

# Clone Restormer
!git clone https://github.com/swz30/Restormer.git
%cd Restormer

In [None]:
import sys
sys.path.append('/content/Restormer')

In [None]:
# task = 'Real_Denoising'
task = 'Single_Image_Defocus_Deblurring'
# task = 'Motion_Deblurring'
# task = 'Deraining'

# Download the pre-trained models
if task is 'Real_Denoising':
  !wget https://github.com/swz30/Restormer/releases/download/v1.0/real_denoising.pth -P Denoising/pretrained_models
if task is 'Single_Image_Defocus_Deblurring':
  !wget https://github.com/swz30/Restormer/releases/download/v1.0/single_image_defocus_deblurring.pth -P Defocus_Deblurring/pretrained_models
if task is 'Motion_Deblurring':
  !wget https://github.com/swz30/Restormer/releases/download/v1.0/motion_deblurring.pth -P Motion_Deblurring/pretrained_models
if task is 'Deraining':
  !wget https://github.com/swz30/Restormer/releases/download/v1.0/deraining.pth -P Deraining/pretrained_models


In [None]:
#Get Pretrained Restormer

import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from runpy import run_path
from skimage import img_as_ubyte
from natsort import natsorted
from glob import glob
import cv2
from tqdm import tqdm
import argparse
import numpy as np

from basicsr.models.archs.restormer_arch import Restormer

def get_weights_and_parameters(task, parameters):
    if task == 'Motion_Deblurring':
        weights = os.path.join('Motion_Deblurring', 'pretrained_models', 'motion_deblurring.pth')
    elif task == 'Single_Image_Defocus_Deblurring':
        weights = os.path.join('Defocus_Deblurring', 'pretrained_models', 'single_image_defocus_deblurring.pth')
    elif task == 'Deraining':
        weights = os.path.join('Deraining', 'pretrained_models', 'deraining.pth')
    elif task == 'Real_Denoising':
        weights = os.path.join('Denoising', 'pretrained_models', 'real_denoising.pth')
        parameters['LayerNorm_type'] =  'BiasFree'
    return weights, parameters



# use pretrained restormer's weights

#model_path = '/content/Restormer/Defocus_Deblurring/pretrained_models/single_image_defocus_deblurring.pth'

# checkpoint = torch.load(model_path, map_location='cpu')
# try:
#     model.load_state_dict(checkpoint["params"], strict=True)
#     print(f"Loaded weights from {model_path} using 'params' key.")
# except KeyError:
#     try:
#         model.load_state_dict(checkpoint, strict=True)
#         print(f"Loaded weights directly from {model_path}.")
#     except Exception as e:
#         print(f"Error loading state dict: {e}")
#         print("You might need to inspect the checkpoint keys or adjust loading logic.")


##Setting up DeepDeblur model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ---------------------------------------
# Residual Block
# ---------------------------------------
class ResBlock(nn.Module):
    def __init__(self, num_feats):
        super().__init__()
        self.conv1 = nn.Conv2d(num_feats, num_feats, 3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(num_feats, num_feats, 3, padding=1)

    def forward(self, x):
        return x + self.conv2(self.relu(self.conv1(x)))

# ---------------------------------------
# Single-Scale Deblurring Network
# ---------------------------------------
class SingleScaleDeblurNet(nn.Module):
    def __init__(self, in_channels, num_feats=64, num_blocks=8):
        super().__init__()
        self.head = nn.Conv2d(in_channels, num_feats, kernel_size=3, padding=1)
        self.body = nn.Sequential(*[ResBlock(num_feats) for _ in range(num_blocks)])
        self.tail = nn.Conv2d(num_feats, 3, kernel_size=3, padding=1)

    def forward(self, x):
        feat = self.head(x)
        feat = self.body(feat)
        out = self.tail(feat)
        return out

# ---------------------------------------
# Multi-Scale Deblurring Network (DeepDeblurMS)
# ---------------------------------------
class DeepDeblurMS(nn.Module):
    def __init__(self):
        super().__init__()
        # Each stage expects concatenated inputs → 6 channels: [blur, upsampled_output]
        self.coarse_net = SingleScaleDeblurNet(in_channels=6)
        self.middle_net = SingleScaleDeblurNet(in_channels=6)
        self.fine_net = SingleScaleDeblurNet(in_channels=6)

    def forward(self, x):
        # Create image pyramid
        x_half = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False)
        x_quarter = F.interpolate(x_half, scale_factor=0.5, mode='bilinear', align_corners=False)

        # Coarse scale input: duplicate x_quarter to simulate [blur, blur]
        coarse_input = torch.cat([x_quarter, x_quarter], dim=1)
        coarse_out = self.coarse_net(coarse_input)
        up_coarse = F.interpolate(coarse_out, scale_factor=2, mode='bilinear', align_corners=False)

        # Middle scale input: [blur_half, upsampled_coarse]
        mid_input = torch.cat([x_half, up_coarse], dim=1)
        mid_out = self.middle_net(mid_input)
        up_mid = F.interpolate(mid_out, scale_factor=2, mode='bilinear', align_corners=False)

        # Fine scale input: [blur_full, upsampled_middle]
        fine_input = torch.cat([x, up_mid], dim=1)
        fine_out = self.fine_net(fine_input)

        return fine_out, mid_out, coarse_out



#Training

In [None]:
import cv2
from natsort import natsorted
import os
from tqdm import tqdm
from pdb import set_trace as stx
from joblib import Parallel, delayed
import multiprocessing
import torch.nn as nn
import torch.optim as optim
from functools import partial
import torch

In [None]:
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    torch.cuda.empty_cache()

##Training logic for Restormer model

In [None]:
#Training Restormer on Perceptual Loss

import loss.deblur_loss as deblur_loss
from deblur_loss import ReconstructPerceptualLoss as ReconstructLoss
import importlib
from utils import pos_embed
from functools import partial
import torch
import torch.nn as nn
from timm.models.vision_transformer import PatchEmbed, Block as TimmBlock
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from utils.pos_embed import get_2d_sincos_pos_embed
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from tqdm.notebook import tqdm
import os
import time
import datetime
import gc

LOAD_PRETRAINED = True
run_validation = True

MODEL_NAME = 'Restormer'

# Restormer for deblur task

parameters = {'inp_channels':3, 'out_channels':3, 'dim':48, 'num_blocks':[4,6,6,8], 'num_refinement_blocks':4, 'heads':[1,2,4,8], 'ffn_expansion_factor':2.66, 'bias':False, 'LayerNorm_type':'WithBias', 'dual_pixel_task':False}

model = Restormer(**parameters)

try:
    opt = {'image_size': 224, 'pretrain_mae': PRETRAINED_WEIGHTS_PATH, 'device': DEVICE}
    criterion = ReconstructLoss(opt)
    model = model.cuda()
    criterion.pretrain_mae = criterion.pretrain_mae.to(torch.device('cuda'))
    print("Custom loss criterion initialized.")
except Exception as e: raise SystemExit(f"Error initializing loss: {e}")

effective_lr = 5e-5 if LOAD_PRETRAINED and os.path.exists(PRETRAINED_WEIGHTS_PATH) else LEARNING_RATE
optimizer = optim.Adam(model.parameters(), lr=effective_lr, weight_decay=WEIGHT_DECAY)
print(f"Optimizer: Adam, LR: {effective_lr:.1e}, Weight Decay: {WEIGHT_DECAY:.1e}")

def lr_lambda(epoch):
    num_halvings = epoch // LR_DECAY_EPOCHS
    lr_multiplier = 0.5 ** num_halvings
    final_multiplier = LR_FINAL / effective_lr # Use the actual starting LR
    return max(lr_multiplier, final_multiplier)
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda)
print(f"LR Scheduler: Halve every {LR_DECAY_EPOCHS} epochs, min LR {LR_FINAL:.1e}")


print(f"\n--- Starting Training for {EPOCHS} Epochs ---")
start_time = time.time()
best_val_loss = 0.0
SAVE_BEST_PATH = os.path.join(OUTPUT_DIR, f"{MODEL_NAME}_best_restormer_new.pth")

for epoch in range(EPOCHS):
    model.train()
    epoch_loss_total = 0.0
    epoch_loss_l1 = 0.0
    epoch_loss_perc = 0.0
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}", leave=False)

    for batch_idx, (blur_imgs, sharp_imgs) in enumerate(progress_bar):

        gt_img = sharp_imgs.to(DEVICE)
        b_img = blur_imgs.to(DEVICE)
        recover_img = model(b_img)
        losses = criterion(recover_img, gt_img)

        grad_loss = losses["total_loss"]
        optimizer.zero_grad()
        grad_loss.backward()
        optimizer.step()

        epoch_loss_total += grad_loss.item()
        epoch_loss_l1 += losses.get('l1', torch.tensor(0.0)).item()
        epoch_loss_perc += losses.get('Perceptual', torch.tensor(0.0)).item()
        progress_bar.set_postfix(loss=f"{grad_loss.item():.4f}")

    # Update learning rate
    scheduler.step()

    avg_loss_total = epoch_loss_total / len(dataloader)
    avg_loss_l1 = epoch_loss_l1 / len(dataloader)
    avg_loss_perc = epoch_loss_perc / len(dataloader)
    current_lr = optimizer.param_groups[0]['lr']

    print("Clearing cache before validation...")
    del gt_img, b_img, blur_imgs, sharp_imgs, gt_img, b_img, recover_img, losses, grad_loss
    gc.collect()
    if DEVICE == 'cuda':
        torch.cuda.empty_cache()

    if run_validation and val_dataloader:

        val_loss = validate_epoch(model, val_dataloader, DEVICE, criterion, visualize=True, num_images_to_show=4, epoch_num=None)

        if val_loss > best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), SAVE_BEST_PATH)
            print(f"Epoch {epoch+1}/{EPOCHS} - LR: {current_lr:.6f} - Loss: {avg_loss_total:.4f} (L1: {avg_loss_l1:.4f}, Perc: {avg_loss_perc:.4f}) || Validation Loss: {val_loss:.4f} *** Best Model Saved ***")
        else:
            print(f"Epoch {epoch+1}/{EPOCHS} - LR: {current_lr:.6f} - Loss: {avg_loss_total:.4f} (L1: {avg_loss_l1:.4f}, Perc: {avg_loss_perc:.4f}) || Validation Loss: {val_loss:.4f}")
    else:
         print(f"Epoch {epoch+1}/{EPOCHS} - LR: {current_lr:.6f} - Loss: {avg_loss_total:.4f} (L1: {avg_loss_l1:.4f}, Perc: {avg_loss_perc:.4f}) || Validation SKIPPED")
    print(f"Epoch {epoch+1}/{EPOCHS} - LR: {current_lr:.6f} - Loss: {avg_loss_total:.4f} (L1: {avg_loss_l1:.4f}, Perc: {avg_loss_perc:.4f})")

    current_checkpoint_path = os.path.join(OUTPUT_DIR, f"checkpoint_epoch_{epoch}.pth")

    save_data = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'loss': avg_loss_total,
        'val_loss': val_loss, }
    torch.save(save_data, current_checkpoint_path)

status = "Transferred weights" if LOAD_PRETRAINED and os.path.exists(PRETRAINED_WEIGHTS_PATH) else "scratch"
final_model_path = os.path.join(OUTPUT_DIR, f"{status}_customloss_restormer_epoch{EPOCHS}_final.pth")
torch.save(model.state_dict(), final_model_path)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f"\n--- Training Finished ---")
print(f"Total Training Time: {total_time_str}")
print(f"Final model saved to: {final_model_path}")
if run_validation and os.path.exists(SAVE_BEST_PATH):
    print(f"Best model (Validation PSNR: {best_val_loss:.4f}) saved to: {SAVE_BEST_PATH}")
elif run_validation:
    print(f"Best model not saved (Validation PSNR did not improve beyond initial {best_val_loss:.4f}).")
else:
    print("Best model not saved (Validation was skipped).")


if 'dataset' in locals() and hasattr(dataset, 'close'): dataset.close()
print("Done.")

##Training Logic for DeepDeblur model

In [None]:
#Training DeepDeblur on Perceptual Loss

import loss.deblur_loss as deblur_loss
from deblur_loss import ReconstructPerceptualLoss as ReconstructLoss
import importlib
from utils import pos_embed
from functools import partial
import torch
import torch.nn as nn
from timm.models.vision_transformer import PatchEmbed, Block as TimmBlock
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from utils.pos_embed import get_2d_sincos_pos_embed
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from tqdm.notebook import tqdm
import os
import time
import datetime
import gc

LOAD_PRETRAINED = True
run_validation = True

model = DeepDeblurMS()

try:
    opt = {'image_size': IMG_SIZE, 'pretrain_mae': PRETRAINED_WEIGHTS_PATH, 'device': DEVICE}
    criterion = ReconstructLoss(opt)
    model = model.cuda()
    criterion.pretrain_mae = criterion.pretrain_mae.to(torch.device('cuda'))
    print("Custom loss criterion initialized.")
except Exception as e: raise SystemExit(f"Error initializing loss: {e}")

effective_lr = 5e-5 if LOAD_PRETRAINED and os.path.exists(PRETRAINED_WEIGHTS_PATH) else LEARNING_RATE
optimizer = optim.Adam(model.parameters(), lr=effective_lr, weight_decay=WEIGHT_DECAY)
print(f"Optimizer: Adam, LR: {effective_lr:.1e}, Weight Decay: {WEIGHT_DECAY:.1e}")

def lr_lambda(epoch):
    num_halvings = epoch // LR_DECAY_EPOCHS
    lr_multiplier = 0.5 ** num_halvings
    final_multiplier = LR_FINAL / effective_lr # Use the actual starting LR
    return max(lr_multiplier, final_multiplier)
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda)
print(f"LR Scheduler: Halve every {LR_DECAY_EPOCHS} epochs, min LR {LR_FINAL:.1e}")


print(f"\n--- Starting Training for {EPOCHS} Epochs ---")
start_time = time.time()
best_val_loss = 0.0
SAVE_BEST_PATH = os.path.join(OUTPUT_DIR, f"{MODEL_NAME}_best_deepDeblur.pth")

for epoch in range(EPOCHS):
    model.train()
    epoch_loss_total = 0.0
    epoch_loss_l1 = 0.0
    epoch_loss_perc = 0.0
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}", leave=False)

    for batch_idx, (blur_imgs, sharp_imgs) in enumerate(progress_bar):

        gt_img = sharp_imgs.to(DEVICE)
        b_img = blur_imgs.to(DEVICE)

        sharp_gt_half = F.interpolate(gt_img, scale_factor=0.5, mode='bilinear', align_corners=False)
        sharp_gt_quarter = F.interpolate(sharp_gt_half, scale_factor=0.5, mode='bilinear', align_corners=False)

        fine_out, mid_out, coarse_out = model(b_img)

        sharp_gt_half_224 = F.interpolate(sharp_gt_half, size = 224, mode='bilinear', align_corners=False)
        sharp_gt_quarter_224 = F.interpolate(sharp_gt_quarter, size = 224, mode='bilinear', align_corners=False)

        mid_out_224 = F.interpolate(mid_out, size = 224, mode='bilinear', align_corners=False)
        coarse_out_224 = F.interpolate(coarse_out, size = 224, mode='bilinear', align_corners=False)

        # Uses ReconstructLoss's forward
        loss_fine = criterion(fine_out, gt_img)
        loss_mid = criterion(mid_out_224, sharp_gt_half_224)
        loss_coarse = criterion(coarse_out_224, sharp_gt_quarter_224)
        # Combining the fine, mid and coarse losses
        losses = loss_fine["total_loss"] + loss_mid["total_loss"] + loss_coarse["total_loss"]

        # 4. Backpropagation
        grad_loss = loss_fine["total_loss"] + loss_mid["total_loss"] + loss_coarse["total_loss"]
        optimizer.zero_grad()
        grad_loss.backward()
        optimizer.step()

        # Logging
        epoch_loss_total += grad_loss.item()
        epoch_loss_l1 += loss_fine.get('l1', torch.tensor(0.0)).item() + loss_mid.get('l1', torch.tensor(0.0)).item() + loss_coarse.get('l1', torch.tensor(0.0)).item()
        epoch_loss_perc += loss_fine.get('Perceptual', torch.tensor(0.0)).item() + loss_mid.get('Perceptual', torch.tensor(0.0)).item() + loss_coarse.get('Perceptual', torch.tensor(0.0)).item()
        progress_bar.set_postfix(loss=f"{grad_loss.item():.4f}")

    # Update learning rate
    scheduler.step()

    avg_loss_total = epoch_loss_total / len(dataloader)
    avg_loss_l1 = epoch_loss_l1 / len(dataloader)
    avg_loss_perc = epoch_loss_perc / len(dataloader)
    current_lr = optimizer.param_groups[0]['lr']

    print("Clearing cache before validation...")
    del blur_imgs, sharp_imgs, gt_img, b_img, sharp_gt_half, sharp_gt_quarter, fine_out, mid_out, coarse_out, sharp_gt_half_224, sharp_gt_quarter_224, mid_out_224, coarse_out_224, losses, grad_loss
    gc.collect()
    if DEVICE == 'cuda':
        torch.cuda.empty_cache()

    if run_validation and val_dataloader:

        val_loss = validate_epoch(model, val_dataloader, DEVICE, criterion, visualize=True, num_images_to_show=4, epoch_num=None)

        if val_loss > best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), SAVE_BEST_PATH)
            print(f"Epoch {epoch+1}/{EPOCHS} - LR: {current_lr:.6f} - Loss: {avg_loss_total:.4f} (L1: {avg_loss_l1:.4f}, Perc: {avg_loss_perc:.4f}) || Validation Loss: {val_loss:.4f} *** Best Model Saved ***")
        else:
            print(f"Epoch {epoch+1}/{EPOCHS} - LR: {current_lr:.6f} - Loss: {avg_loss_total:.4f} (L1: {avg_loss_l1:.4f}, Perc: {avg_loss_perc:.4f}) || Validation Loss: {val_loss:.4f}")
    else:
         print(f"Epoch {epoch+1}/{EPOCHS} - LR: {current_lr:.6f} - Loss: {avg_loss_total:.4f} (L1: {avg_loss_l1:.4f}, Perc: {avg_loss_perc:.4f}) || Validation SKIPPED")
    print(f"Epoch {epoch+1}/{EPOCHS} - LR: {current_lr:.6f} - Loss: {avg_loss_total:.4f} (L1: {avg_loss_l1:.4f}, Perc: {avg_loss_perc:.4f})")

    current_checkpoint_path = os.path.join(OUTPUT_DIR, f"checkpoint_epoch_{epoch}.pth")

    save_data = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'loss': avg_loss_total,
        'val_loss': val_loss, }
    torch.save(save_data, current_checkpoint_path)

status = "Transferred weights" if LOAD_PRETRAINED and os.path.exists(PRETRAINED_WEIGHTS_PATH) else "scratch"
final_model_path = os.path.join(OUTPUT_DIR, f"{MODEL_NAME}_{status}_customloss_deepDeblur_epoch{EPOCHS}_final.pth")
torch.save(model.state_dict(), final_model_path)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f"\n--- Training Finished ---")
print(f"Total Training Time: {total_time_str}")
print(f"Final model saved to: {final_model_path}")
if run_validation and os.path.exists(SAVE_BEST_PATH):
    print(f"Best model (Validation PSNR: {best_val_loss:.4f}) saved to: {SAVE_BEST_PATH}")
elif run_validation:
    print(f"Best model not saved (Validation PSNR did not improve beyond initial {best_val_loss:.4f}).")
else:
    print("Best model not saved (Validation was skipped).")


if 'dataset' in locals() and hasattr(dataset, 'close'): dataset.close()
print("Done.")