In [None]:
import skimage
import numpy as np
from torch import cuda, optim, tensor, zeros_like
from torch import device as torch_device
from torch.nn import L1Loss, MSELoss
import torch
import torch.nn as nn


from darts.common_utils import *
from darts.phantom import generate_phantom, phantom_to_torch
from darts.noises import add_selected_noise


class EarlyStop():
    def __init__(self, size, patience):
        self.patience = patience
        self.wait_count = 0
        self.best_score = float('inf')
        self.best_epoch = 0
        self.img_collection = []
        self.stop = False
        self.size = size

    def check_stop(self, current, cur_epoch):
      #stop when variance doesn't decrease for consecutive P(patience) times
        if current < self.best_score:
            self.best_score = current
            self.best_epoch = cur_epoch
            self.wait_count = 0
            should_stop = False
        else:
            self.wait_count += 1
            should_stop = self.wait_count >= self.patience
        return should_stop

    def update_img_collection(self, cur_img):
        self.img_collection.append(cur_img)
        if len(self.img_collection) > self.size:
            self.img_collection.pop(0)

    def get_img_collection(self):
        return self.img_collection

def MSE(x1, x2):
    return ((x1 - x2) ** 2).sum() / x1.size

def MAE(x1, x2):
    return (np.abs(x1 - x2)).sum() / x1.size

In [None]:

class EncoderDecoder(nn.Module):
    def __init__(self):
        super(EncoderDecoder, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1),  # 16 x 32 x 32
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 32 x 16 x 16
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 64 x 8 x 8
            nn.ReLU()
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), # 32 x 16 x 16
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1), # 16 x 32 x 32
            nn.ReLU(),
            nn.ConvTranspose2d(16, 3, kernel_size=4, stride=2, padding=1),  # 3 x 64 x 64
            nn.Sigmoid()  # To bring the output values between 0 and 1
        )
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# Example usage
model = EncoderDecoder()
input_tensor = torch.randn(1, 3, 64, 64)
output = model(input_tensor)
print(output.shape)  # torch.Size([1, 3, 64, 64])


In [None]:
def test_space(Space):
    dtype = cuda.FloatTensor
    buffer_size = 100
    patience = 1000
    lr = 0.01
    num_iter = 2

    device = torch_device('cuda' if cuda.is_available() else "cpu")

    noise_type = 'gaussian'
    img_np = generate_phantom(resolution=6)
    img = phantom_to_torch(img_np)
    img_noisy = add_selected_noise(img, noise_type=noise_type).to(device)
    img_noisy_np = torch_to_np(img_noisy.squeeze())

    net_input = get_noise(input_depth=3, spatial_size=img.size()[3], noise_type=noise_type).to(device)


    # Add synthetic noise
    net = Space.to(device)
    net = net.type(dtype)

    # Loss
    criterion = MSELoss().type(dtype).to(device)

    # Optimizer
    p = get_params('net', net, net_input)  # network parameters to be optimized
    optimizer = optim.Adam(p, lr=lr)

    # Optimize

    # reg_noise_std = 1./30. 
    reg_noise_std = tensor(1./30.).type(dtype).to(device)
    show_every = 1
    loss_history = []
    psnr_history = []
    ssim_history = []
    variance_history = []
    x_axis = []
    earlystop = EarlyStop(size=buffer_size,patience=patience)
    def closure(iterator):
        #DIP
        net_input_perturbed = net_input + zeros_like(net_input).normal_(std=reg_noise_std)
        r_img_torch = net(net_input_perturbed)
        total_loss = criterion(r_img_torch, img_noisy)
        total_loss.backward()
        loss_history.append(total_loss.item())
        if iterator % show_every == 0:
            # evaluate recovered image (PSNR, SSIM)
            r_img_np = torch_to_np(r_img_torch)
            psnr = skimage.metrics.peak_signal_noise_ratio(img_np, r_img_np)
            temp_img_np = np.transpose(img_np,(1,2,0))
            temp_r_img_np = np.transpose(r_img_np,(1,2,0))
            data_range = temp_img_np.max() - temp_img_np.min()
            ssim = skimage.metrics.structural_similarity(temp_img_np, temp_r_img_np, multichannel=True, win_size=7, channel_axis=-1, data_range=data_range)
            psnr_history.append(psnr)
            ssim_history.append(ssim)

            #variance hisotry
            r_img_np = r_img_np.reshape(-1)
            earlystop.update_img_collection(r_img_np)
            img_collection = earlystop.get_img_collection()
            if len(img_collection) == buffer_size:
                ave_img = np.mean(img_collection,axis = 0)
                variance = []
                for tmp in img_collection:
                    variance.append(MSE(ave_img, tmp))
                cur_var = np.mean(variance)
                cur_epoch = iterator
                variance_history.append(cur_var)
                x_axis.append(cur_epoch)
                if earlystop.stop == False:
                    earlystop.stop = earlystop.check_stop(cur_var, cur_epoch)
        return total_loss
        
    for iterator in range(num_iter):
        optimizer.zero_grad()
        closure(iterator)
        optimizer.step()

