## Imports

In [1]:
import time
import pickle

import scipy
import numpy as np
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
from skimage.metrics import peak_signal_noise_ratio

import torch
import torch.nn as nn

import deepinv

from pnp_unrolling.unrolled_cdl import UnrolledCDL
from utils.measurement_tools import get_operators
from utils.tools import op_norm2
from pnp_unrolling.datasets import (
    create_imagenet_dataloader,
)


def plot_img(img, ax, title=None):
    img = img.detach().cpu().numpy().transpose(1, 2, 0).clip(0, 1)
    ax.imshow(img)
    ax.set_axis_off()
    if title:
        ax.set_title(title)


DATASET = "bsd"
COLOR = True
DEVICE = "cuda:0" #"cuda:3"
STD_NOISE = 0.05

# Here the dataset is "BSD" but we use the same create_imagenet_dataloader
# function which need to set dataset="imagenet"
create_dataloader = create_imagenet_dataloader
DATA_PATH = "./BSDS500/BSDS500/data/images"
DATASET = "imagenet"

  from .autonotebook import tqdm as notebook_tqdm
INFO:datasets:PyTorch version 2.5.1 available.


## Creating and training all denoisers

In [27]:
def get_denoiser(model, **kwargs):

    if model == "drunet":
        nc = 3 if COLOR else 1
        net = deepinv.models.DRUNet(
            in_channels=nc,
            out_channels=nc,
            nc=[64, 128, 256, 512],
            nb=4,
            act_mode="R",
            downsample_mode="strideconv",
            upsample_mode="convtranspose",
            pretrained="download",
        )
        net = nn.DataParallel(net, device_ids=[int(DEVICE[-1])])
    elif model in ["analysis", "synthesis"]:
        unrolled_cdl = UnrolledCDL(type_unrolling=model, **kwargs)
        # Training unrolled networks
        net, *_ = unrolled_cdl.fit()
    else:
        raise ValueError(
            f"Requested denoiser {model} which is not available."
        )
    return net


params_model = {
    "kernel_size": 5,
    "lmbd": 1e-4,
    "color": COLOR,
    "device": DEVICE,
    "dtype": torch.float,
    "optimizer": "adam",
    "path_data": DATA_PATH,
    "max_sigma_noise": STD_NOISE,
    "min_sigma_noise": STD_NOISE,
    "mini_batch_size": 1,
    "max_batch": 10,
    "epochs": 50,
    "avg": False,
    "rescale": False,
    "fixed_noise": True,
    "D_shared": True,
    "step_size_scaling": 1.8,
    "lr": 1e-3,
    "dataset": DATASET,
}

components_list = [10, 50, 100]
layers_list = [1, 20]
n_rep_list = [20, 50, 100]

DENOISERS = {"DRUNet": dict(model="drunet")}
for denoiser_type in ["SD", "AD"]:
    for components in components_list:
        for layers in layers_list:
            params = {k: v for k, v in params_model.items()}
            params["n_layers"] = layers
            params["n_components"] = components
            
            # ----- #REPEAT = #LAYERS -----
            if denoiser_type == "SD":
                base_name = f"SD_{components}C_{layers}L"
                model_type = "synthesis"
            elif denoiser_type == "AD":
                base_name = f"AD_{components}C_{layers}L"
                model_type = "analysis"
            
            name = f"{base_name}_{layers}R"
            DENOISERS[name] = {"model": model_type, **params}
            print(f"Training {base_name}...")
            DENOISERS[name]["net"] = get_denoiser(**DENOISERS[name])
            
            # ----- #REPEAT = 1 -----
            if layers > 1:
                denoiser = DENOISERS[name]
                old_net = denoiser["net"]
                net = UnrolledCDL(
                    type_unrolling=denoiser["model"],
                    **{k: v for k, v in denoiser.items() if k not in ["model", "net"]}
                ).unrolled_net
                # Replace the model with only the first layer of the trained model
                net.parameter = old_net.parameter
                net.model = torch.nn.ModuleList([old_net.model[0]])
                DENOISERS[f"{base_name}_1R"] = dict(net=net, model=denoiser["model"], **params)
                DENOISERS[f"{base_name}_1R"]["n_layers"] = 1
            
            # ----- #REPEAT = N_REP -----
            for n_rep in n_rep_list:
                if n_rep == layers:
                    continue
                denoiser = DENOISERS[f"{base_name}_1R"]
                old_net = denoiser["net"]
                net = UnrolledCDL(
                    type_unrolling=denoiser["model"],
                    **{k: v for k, v in denoiser.items() if k not in ["model", "net"]}
                ).unrolled_net
                assert len(net.model) == 1
                net.parameter = old_net.parameter
                net.model = torch.nn.ModuleList([old_net.model[0]] * n_rep)
                DENOISERS[f"{base_name}_{n_rep}R"] = dict(
                    net=net, model=denoiser["model"], **params
                )
                DENOISERS[f"{base_name}_{n_rep}R"]["n_layers"] = n_rep

