In [1]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [2]:
path = "/content/drive/MyDrive/Deep_Image_Prior/make_images/img2"
clean = path + "/clean.png"
noisy = path + "/noisy.png"
lr = path + "/lr_x4.png"
corrupted = path + "/corrupted.png"
mask = path + "/mask.png"
output = "/content/output"
output_denoise = output + "/denoise/"
output_superres = output + "/superres/"
output_inpaint = output + "/inpaint/"
# create the output folders if they don't exist
import os
os.makedirs(output, exist_ok=True)
os.makedirs(output_denoise, exist_ok=True)
os.makedirs(output_superres, exist_ok=True)
os.makedirs(output_inpaint, exist_ok=True)



In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import numpy as np
import math
import matplotlib.pyplot as plt



def compute_psnr(img_path_1, img_path_2):
    """
    Load two PNG images and compute PSNR between them (in dB).

    Images are converted to RGB, resized to the smaller common size if needed,
    and normalized to [0, 1].
    """
    # Load and convert to RGB
    img1 = Image.open(img_path_1).convert("RGB")
    img2 = Image.open(img_path_2).convert("RGB")

    # Optionally ensure same size (here we resize img2 to img1's size;
    # remove this if you guarantee same size already)
    if img1.size != img2.size:
        img2 = img2.resize(img1.size, Image.BICUBIC)

    # To numpy, float32, [0, 1]
    x = np.asarray(img1).astype(np.float32) / 255.0
    y = np.asarray(img2).astype(np.float32) / 255.0

    # MSE
    mse = np.mean((x - y) ** 2)
    if mse == 0:
        return float("inf")

    max_val = 1.0  # because we normalized to [0, 1]
    psnr = 10 * math.log10((max_val ** 2) / mse)
    return psnr



PSNR_noisy = compute_psnr(clean, noisy)
PSNR_lowres = compute_psnr(clean, lr)
PSNR_inpaint = compute_psnr(clean, corrupted)

print("Before :")

print("Noisy Image PSNR : ", PSNR_noisy)
print("Low-res PSNR : ", PSNR_lowres)
print("Inpainting PSNR (before): ", PSNR_inpaint)


def psnr_torch(x, y, max_val=1.0):
    """
    x, y: tensors in [0,1], shape [1, C, H, W] (or broadcastable)
    returns PSNR in dB (float)
    """
    assert x.shape == y.shape, f"Shape mismatch: {x.shape} vs {y.shape}"
    mse = torch.mean((x - y) ** 2)
    if mse.item() == 0:
        return float("inf")
    return 10 * torch.log10(torch.tensor(max_val**2) / mse).item()


# ============================================================
# 1. Network: small U-Net-style generator (Deep Image Prior)
# ============================================================

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


class DIPUNet(nn.Module):
    """
    Simple U-Net–like architecture to use as Deep Image Prior.
    """
    def __init__(self, in_ch=32, out_ch=3, base_ch=64):
        super().__init__()

        # Encoder
        self.enc1 = ConvBlock(in_ch, base_ch)
        self.enc2 = ConvBlock(base_ch, base_ch * 2)
        self.enc3 = ConvBlock(base_ch * 2, base_ch * 4)

        # Decoder
        self.dec3 = ConvBlock(base_ch * 4 + base_ch * 2, base_ch * 2)
        self.dec2 = ConvBlock(base_ch * 2 + base_ch, base_ch)
        self.dec1 = nn.Sequential(
            nn.Conv2d(base_ch, out_ch, kernel_size=3, padding=1),
            nn.Sigmoid()  # output in [0, 1]
        )

        self.pool = nn.AvgPool2d(2)

    def forward(self, x):
        # Encoder
        x1 = self.enc1(x)                 # [B, C1, H,   W]
        x2 = self.enc2(self.pool(x1))     # [B, C2, H/2, W/2]
        x3 = self.enc3(self.pool(x2))     # [B, C3, H/4, W/4]

        # Decoder: upsample to match skip sizes (avoids off-by-one issues)
        y2 = F.interpolate(
            x3,
            size=x2.shape[-2:],           # match spatial shape of x2
            mode="bilinear",
            align_corners=False,
        )
        y2 = torch.cat([y2, x2], dim=1)
        y2 = self.dec3(y2)

        y1 = F.interpolate(
            y2,
            size=x1.shape[-2:],           # match spatial shape of x1
            mode="bilinear",
            align_corners=False,
        )
        y1 = torch.cat([y1, x1], dim=1)
        y1 = self.dec2(y1)

        out = self.dec1(y1)
        return out