model = EncoderDecoder()
test_space(model)

In [None]:
dtype = cuda.FloatTensor
buffer_size = 100
patience = 1000
lr = 0.01
num_iter = 3000

device = torch_device('cuda' if cuda.is_available() else "cpu")

noise_type = 'gaussian'
img_np = generate_phantom(resolution=6)
img = phantom_to_torch(img_np)
img_np = img.squeeze().numpy()
img_noisy = add_selected_noise(img, noise_type=noise_type).to(device)
img_noisy_np = torch_to_np(img_noisy.squeeze())
img_noisy_np = phantom_to_torch(img_noisy_np).squeeze().numpy()

net_input = get_noise(input_depth=3, spatial_size=img.size()[3], noise_type=noise_type).to(device)

print(f'img_np.shape: {img_np.shape}')
print(f'img shape: {img.shape}')
print(f'img_noisy_np.shape: {img_noisy_np.shape}')
print(f'net_input shape: {net_input.shape}')



# second attempt

In [None]:
import skimage
import numpy as np
import torch
import torch.nn as nn
from torch import cuda, optim, tensor, zeros_like
from torch import device as torch_device
from torch.nn import L1Loss, MSELoss
from matplotlib import pyplot as plt


from darts.common_utils import *
from darts.phantom import generate_phantom, phantom_to_torch
from darts.noises import add_selected_noise

class EarlyStop():
    def __init__(self, size, patience):
        self.patience = patience
        self.wait_count = 0
        self.best_score = float('inf')
        self.best_epoch = 0
        self.img_collection = []
        self.stop = False
        self.size = size

    def check_stop(self, current, cur_epoch):
      #stop when variance doesn't decrease for consecutive P(patience) times
        if current < self.best_score:
            self.best_score = current
            self.best_epoch = cur_epoch
            self.wait_count = 0
            should_stop = False
        else:
            self.wait_count += 1
            should_stop = self.wait_count >= self.patience
        return should_stop

    def update_img_collection(self, cur_img):
        self.img_collection.append(cur_img)
        if len(self.img_collection) > self.size:
            self.img_collection.pop(0)

    def get_img_collection(self):
        return self.img_collection

def MSE(x1, x2):
    return ((x1 - x2) ** 2).sum() / x1.size

def MAE(x1, x2):
    return (np.abs(x1 - x2)).sum() / x1.size


class EncoderDecoder(nn.Module):
    def __init__(self):
        super(EncoderDecoder, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1),  # 16 x 32 x 32
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 32 x 16 x 16
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 64 x 8 x 8
            nn.ReLU()
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), # 32 x 16 x 16
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1), # 16 x 32 x 32
            nn.ReLU(),
            nn.ConvTranspose2d(16, 3, kernel_size=4, stride=2, padding=1),  # 3 x 64 x 64
            nn.Sigmoid()  # To bring the output values between 0 and 1
        )
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# Example usage
model = EncoderDecoder()
input_tensor = torch.randn(1, 3, 64, 64)
output = model(input_tensor)
print(output.shape)  # torch.Size([1, 3, 64, 64])

In [None]:
dtype = cuda.FloatTensor
buffer_size = 100
patience = 1000
lr = 0.0005
num_iter = 3000
resolution=8

device = torch_device('cuda' if cuda.is_available() else "cpu")

noise_type = 'gaussian'

