### Import libs

In [None]:
from __future__ import print_function
import warnings
warnings.filterwarnings("ignore")
import matplotlib
import matplotlib.pyplot as plt

import sys
import numpy as np
from models import *
import torch
import torch.optim
import time
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from utils.denoising_utils import *
import _pickle as cPickle
import seaborn as sns

sns.set_style("darkgrid", {"axes.facecolor": ".9"})

# display images
def np_plot(np_matrix, title):
    plt.clf()
    fig = plt.imshow(np_matrix.transpose(1, 2, 0), interpolation = 'nearest')
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)
    plt.title(title)
    plt.axis('off')
    plt.pause(0.05) 

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark =True
dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
print('CUDA available: {}'.format(torch.cuda.is_available()))

### Load images

In [None]:
fname = 'data/denoising/Dataset/image_Peppers512rgb.png'
imsize =-1
sigma = 25/255.
img_pil = crop_image(get_image(fname, imsize)[0], d=32)
img_np = pil_to_np(img_pil)                
img_noisy_pil, img_noisy_np = get_noisy_image(img_np, sigma)
np_plot(img_np, 'Natural image')
np_plot(img_noisy_np, 'Noisy image')

### Hyper-parameters

In [None]:
INPUT = 'noise'
pad = 'reflection'
OPT_OVER = 'net' # optimize over the net parameters only
reg_noise_std = 1./30.
learning_rate = LR = 0.01
exp_weight=0.99
input_depth = 32 
roll_back = True # to prevent numerical issues
num_iter = 20000 # max iterations
burnin_iter = 7000 # burn-in iteration for SGLD
weight_decay = 5e-8
show_every =  500
mse = torch.nn.MSELoss().type(dtype) # loss
img_noisy_torch = np_to_torch(img_noisy_np).type(dtype)

### SGLD 

In [None]:
sgld_psnr_list = [] # psnr between sgld out and gt
sgld_mean = 0
roll_back = True # To solve the oscillation of model training 
last_net = None
psrn_noisy_last = 0
MCMC_iter = 50
param_noise_sigma = 2

sgld_mean_each = 0
sgld_psnr_mean_list = [] # record the PSNR of avg after burn-in

## SGLD
def add_noise(model):
    for n in [x for x in model.parameters() if len(x.size()) == 4]:
        noise = torch.randn(n.size())*param_noise_sigma*learning_rate
        noise = noise.type(dtype)
        n.data = n.data + noise

net2 = get_net(input_depth, 'skip', pad,
            skip_n33d=128, 
            skip_n33u=128,
            skip_n11=4,
            num_scales=5,
            upsample_mode='bilinear').type(dtype)

## Input random noise
net_input = get_noise(input_depth, INPUT, (img_pil.size[1], img_pil.size[0])).type(dtype).detach()
net_input_saved = net_input.detach().clone()
noise = net_input.detach().clone()
i = 0

sample_count = 0