Loading SD_10C_1L...


Epoch 1 - Average train loss: 0.13360649 - Average test loss: 0.05825477:   0%|          | 0/50 [00:00<?, ?it/s]

Epoch 50 - Average train loss: 0.05113075 - Average test loss: 0.04356655: 100%|██████████| 50/50 [00:03<00:00, 14.62it/s]


Done
Loading SD_10C_20L...


Epoch 50 - Average train loss: 0.00214265 - Average test loss: 0.00222964: 100%|██████████| 50/50 [00:14<00:00,  3.42it/s]


Done
Loading SD_50C_1L...


Epoch 50 - Average train loss: 0.06309576 - Average test loss: 0.04775674: 100%|██████████| 50/50 [00:03<00:00, 14.99it/s]


Done
Loading SD_50C_20L...


Epoch 50 - Average train loss: 0.00235013 - Average test loss: 0.00238540: 100%|██████████| 50/50 [00:15<00:00,  3.13it/s]


Done
Loading SD_100C_1L...


Epoch 50 - Average train loss: 0.04878776 - Average test loss: 0.05639055: 100%|██████████| 50/50 [00:03<00:00, 13.92it/s]


Done
Loading SD_100C_20L...


Epoch 50 - Average train loss: 0.00239302 - Average test loss: 0.00238578: 100%|██████████| 50/50 [00:24<00:00,  2.07it/s]


Done
Loading AD_10C_1L...


Epoch 50 - Average train loss: 0.00240988 - Average test loss: 0.00238601: 100%|██████████| 50/50 [00:03<00:00, 15.45it/s]


Done
Loading AD_10C_20L...


Epoch 50 - Average train loss: 0.00238264 - Average test loss: 0.00239395: 100%|██████████| 50/50 [00:16<00:00,  3.05it/s]


Done
Loading AD_50C_1L...


Epoch 50 - Average train loss: 0.00233999 - Average test loss: 0.00235471: 100%|██████████| 50/50 [00:03<00:00, 14.64it/s]


Done
Loading AD_50C_20L...


Epoch 50 - Average train loss: 0.00239718 - Average test loss: 0.00234325: 100%|██████████| 50/50 [00:17<00:00,  2.90it/s]


Done
Loading AD_100C_1L...


Epoch 50 - Average train loss: 0.00226619 - Average test loss: 0.00232102: 100%|██████████| 50/50 [00:03<00:00, 12.59it/s]


Done
Loading AD_100C_20L...


Epoch 50 - Average train loss: 0.00230095 - Average test loss: 0.00232986: 100%|██████████| 50/50 [00:28<00:00,  1.74it/s]

Done





In [28]:
DENOISERS.keys()

dict_keys(['DRUNet', 'SD_10C_1L_1R', 'SD_10C_1L_20R', 'SD_10C_1L_50R', 'SD_10C_1L_100R', 'SD_10C_20L_20R', 'SD_10C_20L_1R', 'SD_10C_20L_50R', 'SD_10C_20L_100R', 'SD_50C_1L_1R', 'SD_50C_1L_20R', 'SD_50C_1L_50R', 'SD_50C_1L_100R', 'SD_50C_20L_20R', 'SD_50C_20L_1R', 'SD_50C_20L_50R', 'SD_50C_20L_100R', 'SD_100C_1L_1R', 'SD_100C_1L_20R', 'SD_100C_1L_50R', 'SD_100C_1L_100R', 'SD_100C_20L_20R', 'SD_100C_20L_1R', 'SD_100C_20L_50R', 'SD_100C_20L_100R', 'AD_10C_1L_1R', 'AD_10C_1L_20R', 'AD_10C_1L_50R', 'AD_10C_1L_100R', 'AD_10C_20L_20R', 'AD_10C_20L_1R', 'AD_10C_20L_50R', 'AD_10C_20L_100R', 'AD_50C_1L_1R', 'AD_50C_1L_20R', 'AD_50C_1L_50R', 'AD_50C_1L_100R', 'AD_50C_20L_20R', 'AD_50C_20L_1R', 'AD_50C_20L_50R', 'AD_50C_20L_100R', 'AD_100C_1L_1R', 'AD_100C_1L_20R', 'AD_100C_1L_50R', 'AD_100C_1L_100R', 'AD_100C_20L_20R', 'AD_100C_20L_1R', 'AD_100C_20L_50R', 'AD_100C_20L_100R'])