# Tutorial 1 - Gaussian Image Denoising with Dictionary Learning based Unrolled Networks

This tutorial is based on the following paper:

B. Tolooshams, and D. Ba, "[Stable and Interpretable Unrolled Dictionary Learning](https://openreview.net/pdf?id=e3S0Bl2RO8)," *Transactions on Machine Learning Research*, 2022.

---

## Tutorial objectives 

In this notebook, we learn how to design and train autoencoders based on dictionary learning for Gaussian image denoising.

 - Build an unrolled neural network based on dictionary learning model.
 - Load image data.
 - Train the network to denoise images.

---
## Imports and helper functions

Please execute the cell below to initialize the notebook environment

In [1]:
# @title
import torch
import torch.nn.functional as F
import torchvision

import numpy as np
from tqdm import tqdm, trange

In [2]:
# @title
def compute_psnr(x, xhat):
    psnr = []
    for i in range(x.shape[0]):
        mse = np.mean((x[i] - xhat[i]) ** 2)
        max_x = np.max(x[i])
        psnr.append(20 * np.log10(max_x) - 10 * np.log10(mse))
    return np.mean(psnr)


def test_network(data_loader, net, params, name="test"):

    net.eval()

    device = params["device"]

    psnr = []
    for idx, (x, _) in enumerate(data_loader):

        x = x.to(device)

        # forward
        if params["noise_std"]:
            x_noisy = (
                x + params["noise_std"] / 255 * torch.randn(x.shape, device=device)
            ).to(device)
            xhat, _ = net(x_noisy)
        else:
            xhat, _ = net(x)

        xhat = torch.clamp(xhat, 0, 1)

        psnr.append(
            compute_psnr(
                x[:, 0].clone().detach().cpu().numpy(),
                xhat[:, 0].clone().detach().cpu().numpy(),
            )
        )

    psnr = np.mean(np.array(psnr))

    return psnr

### Build an unrolled network based on a convolutional dictionary

In [3]:
class CAE(torch.nn.Module):
    def __init__(self, params, W=None):
        super(CAE, self).__init__()

        self.device = params["device"]
        self.num_ch = params["num_ch"]
        self.lam = params["lam"]
        self.num_layers = params["num_layers"]
        self.twosided = params["twosided"]
        self.num_conv = params["num_conv"]
        self.dictionary_dim = params["dictionary_dim"]
        self.stride = params["stride"]

        self.relu = torch.nn.ReLU()

        if W is None:
            W = torch.randn(
                (self.num_conv, self.num_ch, self.dictionary_dim, self.dictionary_dim),
                device=self.device,
            )
            W = F.normalize(W, p="fro", dim=(-1, -2))
            W /= self.num_ch
        self.register_parameter("W", torch.nn.Parameter(W))
        self.register_buffer("step", torch.tensor(params["step"]))

    def get_param(self, name):
        return self.state_dict(keep_vars=True)[name]

    def normalize(self):
        self.W.data = F.normalize(self.W.data, p="fro", dim=(-1, -2))
        self.W.data /= self.num_ch

    def nonlin(self, z):
        if self.twosided:
            z = self.relu(torch.abs(z) - self.lam * self.step) * torch.sign(z)
        else:
            z = self.relu(z - self.lam * self.step)
        return z

    def forward(self, x):

        zhat = self.nonlin(
            F.conv2d(x, self.W, stride=self.stride) * self.step
        )
        for k in range(self.num_layers - 1):
            Wz = F.conv_transpose2d(zhat, self.W, stride=self.stride)
            res = Wz - x
            grad = F.conv2d(res, self.W, stride=self.stride)
            zhat = self.nonlin(zhat - grad * self.step)

        xhat = F.conv_transpose2d(zhat, self.W, stride=self.stride)

        return xhat, zhat

### Build an unrolled network based on a dictionary matrix with learnable bias (lambda)