# =========================
# 2. Utilities
# =========================

def get_device():
    return "cuda" if torch.cuda.is_available() else "cpu"


def load_image(path, imsize=None, device="cpu"):
    """
    Load an RGB image as a tensor [1, 3, H, W] in [0, 1].
    If imsize is not None, resize to (H, W).
    """
    img = Image.open(path).convert("RGB")

    if imsize is not None:
        tfms = transforms.Compose([
            transforms.Resize(imsize, Image.BICUBIC),
            transforms.ToTensor()
        ])
    else:
        tfms = transforms.Compose([
            transforms.ToTensor()
        ])

    img_t = tfms(img).unsqueeze(0).to(device)
    return img_t


def load_mask(path, imsize=None, device="cpu"):
    """
    Load a mask as a tensor [1, 1, H, W] with values {0,1}.
    Convention: white/bright = 1 (KNOWN pixel), black/dark = 0 (HOLE).
    """
    mask = Image.open(path).convert("L")

    if imsize is not None:
        tfms = transforms.Compose([
            transforms.Resize(imsize, Image.NEAREST),
            transforms.ToTensor()
        ])
    else:
        tfms = transforms.Compose([
            transforms.ToTensor()
        ])

    m_t = tfms(mask).unsqueeze(0).to(device)  # [1, 1, H, W], in [0,1]
    m_t = (m_t > 0.5).float()
    return m_t


def save_image(tensor, path):
    tensor = tensor.detach().cpu().clamp(0, 1)
    img = transforms.ToPILImage()(tensor.squeeze(0))
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    img.save(path)


# ============================================================
# 3. DIP for DENOISING
#    minimize ||f_theta(z) - x_noisy||^2
# ============================================================

def deep_image_prior_denoise(
    noisy_img_path,
    out_path="dip_denoised.png",
    imsize=None,
    num_iter=2000,
    lr=0.01,
    input_depth=32,
    seed=42,
    show_every=200,
    input_noise_std=0.03,
    clean_img_path=None,
    log_every=50,  # new: interval for saving & logging
):
    device = get_device()
    torch.manual_seed(seed)

    # Load noisy observation x0
    x0 = load_image(noisy_img_path, imsize=imsize, device=device)
    _, c, h, w = x0.shape

    # Optional: clean image for PSNR
    x_clean = None
    if clean_img_path is not None:
        x_clean = load_image(
            clean_img_path,
            imsize=(h, w) if imsize is None else imsize,
            device=device,
        )

    # Fixed random input z
    net_input = torch.rand(1, input_depth, h, w, device=device)

    net = DIPUNet(in_ch=input_depth, out_ch=c).to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    criterion = nn.MSELoss()

    out_avg = None
    exp_weight = 0.99

    # For logging
    iters_log = []
    losses_log = []
    psnr_log = []

    base_name, ext = os.path.splitext(out_path)

    for i in range(num_iter):
        optimizer.zero_grad()

        perturbed_input = net_input + input_noise_std * torch.randn_like(net_input)
        out = net(perturbed_input)

        loss = criterion(out, x0)
        loss.backward()
        optimizer.step()

        # EMA of output
        if out_avg is None:
            out_avg = out.detach()
        else:
            out_avg = out_avg * exp_weight + out.detach() * (1 - exp_weight)

        # Logging & intermediate saving
        if (i + 1) % log_every == 0 or i == 0:
            current_iter = i + 1
            current_img = out_avg if out_avg is not None else out.detach()

            # Save intermediate image
            inter_path = f"{base_name}_iter_{current_iter}{ext}"
            if (i + 1) % 500 == 0 :
              save_image(current_img, inter_path)


            # Record loss
            iters_log.append(current_iter)
            losses_log.append(loss.item())

            # Record PSNR if we have clean
            if x_clean is not None:
                psnr_val = psnr_torch(current_img, x_clean)
                psnr_log.append(psnr_val)
                # print(
                #     f"[DENOISE] iter {current_iter}/{num_iter}, "
                #     f"loss = {loss.item():.6f}, PSNR = {psnr_val:.2f} dB"
                # )
            else:
                psnr_log.append(None)
                # print(
                #     f"[DENOISE] iter {current_iter}/{num_iter}, "
                #     f"loss = {loss.item():.6f}"
                # )

    # Final save
    save_image(out_avg, out_path)
    print(f"[DENOISE] Saved final result to {out_path}")

    # Final PSNR
    final_psnr = None
    if x_clean is not None:
        final_psnr = psnr_torch(out_avg, x_clean)
        print(f"[DENOISE] Final PSNR vs clean: {final_psnr:.2f} dB")

        # Plot loss and PSNR
    plt.figure(figsize=(8, 5))
    ax1 = plt.gca()

    # Loss curve
    lns1 = ax1.plot(iters_log, losses_log, label="Loss")
    ax1.set_xlabel("Iteration")
    ax1.set_ylabel("Loss")

    # If we have PSNR values (clean image was provided)
    lns2 = []
    if x_clean is not None and any(p is not None for p in psnr_log):
        psnr_np = [
            float("nan") if p is None else p
            for p in psnr_log
        ]
        ax2 = ax1.twinx()
        lns2 = ax2.plot(iters_log, psnr_np, linestyle="--", label="PSNR (dB)")
        ax2.set_ylabel("PSNR (dB)")

    # Combine legends from both axes
    lines = lns1 + lns2
    labels = [l.get_label() for l in lines]
    ax1.legend(lines, labels, loc="best")

    plt.title("DIP Denoising: Loss & PSNR vs Iteration")
    plt.tight_layout()

    # Save plot next to images
    plot_path = f"{base_name}_curves.png"
    plt.savefig(plot_path, dpi=150)
    print(f"[DENOISE] Saved loss/PSNR plot to {plot_path}")

    plt.show()

    return out_avg, final_psnr, (iters_log, losses_log, psnr_log)




