In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scs_toolkit as scs
import torch
import torchvision.models as models

In [None]:
###
### PARAMETERS
###

device = 0 # GPU id or "cpu""

# Representation
# Choices are "wph" or "vgg"
rep = "wph"
assert rep in ["wph", "vgg"]

# Test data
# Choices are "dust", "lss", "imagenet"
test_data = "dust"
assert test_data in ["dust", "lss", "imagenet"]
assert rep != "vgg" or test_data == "imagenet" # VGG only works with the imagenet data
x0 = scs.get_exp_data(test_data, vgg=rep=="vgg")
M, N = x0.shape[-2:]

# Noise type
# Choices are "white", "pink", "blue", "crosses"
sigma = 1.0
noise_type = "white"
assert noise_type in ["white", "pink", "blue", "crosses"]

# Optim parameters
optimizer = "scipy_lbfgs"
optim_params = {"maxiter": 30, "gtol": 1e-14, "ftol": 1e-14, "maxcor": 20} # 'maxiter' corresponds to 'T' in the paper
noise_batch_size = 100 # 'Q' in the paper
noise_batch_size_split = 20 # For WPH rep only, this should be adapted to the GPU memory (to be increased if OOM)

# Experiment description
# To reproduce Algo 1: n_steps=1, perturb=False
# To reproduce Algo 2: n_steps=int(10*sigma), perturb=False
# To reproduce Algo 2 pert.: n_iters=int(10*sigma), perturb=True (also to follow results of the paper, change 'maxiter' key of optim_params variable to 10)
n_steps = 1 # 'P' in the paper
perturb = False
assert (rep == "wph" and noise_type == "white") or not perturb # Perturb is only available for the WPH representation and the white noise case

In [None]:
###
### OBJECTIVE FUNCTION
###

if rep == "wph":
    # Noise loader
    noise_loader = lambda: scs.get_noises(noise_batch_size, M, N,
                                          sigma, noise_batch_size_split=noise_batch_size_split,
                                          type=noise_type, device=device)

    # Objective function
    wph_op = scs.WPHOp(M, N, 7, device=device)
    wph_op.set_model(["S11", "S00", "S01", "C01"] if not perturb else ["S11", "S01"])
    obj = lambda x, y, alpha: scs.wph_loss(x, y, wph_op,
                                           dt=alpha if not perturb else alpha*sigma,
                                           perturb=perturb, verbose=True,
                                           noise_loader=noise_loader, backward=not perturb)
    
else:
    # Noise loader
    noise_loader = lambda: scs.get_noises(3*noise_batch_size, M, N,
                                          sigma*x0.std(), type=noise_type,
                                          device=device).view((noise_batch_size, 3, M, N))

    # Objective function
    vgg_model = models.vgg19_bn(weights=models.VGG19_BN_Weights.DEFAULT)
    features = vgg_model.features[:6].to(device).eval()
    for param in features.parameters():
        param.requires_grad = False
    print(features)
    phi = lambda x: torch.mean(features(x)**2, dim=(-1, -2))
    obj = lambda x, y, alpha: scs.general_loss(x, y, phi, dt=alpha, verbose=True, noise_loader=noise_loader, device=device)

In [None]:
###
### STATISTICAL COMPONENT SEPARATION
###

# We make a noisy test map
e0 = noise_loader()[0, 0].cpu().numpy() if rep == "wph" else noise_loader()[0].cpu().numpy()
y = x0 + e0

# Statistical Component Separation
x_hat_0 = scs.diffusive_denoiser(y, n_steps, obj,
                                 optimizer=optimizer, optim_params=optim_params,
                                 device=device)

# Plot images
if rep == "vgg":
    x0 = scs.imagenet_unnormalize(torch.from_numpy(x0))[0].cpu().numpy().transpose((1, 2, 0))
    y = scs.imagenet_unnormalize(torch.from_numpy(y))[0].cpu().numpy().transpose((1, 2, 0))
    x_hat_0 = scs.imagenet_unnormalize(torch.from_numpy(x_hat_0))[0].cpu().numpy().transpose((1, 2, 0))
fig, ax = plt.subplots(1, 3, figsize=(15, 5), sharex=True, sharey=True)
vmin, vmax = -3, 3
ax[0].imshow(x0, cmap="magma", vmin=vmin, vmax=vmax)
ax[0].set_title("Original $x_0$")
ax[1].imshow(y, cmap="magma", vmin=vmin, vmax=vmax)
ax[1].set_title("Noisy $y$")
ax[2].imshow(x_hat_0, cmap="magma", vmin=vmin, vmax=vmax)
ax[2].set_title(r"Denoised $\hat x_0$")
plt.show()