# Tutorial on PnP-LADMM

#### Imports

In [None]:
import os
import tqdm
import torch
import pickle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from utils.utils_pnp import get_metrics, get_metrics_bis
import utils.utils_image as util
import utils.utils_sisr as sisr

from models.model_pnp_approximate_admm import PnP_approx_ADMM
from models.model_pnp_ladmm import PnP_linearized_ADMM

#### Data-loading

In [9]:
from data.dataset_multiblur import Dataset

ROOT = 'PATH2ROOT'

opt_data = { "phase": "train"
          , "dataset_type": "multiblur"
          , "dataroot_H": ROOT + "/datasets/COCO/val2014"
          , "dataroot_L": None
          , "H_size": 128
          , "use_flip": True
          , "use_rot": True
          , "scales": [1]
          , "sigma": [2, 3]
          , "sigma_test": 15
          , "n_channels": 3
          , "dataloader_shuffle": True
          , "dataloader_num_workers": 16
          , "dataloader_batch_size": 16
          , "motion_blur": True

          , "coco_annotation_path": ROOT + "/datasets/COCO/instances_val2014.json"}

data = Dataset(opt_data)

from torch.utils.data import DataLoader
loader = DataLoader(data, batch_size=1)


loading annotations into memory...
Done (t=3.78s)
creating index...
index created!


#### Creation of the model

In [2]:
opt_pnp_ladmm = {'device': 'cuda',
                 'lamb': 1,
                 'sigma_d': 20/255,
                 'Lx': 1 / ((20 / 255)**2), 
                 'n_iter': 40,
                 'path_denoiser': 'model_zoo/drunet_color.pth'}

pnp_ladmm = PnP_linearized_ADMM(opt)

#### Run model

In [None]:
i = 0
iter_max = 10

for sample in loader:
    i += 1
    # run model
    pnp_ladmm.feed_data(sample)
    x, x_list, time_list = pnp_ladmm.run()

    #plot
    plt.subplot(131)
    plt.imshow(util.tensor2uint(sample['L']))
    plt.axis('off')
    plt.title('LR')
    plt.subplot(132)
    plt.imshow(util.tensor2uint(x))
    plt.axis('off')
    plt.title('Estimate')
    plt.subplot(133)
    plt.imshow(util.tensor2uint(sample['H']))
    plt.axis('off')
    plt.title('HR')

    # Stop after n_iter
    if i >= iter_max:
        break