def closure_sgld():
    global i, net_input, sgld_mean, sample_count, psrn_noisy_last, last_net, sgld_mean_each
    if reg_noise_std > 0:
        net_input = net_input_saved + (noise.normal_() * reg_noise_std)
    out = net2(net_input)
    total_loss = mse(out, img_noisy_torch)
    total_loss.backward()
    out_np = out.detach().cpu().numpy()[0]

    psrn_noisy = compare_psnr(img_noisy_np, out.detach().cpu().numpy()[0])
    psrn_gt    = compare_psnr(img_np, out_np)

    sgld_psnr_list.append(psrn_gt)

    # Backtracking
    if roll_back and i % show_every:
        if psrn_noisy - psrn_noisy_last < -5: 
            print('Falling back to previous checkpoint.')
            for new_param, net_param in zip(last_net, net2.parameters()):
                net_param.detach().copy_(new_param.cuda())
            return total_loss*0
        else:
            last_net = [x.detach().cpu() for x in net2.parameters()]
            psrn_noisy_last = psrn_noisy

    if i % show_every == 0:
        np_plot(out.detach().cpu().numpy()[0], 'Iter: %d; gt %.2f' % (i, psrn_gt))
    
    if i > burnin_iter and np.mod(i, MCMC_iter) == 0:
        sgld_mean += out_np
        sample_count += 1.

    if i > burnin_iter:
        sgld_mean_each += out_np
        sgld_mean_tmp = sgld_mean_each / (i - burnin_iter)
        sgld_mean_psnr_each = compare_psnr(img_np, sgld_mean_tmp)
        sgld_psnr_mean_list.append(sgld_mean_psnr_each) # record the PSNR of avg after burn-in
        print('Iter: %d; psnr_gt %.2f; psnr_sgld %.2f' % (i, psrn_gt, sgld_mean_psnr_each))
    else:
        print('Iter: %d; psnr_gt %.2f; loss %.5f' % (i, psrn_gt, total_loss))
    
    if i == burnin_iter:
        print('Burn-in done, start sampling')

    i += 1
    return total_loss


  ## Optimizing 
print('Starting optimization with SGLD')
optimizer = torch.optim.Adam(net2.parameters(), lr=LR, weight_decay = weight_decay)
for j in range(num_iter):
    optimizer.zero_grad()
    closure_sgld()
    optimizer.step()
    add_noise(net2)

sgld_mean = sgld_mean / sample_count
sgld_mean_psnr = compare_psnr(img_np, sgld_mean)

np_plot(sgld_mean.detach().cpu().numpy()[0], 'Iter: %d; gt %.2f' % (i, sgld_mean_psnr))

# SGLD Pytorch Lightning Module

In [None]:
from nni.retiarii.evaluator.pytorch import Lightning, Trainer, LightningModule
from nni.retiarii.evaluator.pytorch.lightning import DataLoader
from pytorch_lightning.callbacks import ModelCheckpoint

from skimage.metrics import peak_signal_noise_ratio as compare_psnr

import torch
from torch.optim import Optimizer
from torch.utils.data import Dataset
from torch import optim, tensor

from typing import Any

# SGLD Pytorch Lightning Module
class SingleImageDataset(Dataset):
    def __init__(self, image, num_iter):
        self.image = image
        self.num_iter = num_iter

    def __len__(self):
        return self.num_iter

    def __getitem__(self, index):
        # Always return the same image (and maybe a noise tensor or other information if necessary??)
        return self.image