# ============================================================
# 4. DIP for SUPER-RESOLUTION
#    minimize || D(f_theta(z)) - y_lr ||^2
#    where D is downsampling operator, y_lr is low-res image
# ============================================================

def deep_image_prior_superres(
    lowres_img_path,
    out_path="dip_superres.png",
    up_factor=4,
    num_iter=2000,
    lr=0.01,
    input_depth=32,
    seed=42,
    show_every=200,
    input_noise_std=0.03,
    clean_img_path=None,
    log_every=100,  # new
):
    device = get_device()
    torch.manual_seed(seed)

    # Low-res observation y_lr
    y_lr = load_image(lowres_img_path, imsize=None, device=device)
    _, c, h_lr, w_lr = y_lr.shape
    h_hr, w_hr = h_lr * up_factor, w_lr * up_factor

    # Optional clean HR image
    x_clean = None
    if clean_img_path is not None:
        x_clean = load_image(
            clean_img_path,
            imsize=(h_hr, w_hr),
            device=device,
        )

    # Fixed random input z at HR size
    net_input = torch.rand(1, input_depth, h_hr, w_hr, device=device)

    net = DIPUNet(in_ch=input_depth, out_ch=c).to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    criterion = nn.MSELoss()

    out_avg = None
    exp_weight = 0.99

    # Logging
    iters_log = []
    losses_log = []
    psnr_log = []

    base_name, ext = os.path.splitext(out_path)

    for i in range(num_iter):
        optimizer.zero_grad()

        perturbed_input = net_input + input_noise_std * torch.randn_like(net_input)
        out_hr = net(perturbed_input)  # HR output

        out_lr = F.interpolate(
            out_hr,
            size=(h_lr, w_lr),
            mode="bilinear",
            align_corners=False,
        )

        loss = criterion(out_lr, y_lr)
        loss.backward()
        optimizer.step()

        # EMA of HR output
        if out_avg is None:
            out_avg = out_hr.detach()
        else:
            out_avg = out_avg * exp_weight + out_hr.detach() * (1 - exp_weight)

        # Logging & intermediate saving
        if (i + 1) % log_every == 0 or i == 0:
            current_iter = i + 1
            current_img = out_avg if out_avg is not None else out_hr.detach()

            inter_path = f"{base_name}_iter_{current_iter}{ext}"
            if (i + 1) % 500 == 0 :
              save_image(current_img, inter_path)


            iters_log.append(current_iter)
            losses_log.append(loss.item())

            if x_clean is not None:
                psnr_val = psnr_torch(current_img, x_clean)
                psnr_log.append(psnr_val)
                # print(
                #     f"[SUPERRES x{up_factor}] iter {current_iter}/{num_iter}, "
                #     f"loss = {loss.item():.6f}, PSNR = {psnr_val:.2f} dB"
                # )
            else:
                psnr_log.append(None)
                # print(
                #     f"[SUPERRES x{up_factor}] iter {current_iter}/{num_iter}, "
                #     f"loss = {loss.item():.6f}"
                # )

    # Final save
    save_image(out_avg, out_path)
    print(f"[SUPERRES x{up_factor}] Saved final result to {out_path}")

    final_psnr = None
    if x_clean is not None:
        final_psnr = psnr_torch(out_avg, x_clean)
        print(f"[SUPERRES x{up_factor}] Final PSNR vs clean: {final_psnr:.2f} dB")

        # Plot
    plt.figure(figsize=(8, 5))
    ax1 = plt.gca()

    # Loss curve
    lns1 = ax1.plot(iters_log, losses_log, label="Loss")
    ax1.set_xlabel("Iteration")
    ax1.set_ylabel("Loss")

    # PSNR curve
    lns2 = []
    if x_clean is not None and any(p is not None for p in psnr_log):
        psnr_np = [
            float("nan") if p is None else p
            for p in psnr_log
        ]
        ax2 = ax1.twinx()
        lns2 = ax2.plot(iters_log, psnr_np, linestyle="--", label="PSNR (dB)")
        ax2.set_ylabel("PSNR (dB)")

    # Legend
    lines = lns1 + lns2
    labels = [l.get_label() for l in lines]
    ax1.legend(lines, labels, loc="best")

    plt.title(f"DIP Super-Res x{up_factor}: Loss & PSNR vs Iteration")
    plt.tight_layout()

    # Save plot
    plot_path = f"{base_name}_curves.png"
    plt.savefig(plot_path, dpi=150)
    print(f"[SUPERRES x{up_factor}] Saved loss/PSNR plot to {plot_path}")

    plt.show()

    return out_avg, final_psnr, (iters_log, losses_log, psnr_log)