raw_img_np = generate_phantom(resolution=resolution) # 64x64 np array
print(f'raw_img_np.shape: {raw_img_np.shape}')

img_np = raw_img_np.reshape(1, raw_img_np.shape[0], raw_img_np.shape[1]) # 1x64x64 np array
img_np = np.repeat(img_np, 3, axis=0) # 3x64x64 np array
img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
print(f'img_np.shape: {img_np.shape}')

img_torch = phantom_to_torch(raw_img_np).unsqueeze(0) # 1x3x64x64 torch tensor
print(f'img_torch.shape: {img_torch.shape}')

img_noisy_torch = add_selected_noise(img_torch, noise_type=noise_type,noise_factor=0.15) # 1x3x64x64 torch tensor
print(f'img_torch_noisy.shape: {img_noisy_torch.shape}')

img_noisy_np = img_noisy_torch.squeeze(0).numpy() # 3x64x64 np array
print(f'raw_img_noisy_np.shape: {img_noisy_np.shape}')

raw_img_noisy_np = np.mean(img_noisy_np, axis=0) # 64x64 np array
print(f'raw_img_noisy_np_reduced.shape: {raw_img_noisy_np.shape}')


img_noisy_torch = img_noisy_torch.to(device)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
fig.suptitle('Clean and Noisy Image')
ax1.imshow(raw_img_np, cmap='gray')
ax2.imshow(raw_img_noisy_np, cmap='gray')
ax1.axis('off')
ax2.axis('off')
plt.show()
plt.close()

net_input = get_noise(input_depth=3, spatial_size=raw_img_np.shape[1], noise_type=noise_type).type(dtype).to(device)

model = EncoderDecoder()

# Add synthetic noise
net = model.to(device)
net = net.type(dtype)

# Loss
criterion = MSELoss().type(dtype).to(device)

# Optimizer
p = get_params('net', net, net_input)  # network parameters to be optimized
optimizer = optim.Adam(p, lr=lr)

# Optimize

# reg_noise_std = 1./30. 
reg_noise_std = tensor(1./30.).type(dtype).to(device)
show_every = 1
loss_history = []
psnr_history = []
ssim_history = []
variance_history = []
x_axis = []
earlystop = EarlyStop(size=buffer_size,patience=patience)
def closure(iterator):
    #DIP
    net_input_perturbed = net_input + zeros_like(net_input).normal_(std=reg_noise_std)
    r_img_torch = net(net_input_perturbed)
    total_loss = criterion(r_img_torch, img_noisy_torch)
    total_loss.backward()
    loss_history.append(total_loss.item())
    if iterator % show_every == 0:
        # evaluate recovered image (PSNR, SSIM)
        r_img_np = torch_to_np(r_img_torch)
        psnr = skimage.metrics.peak_signal_noise_ratio(img_np, r_img_np)
        temp_img_np = np.transpose(img_np,(1,2,0))
        temp_r_img_np = np.transpose(r_img_np,(1,2,0))
        data_range = temp_img_np.max() - temp_img_np.min()
        ssim = skimage.metrics.structural_similarity(temp_img_np, temp_r_img_np, multichannel=True, win_size=7, channel_axis=-1, data_range=data_range)
        psnr_history.append(psnr)
        ssim_history.append(ssim)

        #variance hisotry
        r_img_np = r_img_np.reshape(-1)
        earlystop.update_img_collection(r_img_np)
        img_collection = earlystop.get_img_collection()
        if len(img_collection) == buffer_size:
            ave_img = np.mean(img_collection,axis = 0)
            variance = []
            for tmp in img_collection:
                variance.append(MSE(ave_img, tmp))
            cur_var = np.mean(variance)
            cur_epoch = iterator
            variance_history.append(cur_var)
            x_axis.append(cur_epoch)
            if earlystop.stop == False:
                earlystop.stop = earlystop.check_stop(cur_var, cur_epoch)
    return total_loss
    
for iterator in range(num_iter):
    optimizer.zero_grad()
    closure(iterator)
    optimizer.step()
    
    if iterator % show_every == 0:
        r_img_np = torch_to_np(net(net_input))
        plot_image_grid([np.clip(img_np, 0, 1), np.clip(r_img_np, 0, 1)], factor=1, nrow=1)