class SGLD(LightningModule):
    def __init__(self, 
        original_np,
        noisy_np,
        noisy_torch
    ):
        super().__init__()
        print('CUDA available: {}'.format(torch.cuda.is_available()))
        print(f'DTYPE: {dtype}')
        self.automatic_optimization = False

        # iterators
        self.burnin_iter=7000 # burn-in iteration for SGLD
        self.show_every=500
        self.num_iter=20000

        # backtracking
        self.psrn_noisy_last=0
        self.last_net = None
        self.roll_back = True # To solve the oscillation of model training 

        # SGLD Output Accumulation
        self.sgld_mean=0
        self.sgld_mean_each=0
        self.sgld_psnr_list = [] # psnr between sgld out and gt
        self.MCMC_iter=50
        self.param_noise_sigma=2

        # tinker with image input
        self.img_np = original_np           
        self.img_noisy_np = noisy_np
        self.img_noisy_torch = noisy_torch
        
        # network input
        self.input_depth = 32
        self.model = get_net(
                    self.input_depth, 
                    'skip', 
                    'reflection',
                    skip_n33d=128, 
                    skip_n33u=128,
                    skip_n11=4,
                    num_scales=5,
                    upsample_mode='bilinear'
                ).type(self.dtype)
        self.net_input = get_noise(self.input_depth, 'noise', (img_np.shape[-2:][1], img_np.shape[-2:][0])).type(self.dtype).detach()
        self.net_input_saved = self.net_input.detach().clone()
        self.noise = self.net_input.detach().clone()
        
        # closure
        self.reg_noise_std = tensor(1./30.)
        self.criteria = torch.nn.MSELoss().type(dtype) # loss

        # optimizer
        self.learning_rate = 0.01
        self.weight_decay = 5e-8

    ## SGLD
    def add_noise(self, net):
        for n in [x for x in net.parameters() if len(x.size()) == 4]:
            noise = torch.randn(n.size())*self.param_noise_sigma*self.learning_rate
            noise = noise.type(dtype)
            n.data = n.data + noise

    def forward(self, net_input_saved):
        if self.reg_noise_std > 0:
            self.net_input = self.net_input_saved + (self.noise.normal_() * self.reg_noise_std)
            return self.model(self.net_input)
        else:
            return self.model(net_input_saved)

    def closure_sgld(self):
        out = self.forward(self.net_input)
        total_loss = self.criteria(out, self.img_noisy_torch)
        total_loss.backward()
        out_np = out.detach().cpu().numpy()[0]

        psrn_noisy = compare_psnr(self.img_noisy_np, out.detach().cpu().numpy()[0])
        psrn_gt    = compare_psnr(self.img_np, out_np)

        self.sgld_psnr_list.append(psrn_gt)

        # Backtracking
        if self.roll_back and self.i % self.show_every:
            if psrn_noisy - self.psrn_noisy_last < -5: 
                print('Falling back to previous checkpoint.')
                for new_param, net_param in zip(self.last_net, self.model.parameters()):
                    net_param.detach().copy_(new_param.cuda())
                return total_loss*0
            else:
                self.last_net = [x.detach().cpu() for x in model.parameters()]
                self.psrn_noisy_last = psrn_noisy

        if self.i % self.show_every == 0:
            np_plot(out.detach().cpu().numpy()[0], 'Iter: %d; gt %.2f' % (self.i, psrn_gt))
        
        if self.i > self.burnin_iter and np.mod(self.i, self.MCMC_iter) == 0:
            self.sgld_mean += out_np
            self.sample_count += 1.

        if self.i > self.burnin_iter:
            self.sgld_mean_each += out_np
            sgld_mean_tmp = self.sgld_mean_each / (self.i - self.burnin_iter)
            self.sgld_mean_psnr_each = compare_psnr(self.img_np, self.sgld_mean_tmp)
            self.sgld_psnr_mean_list.append(self.sgld_mean_psnr_each) # record the PSNR of avg after burn-in
            print('Iter: %d; psnr_gt %.2f; psnr_sgld %.2f' % (self.i, psrn_gt, self.sgld_mean_psnr_each))
        else:
            print('Iter: %d; psnr_gt %.2f; loss %.5f' % (self.i, psrn_gt, total_loss))
        
        if self.i == self.burnin_iter:
            print('Burn-in done, start sampling')

        self.i += 1
        return total_loss

    def configure_optimizers(self) -> Optimizer:
        """
        We are doing a manual implementation of the SGLD optimizer
        There is a SGLD optimizer that can be found here:
            - https://pysgmcmc.readthedocs.io/en/pytorch/_modules/pysgmcmc/optimizers/sgld.html
            - Implementing this would greatly affect the training step
                - But could it work?? :`( I couldn't figure it out
        """
        return torch.optim.Adam(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)

    def train_dataloader(self):
        """
        Trick this puppy into thinking we have a dataloader
        It's a single image for deep image priors
        So we just need to return a dataloader with a single image
        """
        dataset = SingleImageDataset(self.img_np, self.num_iter)
        return DataLoader(dataset, batch_size=1)

    def on_train_start(self) -> None:
        """
        Move all tensors to the GPU to begin training
        Initialize Iterators
        Set Sail
        """
        self.model.to(self.device)
        self.net_input = self.net_input.to(self.device)
        self.img_noisy_torch = self.img_noisy_torch.to(self.device)
        self.reg_noise_std = self.reg_noise_std.to(self.device)

        self.net_input_saved = self.net_input.clone().to(self.device)
        self.noise = self.net_input.clone().to(self.device)
        
        # Initialize Iterations
        self.i=0
        self.sample_count=0

        # bon voyage
        print('Starting optimization with SGLD')

    def training_step(self, batch: Any, batch_idx: int) -> Any:
        """
        Oh the places you'll go
        ---> Straight to error city calling this add_noise in the training step
        ---> Consider using the on_train_batch_end hook? (each batch is only one iteration)
        """
        optimizer = self.optimizers()
        optimizer.zero_grad()
        loss = self.closure_sgld()
        optimizer.step()
        self.add_noise(model)
        return loss

    def on_train_end(self) -> None:
        """
        May all your dreams come true
        """
        self.sgld_mean = self.sgld_mean / self.sample_count
        np_plot(self.sgld_mean.detach().cpu().numpy()[0], 'Final after %d iterations' % (self.i))
        