# ============================================================
# 5. DIP for INPAINTING
#    minimize || M ⊙ f_theta(z) - M ⊙ y ||^2
#    M is mask of KNOWN pixels, y is corrupted image
# ============================================================
def deep_image_prior_inpaint(
    corrupted_img_path,
    mask_path,
    out_path="dip_inpainted.png",
    imsize=None,
    num_iter=2000,
    lr=0.01,
    input_depth=32,
    seed=42,
    show_every=200,
    input_noise_std=0.03,
    clean_img_path=None,
    log_every=100,  # new
):
    device = get_device()
    torch.manual_seed(seed)

    # Observed image y (with holes) and mask M
    y = load_image(corrupted_img_path, imsize=imsize, device=device)  # [1, 3, H, W]
    M = load_mask(mask_path, imsize=imsize, device=device)            # [1, 1, H, W]

    _, c, h, w = y.shape
    M3 = M.expand(1, c, h, w)

    # Optional clean image
    x_clean = None
    if clean_img_path is not None:
        x_clean = load_image(
            clean_img_path,
            imsize=(h, w) if imsize is None else imsize,
            device=device,
        )

    # Fixed random input z
    net_input = torch.rand(1, input_depth, h, w, device=device)

    net = DIPUNet(in_ch=input_depth, out_ch=c).to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    criterion = nn.MSELoss()

    out_avg = None
    exp_weight = 0.99

    # Logging
    iters_log = []
    losses_log = []
    psnr_log = []

    base_name, ext = os.path.splitext(out_path)

    for i in range(num_iter):
        optimizer.zero_grad()

        perturbed_input = net_input + input_noise_std * torch.randn_like(net_input)
        out = net(perturbed_input)

        # Fidelity on known pixels only
        loss = criterion(out * M3, y * M3)
        loss.backward()
        optimizer.step()

        if out_avg is None:
            out_avg = out.detach()
        else:
            out_avg = out_avg * exp_weight + out.detach() * (1 - exp_weight)

        # Logging & intermediate saving
        if (i + 1) % log_every == 0 or i == 0:
            current_iter = i + 1
            current_img = out_avg if out_avg is not None else out.detach()

            inter_path = f"{base_name}_iter_{current_iter}{ext}"
            if (i + 1) % 500 == 0 :
              save_image(current_img, inter_path)

            iters_log.append(current_iter)
            losses_log.append(loss.item())

            if x_clean is not None:
                psnr_val = psnr_torch(current_img, x_clean)
                psnr_log.append(psnr_val)
                # print(
                #     f"[INPAINT] iter {current_iter}/{num_iter}, "
                #     f"loss = {loss.item():.6f}, PSNR = {psnr_val:.2f} dB"
                # )
            else:
                psnr_log.append(None)
                # print(
                #     f"[INPAINT] iter {current_iter}/{num_iter}, "
                #     f"loss = {loss.item():.6f}"
                # )

    # Final save
    save_image(out_avg, out_path)
    print(f"[INPAINT] Saved final result to {out_path}")

    final_psnr = None
    if x_clean is not None:
        final_psnr = psnr_torch(out_avg, x_clean)
        print(f"[INPAINT] Final PSNR vs clean: {final_psnr:.2f} dB")

        # Plot
    plt.figure(figsize=(8, 5))
    ax1 = plt.gca()

    # Loss curve
    lns1 = ax1.plot(iters_log, losses_log, label="Loss")
    ax1.set_xlabel("Iteration")
    ax1.set_ylabel("Loss")

    # PSNR curve
    lns2 = []
    if x_clean is not None and any(p is not None for p in psnr_log):
        psnr_np = [
            float("nan") if p is None else p
            for p in psnr_log
        ]
        ax2 = ax1.twinx()
        lns2 = ax2.plot(iters_log, psnr_np, linestyle="--", label="PSNR (dB)")
        ax2.set_ylabel("PSNR (dB)")

    # Legend
    lines = lns1 + lns2
    labels = [l.get_label() for l in lines]
    ax1.legend(lines, labels, loc="best")

    plt.title("DIP Inpainting: Loss & PSNR vs Iteration")
    plt.tight_layout()

    # Save plot
    plot_path = f"{base_name}_curves.png"
    plt.savefig(plot_path, dpi=150)
    print(f"[INPAINT] Saved loss/PSNR plot to {plot_path}")

    plt.show()

    return out_avg, final_psnr, (iters_log, losses_log, psnr_log)




