In [None]:
import skimage
import numpy as np
import torch
import torch.nn as nn
from torch import cuda, optim, tensor, zeros_like
from torch import device as torch_device
from torch.nn import L1Loss, MSELoss
from matplotlib import pyplot as plt


from darts.common_utils import *
from darts.phantom import generate_phantom, phantom_to_torch
from darts.noises import add_selected_noise
from darts.early_stop import EarlyStop, MSE, MAE


In [None]:
torch.cuda.empty_cache()

In [None]:

device = torch_device('cuda' if cuda.is_available() else "cpu")
dtype = cuda.FloatTensor


model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
                       in_channels=1, out_channels=1, init_features=64, pretrained=False)

buffer_size = 100
patience = 1000
num_iter = 7500
show_every = 1
lr = 0.00005

# reg_noise_std = 1./30. 
reg_noise_std = tensor(1./30.).type(dtype).to(device)
noise_type = 'gaussian'
noise_factor = 0.1
resolution= 6
n_channles = 1

raw_img_np = generate_phantom(resolution=resolution) # 1x64x64 np array
img_np = raw_img_np.copy() # 1x64x64 np array
img_torch = torch.tensor(raw_img_np, dtype=torch.float32).unsqueeze(0) # 1x1x64x64 torch tensor
img_noisy_torch = add_selected_noise(img_torch, noise_type=noise_type,noise_factor=noise_factor) # 1x1x64x64 torch tensor
img_noisy_np = img_noisy_torch.squeeze(0).numpy() # 1x64x64 np array

img_noisy_torch = img_noisy_torch.to(device)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
ax1.imshow(np.mean(raw_img_np, axis=0), cmap='gray')
ax1.set_title("Clean Image")
ax1.axis('off')
ax2.imshow(np.mean(img_noisy_np, axis=0), cmap='gray')
ax2.set_title("Noisy Image")
ax2.axis('off')
plt.show()
plt.close()

net_input = get_noise(input_depth=1, spatial_size=raw_img_np.shape[1], noise_type=noise_type).type(dtype).to(device)

# Add synthetic noise
net = model.to(device)
net = net.type(dtype)

# Loss
criterion = MSELoss().type(dtype).to(device)

# Optimizer
p = get_params('net', net, net_input)  # network parameters to be optimized
optimizer = optim.Adam(p, lr=lr)

# Optimize

loss_history = []
psnr_history = []
ssim_history = []
variance_history = []
x_axis = []
earlystop = EarlyStop(size=buffer_size,patience=patience)
def closure(iterator):
    #DIP
    net_input_perturbed = net_input + zeros_like(net_input).normal_(std=reg_noise_std)
    r_img_torch = net(net_input_perturbed)
    total_loss = criterion(r_img_torch, img_noisy_torch)
    total_loss.backward()
    loss_history.append(total_loss.item())
    if iterator % show_every == 0:
        # evaluate recovered image (PSNR, SSIM)
        r_img_np = torch_to_np(r_img_torch)
        psnr = skimage.metrics.peak_signal_noise_ratio(img_np, r_img_np)
        temp_img_np = np.transpose(img_np,(1,2,0))
        temp_r_img_np = np.transpose(r_img_np,(1,2,0))
        data_range = temp_img_np.max() - temp_img_np.min()
        if n_channles == 1:
            multichannel = False
        else:
            multichannel = True
        ssim = skimage.metrics.structural_similarity(temp_img_np, temp_r_img_np, multichannel=multichannel, win_size=7, channel_axis=-1, data_range=data_range)
        psnr_history.append(psnr)
        ssim_history.append(ssim)
        
        #variance hisotry
        r_img_np = r_img_np.reshape(-1)
        earlystop.update_img_collection(r_img_np)
        img_collection = earlystop.get_img_collection()
        if iterator % (show_every*10) == 0:
            print(f'Iteration %05d    Loss %.4f' % (iterator, total_loss.item()) + '    PSNR %.4f' % (psnr) + '    SSIM %.4f' % (ssim) + '    Collection Size %.4f' % (int(len(img_collection))))
        if len(img_collection) == buffer_size:
            ave_img = np.mean(img_collection,axis = 0)
            variance = []
            for tmp in img_collection:
                variance.append(MSE(ave_img, tmp))
            cur_var = np.mean(variance)
            cur_epoch = iterator
            variance_history.append(cur_var)
            x_axis.append(cur_epoch)
            if earlystop.stop == False:
                earlystop.stop = earlystop.check_stop(cur_var, cur_epoch)
    if earlystop.stop:
        return "STOP"
    return total_loss
    
for iterator in range(num_iter):
    optimizer.zero_grad()
    early_stop = closure(iterator)
    optimizer.step()
    
    if iterator % (show_every*100) == 0:
        r_img_np = torch_to_np(net(net_input))
        plot_side_by_side(np.clip(img_np, 0, 1), np.clip(r_img_np, 0, 1), np.clip(img_noisy_np,0,1))

    # if early_stop == "STOP":
    #     print("Early stopping triggered.")
    #     break



Iteration 01977    Loss 0.0031    PSNR 21.1801    SSIM 0.3640    Collection Size 100.0000

In [None]:
#show PSNR and variance history and detection
fig, ax1 = plt.subplots()

color = 'tab:red'
ax1.set_xlabel('Epoch')
ax1.set_ylabel('PSNR', color=color)
ax1.plot(psnr_history, color=color)
ax1.tick_params(axis='y', labelcolor=color)

ax2 = ax1.twinx()

color = 'tab:blue'
ax2.set_ylabel('Variance', color=color)
ax2.plot(x_axis, variance_history, color=color)
ax2.tick_params(axis='y', labelcolor=color)
fig.tight_layout()
plt.title('ES-WMV')
plt.axvline(x=earlystop.best_epoch, label='detection',color='y')
plt.legend()
plt.show()
     