## Code used for producing the results of the EL image section. Only works on a computer that has a GPU

# U-net training

In [None]:
import numpy as np
from fastai.vision.all import *
from skimage.util import random_noise
from skimage.metrics import structural_similarity as ssim
import cv2

random.seed(42) # setting random seed

In [None]:
# Please change these file names to the files with your images
path_PV_images = Path("../../Data")
path_poissongaussian_noisy = path_PV_images/"PoissonGaussianDirty"
path_poisson_noisy = path_PV_images/"PoissonDirty"
path_clean = path_PV_images/"CleanFull"

In [None]:
# Metrics
def PSNR(gt, image, max_value=1):
    """Function to calculate peak signal-to-noise ratio (PSNR) between two images."""
    mse = MSELossFlat() # calculate mean square error
    mse = mse(gt, image) 
    return 20 * torch.log10(max_value / (torch.sqrt(mse)))

def SSIM(gt, image):
    """Function to calculate the SSIM between two images"""
    return ssim(gt.cpu().numpy(), image.cpu().numpy(), multichannel=1, win_size=3) # changed multichanel from 3 to 1

In [None]:
# Define gaussian noise function and create DataBlock
class TensorImageBWInput(TensorImage): pass
class PILImageBWInput(PILImageBW): pass
PILImageBWInput._tensor_cls = TensorImageBWInput

class AddGaussianNoise(RandTransform):
    def __init__(self, mean=0., var=1., **kwargs):
        self.var = var
        self.mean = mean
        super().__init__(**kwargs)
        
    def encodes(self, x:TensorImageBWInput):
        # Notes: random_noise converts tensor image x into floating point [0, 1] (normalisation)
        # then adds on the random noise and returns the image in floating point [0, 1]
        return TensorImage(random_noise(x.cpu(), mode="gaussian", mean=self.mean, var=self.var))*255 # removed .cuda()
        # Multiplied with 255 to bring it back from [0, 1] to [0, 255]

In [None]:
def save_image(img_tensor, path):
    img = PILImage.create(img_tensor) # Create PILImage from tensor
    img.save(path) # Save image to path destination

In [None]:
## HYPERPARAMETERS/SETTINGS ----------------------------------------------------
arch = models.resnet34
wd = 1e-3
y_range = (-3.,3.)
loss_func = MSELossFlat()
item_tfms = [RandomResizedCrop(520, min_scale=0.5)]
batch_tfms = [AddGaussianNoise(0, random.uniform(0.0001, 0.001))]
bs = 4
num_of_cycles = 20
save_name = 'LocalRandomCrop20CyclesResnet34'
model_save_path = F"Models/{save_name}"
image_save_path = F"Data/Results/{save_name}.png"

In [None]:
def create_data_block(item_tfms):
    return DataBlock(blocks=(ImageBlock(cls=PILImageBWInput), ImageBlock(cls=PILImageBW)), # x and y are bw images
              get_items=get_image_files,
              get_y=lambda o: path_clean/f'{o.stem}{o.suffix}', # y label is clean original image
              splitter=RandomSplitter(valid_pct=0.2, seed=42), # split 80% training and 20% valid
              item_tfms=item_tfms, # crop origianlly 224
              batch_tfms=batch_tfms) 

def create_dataloader(item_tfms):
    db = create_data_block(item_tfms)
    dl = db.dataloaders(path_poisson_noisy, bs=bs)
    dl.c = 3
    dl.show_batch(cmap='gray') # need 'gray' because cls=PILImageBW
    return dl

def create_unet(dl):
    return unet_learner(dl, arch, wd=wd, blur=True, norm_type=NormType.Weight,
                          loss_func=loss_func, metrics=[PSNR, SSIM])

In [None]:
dl = create_dataloader(item_tfms)

In [None]:
learner = create_unet(dl)

In [None]:
learner.fit_one_cycle(num_of_cycles, pct_start=0.8)
# For the experimental validation on the PL images, the only difference is the addition of learner.fine_tune(3)

In [None]:
# Save model
learner.save(model_save_path)

In [None]:
# Reload model with no random resize crop
dl_img = create_dataloader(None)
# Load model from saved file
learner = create_unet(dl_img)
learner.load(model_save_path)

In [None]:
# Run to produce an example of an denoised image
img, b, c = learner.predict(torch.tensor(np.array(PILImage.create("EL noisy")))) # Get tensor for dirty image
save_image(img, image_save_path)

# BM3D - traditional denoising

In [None]:
from bm3d import bm3d
bm3d(dl.valid_ds.items)