# ============================================================
# 6. Example usage
# ============================================================
if __name__ == "__main__":
    # Minimal CLI-style usage example.
    # Edit & run as needed, or call these functions from a notebook.

    # Example 1: DENOISING
    # out_denoise, final_psnr_denoise, (iters_denoise, losses_denoise, psnrs_denoise) = deep_image_prior_denoise(
    #                       noisy_img_path= noisy ,
    #                       clean_img_path= clean,
    #                       out_path= output_denoise + "dip_denoised.png",
    #                       num_iter=10000,
    #                       log_every=100,
    #                   )



    # # Example 2: SUPER-RESOLUTION (4x)
    # out_superres, final_psnr_superres, (iters_superres, losses_superres, psnrs_superres) = deep_image_prior_superres(
    #     lowres_img_path= lr,
    #     clean_img_path= clean,
    #     out_path= output_superres + "dip_superres_x4.png",
    #     up_factor=4,
    #     num_iter=10000,
    #     log_every=100,
    # )


    # # Example 3: INPAINTING
    # out_inpaint, final_psnr_inpaint, (iters_inpaint, losses_inpaint, psnrs_inpaint) = deep_image_prior_inpaint(
    #     corrupted_img_path= corrupted,
    #     mask_path= mask,
    #     clean_img_path= clean,
    #     out_path= output_inpaint + "dip_inpainted.png",
    #     num_iter=10000,
    #     log_every=100,
    # )





In [None]:
# zip the output folder and download it
import zipfile

with zipfile.ZipFile('output.zip', 'w') as zipf:
    for root, dirs, files in os.walk(output):
        for file in files:
            file_path = os.path.join(root, file)
            # Create archive name relative to the parent of output folder
            arcname = os.path.relpath(file_path, os.path.dirname(output))
            zipf.write(file_path, arcname=arcname)

from google.colab import files
files.download('output.zip')