def image_unpack(fname, imsize=-1, sigma=25/255):
    dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor

    img_pil = crop_image(get_image(fname, imsize)[0], d=32) 
    img_np = pil_to_np(img_pil)
    
    img_noisy_pil, img_noisy_np = get_noisy_image(img_np, sigma)
    img_noisy_torch = np_to_torch(img_noisy_np).type(dtype)
    
    return {
        'original_pil': img_pil,
        'original_np': img_np,
        'noisy_pil': img_noisy_pil, 
        'noisy_np': img_noisy_np,
        'noisy_torch': img_noisy_torch
    }


# choose iterations
num_iter = 20000 # max iterations

# get image
fname = 'data/denoising/Dataset/image_Peppers512rgb.png'
img_dict = image_unpack(fname)
original_pil = img_dict['original_pil']
original_np = img_dict['original_np'] 
noisy_pil = img_dict['noisy_pil']
noisy_np = img_dict['noisy_np']
noisy_torch = img_dict['noisy_torch']


# reference model
dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor  
model = get_net(
                32, 
                'skip', 
                'reflection',
                skip_n33d=128, 
                skip_n33u=128,
                skip_n11=4,
                num_scales=5,
                upsample_mode='bilinear'
            ).type(dtype)

# Create the lightning module
module = SGLD(
        original_np=original_np,
        noisy_np=noisy_np,
        noisy_torch=noisy_torch)

# Create a PyTorch Lightning trainer
trainer = Trainer(
            max_epochs=num_iter,
            fast_dev_run=False,
            gpus=1,
            checkpoint_callback=False
            )

# Initialize ModelCheckpoint callback
checkpoint_callback = ModelCheckpoint(
    dirpath='./{lightning_logs}/{logger_name}/version_{version}/checkpoints/',
    filename='{epoch}-{step}',
    every_n_epochs=500,
    save_top_k=1,
)

# Add the checkpoint callback to trainer
trainer.callbacks.append(checkpoint_callback)
            
if not hasattr(trainer, 'optimizer_frequencies'):
    trainer.optimizer_frequencies = []

# Create the lighting object for evaluator
train_loader = DataLoader(SingleImageDataset(noisy_np, num_iter=1), batch_size=1)
val_loader = DataLoader(SingleImageDataset(noisy_np, num_iter=1), batch_size=1)

lightning = Lightning(lightning_module=module, trainer=trainer, train_dataloaders=train_loader, val_dataloaders=val_loader)
lightning.fit(model)

# TPU

In [None]:
import torch_xla
import torch_xla.core.xla_model as xm

# Initialize TPU
device = xm.xla_device()

MCMC_iter = 50
param_noise_sigma = 2