In [4]:
class CAElearnbias(torch.nn.Module):
    def __init__(self, params, W=None):
        super(CAElearnbias, self).__init__()

        self.device = params["device"]
        self.num_ch = params["num_ch"]
        self.num_layers = params["num_layers"]
        self.twosided = params["twosided"]
        self.num_conv = params["num_conv"]
        self.dictionary_dim = params["dictionary_dim"]
        self.stride = params["stride"]

        self.relu = torch.nn.ReLU()

        if W is None:
            W = torch.randn(
                (self.num_conv, self.num_ch, self.dictionary_dim, self.dictionary_dim),
                device=self.device,
            )
            W = F.normalize(W, p="fro", dim=(-1, -2))
            W /= self.num_ch
        self.register_parameter("W", torch.nn.Parameter(W))
        self.register_buffer("step", torch.tensor(params["step"]))

        b = torch.nn.Parameter(
            torch.zeros(1, self.num_conv, 1, 1, device=self.device)
            + params["lam"] * params["step"]
        )
        self.register_parameter("b", b)  # this is lam * step

    def get_param(self, name):
        return self.state_dict(keep_vars=True)[name]

    def normalize(self):
        self.W.data = F.normalize(self.W.data, p="fro", dim=(-1, -2))
        self.W.data /= self.num_ch

    def nonlin(self, z):
        if self.twosided:
            z = self.relu(torch.abs(z) - self.b) * torch.sign(z)
        else:
            z = self.relu(z - self.b)
        return z

    def forward(self, x):

        zhat = self.nonlin(
            F.conv2d(x, self.W, stride=self.stride) * self.step
        )
        for k in range(self.num_layers - 1):
            Wz = F.conv_transpose2d(zhat, self.W, stride=self.stride)
            res = Wz - x
            grad = F.conv2d(res, self.W, stride=self.stride)
            zhat = self.nonlin(zhat - grad * self.step)

        xhat = F.conv_transpose2d(zhat, self.W, stride=self.stride)

        return xhat, zhat

### Build a relaxed unrolled network based on a dictionary matrix

In [5]:
class CAElearnbiasuntied(torch.nn.Module):
    def __init__(self, params, W=None):
        super(CAElearnbiasuntied, self).__init__()

        self.device = params["device"]
        self.num_ch = params["num_ch"]
        self.num_layers = params["num_layers"]
        self.twosided = params["twosided"]
        self.num_conv = params["num_conv"]
        self.dictionary_dim = params["dictionary_dim"]
        self.stride = params["stride"]

        self.relu = torch.nn.ReLU()

        if W is None:
            W = torch.randn(
                (self.num_conv, self.num_ch, self.dictionary_dim, self.dictionary_dim),
                device=self.device,
            )
            W = F.normalize(W, p="fro", dim=(-1, -2))
            W /= self.num_ch
        self.register_parameter("W", torch.nn.Parameter(W))

        E = torch.clone(W)
        D = torch.clone(W)

        self.register_parameter("E", torch.nn.Parameter(E))
        self.register_parameter("D", torch.nn.Parameter(D))

        self.register_buffer("step", torch.tensor(params["step"]))

        b = torch.nn.Parameter(
            torch.zeros(1, self.num_conv, 1, 1, device=self.device)
            + params["lam"] * params["step"]
        )
        self.register_parameter("b", b)  # this is lam * step

    def get_param(self, name):
        return self.state_dict(keep_vars=True)[name]

    def normalize(self):
        self.W.data = F.normalize(self.W.data, p="fro", dim=(-1, -2))
        self.W.data /= self.num_ch

    def nonlin(self, z):
        if self.twosided:
            z = self.relu(torch.abs(z) - self.b) * torch.sign(z)
        else:
            z = self.relu(z - self.b)
        return z

    def forward(self, x):
        zhat = self.nonlin(
            F.conv2d(x, self.E, stride=self.stride) * self.step
        )
        for k in range(self.num_layers - 1):
            Wz = F.conv_transpose2d(zhat, self.D, stride=self.stride)
            res = Wz - x
            grad = F.conv2d(res, self.E, stride=self.stride)
            zhat = self.nonlin(zhat - grad * self.step)

        xhat = F.conv_transpose2d(zhat, self.W, stride=self.stride)

        return xhat, zhat

### Setup network and training parameters

