In [None]:
from __future__ import print_function
import matplotlib.pyplot as plt

import os

import numpy as np
from models import *

import torch
import torch.optim

from skimage.metrics import peak_signal_noise_ratio
from utils.denoising_utils import *

torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = False
dtype = torch.FloatTensor

imsize = -1
PLOT = True
sigma = 25
sigma_ = sigma / 255.0

mps_device = torch.device("mps")

In [None]:
img_all_np = np.load("../Fluo enhancement/apoptosis.npy").transpose(2, 0, 1)

img_all_np -= img_all_np.min()
img_all_np /= img_all_np.max()


# img_noisy_np = img_all_np[0:3, :, :]
# img_noisy_np -= img_noisy_np.min()
# img_noisy_np /= img_noisy_np.max()

# img_np = img_noisy_np

# plot_image_grid([img_np], 4, 5)

In [None]:
INPUT = "noise"  # 'meshgrid'
pad = "reflection"
OPT_OVER = "net"  # 'net,input'

reg_noise_std = 1.0 / 30.0  # set to 1./20. for sigma=50
LR = 0.01

OPTIMIZER = "adam"  # 'LBFGS'
show_every = 99
exp_weight = 0.99

num_iter = 100  # 2400
input_depth = 3
figsize = 5

In [None]:
output_image = []

for i in range(10):

    net = skip(
        input_depth,
        3,
        num_channels_down=[8, 16, 32, 64, 128],
        num_channels_up=[8, 16, 32, 64, 128],
        num_channels_skip=[0, 0, 0, 4, 4],
        upsample_mode="bilinear",
        need_sigmoid=True,
        need_bias=True,
        pad=pad,
        act_fun="LeakyReLU",
    )

    net = net.type(dtype)

    # # Compute number of parameters
    # s = sum([np.prod(list(p.size())) for p in net.parameters()])
    # print("Number of params: %d" % s)

    # Loss
    mse = torch.nn.MSELoss().type(dtype)

    img_noisy_np = img_all_np[3 * i : 3 * i + 3, :, :]
    img_np = img_noisy_np

    img_noisy_torch = np_to_torch(img_noisy_np).type(dtype)

    net_input = (
        get_noise(input_depth, INPUT, (img_all_np.shape[2], img_all_np.shape[1]))
        .type(dtype)
        .detach()
    )

    net_input_saved = net_input.detach().clone()
    noise = net_input.detach().clone()
    out_avg = 1  # None
    last_net = None
    psrn_noisy_last = 0

    i = 0

    def closure():

        global i, out_avg, psrn_noisy_last, last_net, net_input

        if reg_noise_std > 0:
            net_input = net_input_saved + (noise.normal_() * reg_noise_std)

        out = net(net_input)

        # Smoothing
        if out_avg is None:
            out_avg = out.detach()
        else:
            out_avg = out_avg * exp_weight + out.detach() * (1 - exp_weight)

        total_loss = mse(out, img_noisy_torch)
        total_loss.backward()

        psrn_noisy = peak_signal_noise_ratio(
            img_noisy_np, out.detach().cpu().numpy()[0], data_range=1
        )
        # psrn_gt = peak_signal_noise_ratio(
        #     img_np, out.detach().cpu().numpy()[0], data_range=1
        # )
        # psrn_gt_sm = peak_signal_noise_ratio(
        #     img_np, out_avg.detach().cpu().numpy()[0], data_range=1
        # )

        # # Note that we do not have GT for the "snail" example
        # # So 'PSRN_gt', 'PSNR_gt_sm' make no sense
        print(
            "Iteration %05d    Loss %f   PSNR_noisy: %f"
            % (i, total_loss.item(), psrn_noisy),
            "\r",
            end="",
        )
        # print(
        #     "Iteration %05d    Loss %f   PSNR_noisy: %f   PSRN_gt: %f PSNR_gt_sm: %f"
        #     % (i, total_loss.item(), psrn_noisy, psrn_gt, psrn_gt_sm),
        #     "\r",
        #     end="",
        # )
        if PLOT and i % show_every == 0:
            out_np = torch_to_np(out)
            plot_image_grid(
                [np.clip(out_np, 0, 1), np.clip(torch_to_np(out_avg), 0, 1)],
                factor=figsize,
                nrow=1,
            )

        # Backtracking
        if i % show_every:
            if psrn_noisy - psrn_noisy_last < -5:
                print("Falling back to previous checkpoint.")

                for new_param, net_param in zip(last_net, net.parameters()):
                    net_param.data.copy_(new_param.cuda())

                return total_loss * 0
            else:
                last_net = [x.detach().cpu() for x in net.parameters()]
                psrn_noisy_last = psrn_noisy

        i += 1

        return total_loss

    p = get_params(OPT_OVER, net, net_input)
    optimize(OPTIMIZER, p, closure, LR, num_iter)

    output_image.append(net(net_input).detach().squeeze())

In [None]:
oo = torch.cat(output_image, dim=0).numpy()

In [None]:
img_all_np.shape

In [None]:
fig, axes = plt.subplots(27, 2, figsize=(6, 61))

for i in range(27):
    axes[i, 0].imshow(img_all_np[i, :, :].squeeze(), vmin=0, vmax=1, cmap="gray")
    axes[i, 0].axis("off")
    axes[i, 1].imshow(oo[i, :, :].squeeze(), vmin=0, vmax=0.1, cmap="gray")
    axes[i, 1].axis("off")