# Initialize these as tensors on the TPU
i = torch.tensor(0, device=device)
sample_count = torch.tensor(0.0, device=device)
sgld_mean = torch.zeros([1, *img_np.shape[1:]], dtype=torch.float32, device=device)
sgld_mean_each = torch.zeros([1, *img_np.shape[1:]], dtype=torch.float32, device=device)

## SGLD
def add_noise(model):
    for n in [x for x in model.parameters() if len(x.size()) == 4]:
        noise = torch.randn(n.size())*param_noise_sigma*learning_rate
        noise = noise.type(dtype)
        n.data = n.data + noise

net2 = get_net(input_depth, 'skip', pad,
            skip_n33d=128, 
            skip_n33u=128,
            skip_n11=4,
            num_scales=5,
            upsample_mode='bilinear').type(dtype)

## Input random noise
net_input = get_noise(input_depth, INPUT, (img_pil.size[1], img_pil.size[0])).type(dtype).detach()
net_input_saved = net_input.detach().clone()
noise = net_input.detach().clone()

def closure_sgld():
    if reg_noise_std > 0:
        net_input = net_input_saved + (noise.normal_() * reg_noise_std)
    out = net2(net_input)
    total_loss = mse(out, img_noisy_torch)
    total_loss.backward()
    out_np = out.detach().cpu().numpy()[0]
    
    if i > burnin_iter and i % MCMC_iter == 0:
        sgld_mean += out_np
        sample_count += 1.

    if i > burnin_iter:
        sgld_mean_each += out_np
        sgld_mean_tmp = sgld_mean_each / (i - burnin_iter)

    i += 1
    return total_loss


  ## Optimizing 
print('Starting optimization with SGLD')
optimizer = torch.optim.Adam(net2.parameters(), lr=LR, weight_decay = weight_decay)
for j in range(num_iter):
    optimizer.zero_grad()
    closure_sgld()
    optimizer.step()
    add_noise(net2)

sgld_mean /= sample_count
sgld_mean_psnr = compare_psnr(img_np, sgld_mean)


_, ax = plt.subplots(1, 3, figsize=(10, 5))

ax[0].imshow(img_np.squeeze(), cmap='gray')
ax[0].set_title("Original Image")
ax[0].axis('off')

ax[1].imshow(sgld_mean.squeeze(), cmap='gray')
ax[1].set_title("Denoised Image")
ax[1].axis('off')

ax[2].imshow(img_noisy_torch.detach().cpu().squeeze().numpy(), cmap='gray')
ax[2].set_title("Noisy Image")
ax[2].axis('off')

plt.tight_layout()
plt.show()

In [None]:

class MyModel:
    def __init__(self, ...):  
        self.i = torch.tensor(0, device=device)
        self.sample_count = torch.tensor(0.0, device=device)
        self.sgld_mean = torch.zeros([1, *img_np.shape[1:]], dtype=torch.float32, device=device)
        self.sgld_mean_each = torch.zeros([1, *img_np.shape[1:]], dtype=torch.float32, device=device)
        self.net = ...  # Initialize model
        # Initialize other necessary attributes
    
    def closure_sgld(self):
        if self.i > self.burnin_iter and self.i % self.MCMC_iter == 0:
            self.sgld_mean += self.out.detach()
            self.sample_count += 1.0

        if self.i > self.burnin_iter:
            self.sgld_mean_each += self.out.detach()
            self.sgld_mean_tmp = self.sgld_mean_each / (self.i - self.burnin_iter)

        self.i += 1
        return self.total_loss


# Initialize TPU
device = xm.xla_device()

# Initialize class-based model
model = MyModel(...)

optimizer = torch.optim.Adam(model.net.parameters(), lr=LR, weight_decay = weight_decay)
for j in range(num_iter):
    optimizer.zero_grad()
    loss = model.closure_sgld()
    loss.backward()
    optimizer.step()
    add_noise(net2) # net2 is the model in the previous example