In [6]:
params = {
    "network": "CAE",
    "shuffle": True,
    "batch_size": 8,
    "num_workers": 4,
    # related to the Network
    "num_ch": 1,
    "num_conv": 64,
    "dictionary_dim": 9,
    "stride": 4,
    "patch_size": 129,
    "lam": 0.12,
    "step": 0.1,
    "num_layers": 15,
    "twosided": False,
    # related to the optimizer
    "lr": 1e-4,
    "adam_eps": 1e-3,
    "adam_weight_decay": 0,
    # related to noise
    "noise_std": 25,
    #
    "normalize": False,
    "num_epochs": 20,
    #
    "tqdm_prints_disable": False,
    "log_info_epoch_period": 2,
    #
    "train_image_path": "../data/CBSD432",
    "test_image_path": "../data/BSD68",
    "device": "cuda:0" if torch.cuda.is_available() else "cpu",
}

device = params["device"]

### Build train and test datasets with various data augmentations from BSD.

In [7]:
# load data
train_dataset = torchvision.datasets.ImageFolder(
    root=params["train_image_path"],
    transform=torchvision.transforms.Compose(
        [
            torchvision.transforms.Grayscale(),
            torchvision.transforms.RandomVerticalFlip(),
            torchvision.transforms.RandomHorizontalFlip(),
            torchvision.transforms.RandomCrop(
                params["patch_size"],
                padding=None,
                pad_if_needed=True,
                fill=0,
                padding_mode="constant",
            ),
            torchvision.transforms.ToTensor(),
        ]
    ),
)

test_dataset = torchvision.datasets.ImageFolder(
    root=params["test_image_path"],
    transform=torchvision.transforms.Compose(
        [
            torchvision.transforms.Grayscale(),
            torchvision.transforms.RandomCrop(
                params["patch_size"],
                padding=None,
                pad_if_needed=True,
                fill=0,
                padding_mode="constant",
            ),
            torchvision.transforms.ToTensor(),
        ]
    ),
)

### Make dataloaders

In [8]:
# make dataloader
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    shuffle=params["shuffle"],
    batch_size=params["batch_size"],
    num_workers=params["num_workers"],
)

test_loader = torch.utils.data.DataLoader(
    test_dataset, shuffle=False, batch_size=1, num_workers=params["num_workers"],
)

### Create the network (model)

In [9]:
# create model
if params["network"] == "CAE":
    net = CAE(params)
elif params["network"] == "CAElearnbias":
    net = CAElearnbias(params)
elif params["network"] == "CAElearnbiasuntied":
    net = CAElearnbiasuntied(params)
else:
    print("Network is not implemented!")
    raise NotImplementedError

if params["normalize"]:
    net.normalize()

### Build optimizer

In [10]:
# optimizer
optimizer = torch.optim.Adam(
    net.parameters(),
    lr=params["lr"],
    eps=params["adam_eps"],
    weight_decay=params["adam_weight_decay"],
)

### Create a loss criterion

In [11]:
# loss criterion
# you can also use already defined loss in PyTorch.

class DLLoss2D(torch.nn.Module):
    def __init__(self):
        super(DLLoss2D, self).__init__()

    def forward(self, x, xhat):
        return 0.5 * (x - xhat).pow(2).sum(dim=(-1, -2)).mean()
    
criterion = DLLoss2D()

### Train the network

In [None]:
# train
for epoch in tqdm(
    range(params["num_epochs"]), disable=params["tqdm_prints_disable"]
):
    net.train()

    for idx, (x, _) in tqdm(enumerate(train_loader), disable=True):
        optimizer.zero_grad()

        x = x.to(device)

        # forward
        if params["noise_std"]:
            x_noisy = (
                x + params["noise_std"] / 255 * torch.randn(x.shape, device=device)
            ).to(device)
            xhat, zT = net(x_noisy)
        else:
            xhat, zT = net(x)

        loss = criterion(x, xhat)

        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if params["normalize"]:
            net.normalize()
 
    if (epoch + 1) % params["log_info_epoch_period"] == 0:
        
        test_psnr = test_network(test_loader, net, params)  
        
        print("Epoch: Train loss {}, Test PSNR {}".format(loss.item(), test_psnr))           

### Write down your own test and visulization functions 

# The End!