___
### Import libraries and sub-libraries.

In [26]:
import os
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
from torchvision.transforms import functional as TF
import tifffile 

___
### Calling a custom code to change the default font for figures to `Computer Modern`. (Optional)

In [None]:
# from fontsetting import font_cmu
# plt = font_cmu(plt)

___
### Check the hardware that is at your disposal

In [28]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device available:', device)
is_cuda = device.type == 'cuda'

Device available: cpu


___
### Read training data from `data/train-clean-tif`


In [30]:
# Loading TIFF images
class TIFFDataset(Dataset):
    def __init__(self, directory, transform=None):
        self.directory = directory
        self.transform = transform
        self.filenames = [f for f in os.listdir(directory) if f.endswith('.tif')]

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img_path = os.path.join(self.directory, self.filenames[idx])
        image = Image.open(img_path)
        if self.transform: # Dynamically apply data transformation
            image = self.transform(image)
        return image

# Create a transform to convert the images to PyTorch tensors
transform = transforms.Compose([
    transforms.ToTensor()
])

train_transform = transforms.Compose([
    transforms.RandomAffine(degrees=0,translate=(0.1,0.1)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor()
])

# Create the dataset for training images
train_dataset = TIFFDataset('train-clean-tif', transform=train_transform)
val_dataset = TIFFDataset('val-clean-tif', transform=transform) # Create the dataset for validation images

# Function to create data loader
def create_loader(train_dataset, batch_size):
    torch.manual_seed(0)  # For reproducibility of random numbers in PyTorch
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)  # Creates a training DataLoader from this Dataset
    return train_loader

dataset_size = len(train_dataset), len(val_dataset)
print('Number of images in the training dataset:', dataset_size[0])
print('Number of images in the validation dataset:', dataset_size[1])

Number of images in the training dataset: 52
Number of images in the validation dataset: 16


### Define a denoising network

#### Here, I have defined a trivial network, which has only one convolutional layer and no activation function. We are essentially doing linear filtering.

In [32]:
class TrivialNet(nn.Module):
    def __init__(self):
        super(TrivialNet, self).__init__()
        self.relu = nn.ReLU()
        initial_weights = torch.tensor([
            [[1e-6, 1e-6, 1e-6],
            [1e-6, 1, 1e-6],
            [1e-6, 1e-6, 1e-6]]
        ], dtype=torch.float32) #initial weights for the skip layers, 1e-6 is close to 0
        
        # Encoder Section
        # layer 1
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1) # output 256x256
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1) # output 256x256
        
        #Performing convolution and batch normalization on first skip layer
        self.conv_skip1 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=1, padding=1)
        self.conv_skip_bn1 = nn.BatchNorm2d(128)
        self.conv_skip2 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=1, padding=1)
        self.conv_skip_bn2 = nn.BatchNorm2d(256)
        self.conv_skip3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1, padding=1)
        self.conv_skip_bn3 = nn.BatchNorm2d(256)
        self.conv_skip4 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1, padding=1)
        self.conv_skip_bn4 = nn.BatchNorm2d(128)
        self.conv_skip5 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=1, padding=1)
        self.conv_skip_bn5 = nn.BatchNorm2d(64)

        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # output 128x128
        self.bn1 = nn.BatchNorm2d(64)
        self.dropout1 = nn.Dropout(0.2)

        # layer 2
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1) # output 128x128
        self.conv4 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1) # output 128x128
        
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # output 64x64
        self.bn2 = nn.BatchNorm2d(128)
        self.dropout2 = nn.Dropout(0.2)

        # layer 3
        self.conv5 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1) # output 64x64
        self.conv6 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1) # output 64x64

        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) # output 32x32
        self.bn3 = nn.BatchNorm2d(256)
        self.dropout3 = nn.Dropout(0.2)

        # layer 4
        self.conv7 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1) # output 32x32
        self.conv8 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1) # output 32x32

        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) # output 16x16
        self.bn4 = nn.BatchNorm2d(512)
        self.dropout4 = nn.Dropout(0.2)
        
        # layer 5
        self.conv9 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=1) # output 16x16
        self.conv10 = nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, padding=1) # output 16x16

        # Decoder Section
        self.up1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2) # output 32x32
        self.bn5 = nn.BatchNorm2d(512)
        self.dropout5 = nn.Dropout(0.2)

        # layer 4
        self.conv11 = nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, padding=1) # output 32x32; still using 1024 channels because of concatenation
        self.conv12 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1) # output 32x32
        
        self.up2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2) # output 64x64
        self.bn6 = nn.BatchNorm2d(256)
        self.dropout6 = nn.Dropout(0.2)

        # layer 3
        self.conv13 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1) # output 64x64; still using 512 channels because of concatenation
        self.conv14 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1) # output 64x64

        self.up3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2) # output 128x128
        self.bn7 = nn.BatchNorm2d(128)
        self.dropout7 = nn.Dropout(0.2)

        # layer 2
        self.conv15 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1) # output 128x128; still using 256 channels because of concatenation
        self.conv16 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1) # output 128x128

        self.up4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2) # output 256x256
        self.bn8 = nn.BatchNorm2d(64)
        self.dropout8 = nn.Dropout(0.2)

        # layer 1
        self.conv17 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1) # output 256x256; still using 128 channels because of concatenation
        self.conv18 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1) # output 256x256

        # Output layer
        self.conv19 = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, padding=0) # output 256x256



    def forward(self, x):
        x11 = self.relu(self.conv1(x))
        x12 = self.relu(self.conv2(x11))

        x_skip1 = self.relu(self.conv_skip1(x12))
        x_skip1 = self.conv_skip_bn1(x_skip1)
        x_skip2 = self.relu(self.conv_skip2(x_skip1))
        x_skip2 = self.conv_skip_bn2(x_skip2)
        x_skip3 = self.relu(self.conv_skip3(x_skip2))
        x_skip3 = self.conv_skip_bn3(x_skip3)
        x_skip4 = self.relu(self.conv_skip4(x_skip3))
        x_skip4 = self.conv_skip_bn4(x_skip4)
        x_skip5 = self.relu(self.conv_skip5(x_skip4))
        x_skip5 = self.conv_skip_bn5(x_skip5)

        
        x2 = self.bn1(self.pool1(x_skip5))
        x21 = self.relu(self.conv3(x2))
        x22 = self.relu(self.conv4(x21))

        x3 = self.bn2(self.pool2(x22))
        x31 = self.relu(self.conv5(x3))
        x32 = self.relu(self.conv6(x31))

        x4 = self.bn3(self.pool3(x32))
        x41 = self.relu(self.conv7(x4))
        x42 = self.relu(self.conv8(x41))

        x5 = self.bn4(self.pool4(x42))
        x51 = self.relu(self.conv9(x5))
        x52 = self.relu(self.conv10(x51))

        x6 = self.bn5(self.up1(x52))
        x6 = torch.cat((x6, x42), dim=1)
        x61 = self.relu(self.conv11(x6))
        x62 = self.relu(self.conv12(x61))

        x7 = self.bn6(self.up2(x62))
        x7 = torch.cat((x7, x32), dim=1)
        x71 = self.relu(self.conv13(x7))
        x72 = self.relu(self.conv14(x71))

        x8 = self.bn7(self.up3(x72))
        x8 = torch.cat((x8, x22), dim=1)
        x81 = self.relu(self.conv15(x8))
        x82 = self.relu(self.conv16(x81))

        x9 = self.bn8(self.up4(x82))
        x9 = torch.cat((x9, x_skip5), dim=1)
        x91 = self.relu(self.conv17(x9))
        x92 = self.relu(self.conv18(x91))

        x = self.conv19(x92)
        
        return x

___
### Create a function to execute training. Note, we will call this function later.

In [34]:
def train_model(model, opt, criterion, train_loader, test_loader, num_epoch, noise_std, avg_train_losses=[], avg_test_losses=[], epoch=0):

    for epoch in range(epoch, num_epoch): # Loop over the dataset multiple times
        model.train()
        total_train_loss = 0
        for i, y_tr_batch in enumerate(train_loader): # Loop over mini-batches
            if is_cuda:
                y_tr_batch = y_tr_batch.to(device) #GPU STUFF
            x_tr_batch = y_tr_batch.clone()
            # implement data augmentation
            # flip = torch.randint(0,1,(1,)).item()
            # if flip:
            #     x_tr_batch = x_tr_batch.flip(2) # flip the image horizontally
            # rotate = torch.randint(-10, 10,(1,)).item()
            # translate = torch.randint(-10, 10, (2,)).tolist()
            # x_tr_batch = TF.affine(x_tr_batch, angle=rotate, translate=(translate), scale=1, shear=0)
            noise = torch.randn_like(y_tr_batch) * noise_std
            x_tr_batch = x_tr_batch + noise

            # insert translation, rotation, and flipping
            opt.zero_grad() # delete previous gradients
            y_hat_tr_batch = model(x_tr_batch) # forward pass
            loss = criterion(y_hat_tr_batch, y_tr_batch) # compute loss
            loss.backward() # backward pass
            opt.step() # update weights
            total_train_loss += loss.item() # accumulate loss
            # if (i + 1) % 10 == 0:
                # print(f'Epoch {epoch+1}, Iteration {i+1}, Loss: {loss.item():.6f}')

        avg_train_loss = total_train_loss / len(train_loader) # compute average loss
        avg_train_losses.append(avg_train_loss) # accumulate average loss
        
        # Testing
        model.eval()
        with torch.no_grad():
            total_test_loss = 0
            for i, y_te_batch in enumerate(test_loader):
                if is_cuda:
                    y_te_batch= y_te_batch.to(device)#GPU STUFF
                noise = torch.randn_like(y_te_batch) * noise_std
                x_te_batch = y_te_batch + noise
                y_hat_te_batch = model(x_te_batch)
                loss = criterion(y_hat_te_batch, y_te_batch)
                total_test_loss += loss.item()
            avg_test_loss = total_test_loss / len(test_loader)
            print(f'Epoch {epoch+1}, Test Loss: {avg_test_loss:.6f}, Train Loss: {avg_train_loss:.6f}')
            avg_test_losses.append(avg_test_loss)

        # Save the model and optimizer state
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': opt.state_dict(),
                'avg_train_losses': avg_train_losses[0:epoch],
                'avg_test_losses': avg_test_losses[0:epoch]
            }, 'model_checkpoint.pt')

    print('Length of Training loss', len(avg_test_losses))
    print('Length of Testing loss:', len(avg_test_losses))
    # Plotting
    fig, ax = plt.subplots(figsize=(8, 5))
    ax.plot(range(1, num_epoch+1), avg_train_losses, label='training loss')
    ax.plot(range(1, num_epoch+1), avg_test_losses, label='testing loss')
    ax.plot()
    ax.set_xlabel('epochs')
    ax.set_ylabel('NMSE loss')
    # ax.set_yscale('log')  # Set the vertical axis to log scale
    ax.set_title('training loss')
    ax.legend(['training accuracy', 'test accuracy'])
    ax.grid(True)
    ax.legend()
    plt.show()

___
### Now, let us define hyperparameters and train the network. 

#### Note, in addition to the parameters that controls the network architecture or the training process, you need to select/initialize (i) a data loader, (ii) a model, (iii) an optimizer, and (iv) a loss function.

In [36]:
batch_size = 2  # Number of complete images in each batch
lr = 1e-3  # Learning rate
sig = 0.1  # Noise std
num_epoch = 200  # Epochs

# Create a test loader (using validation for testing)
test_loader = create_loader(val_dataset, batch_size)

In [38]:
# Initilize the model, criterion, and optimizer

# Model, criterion, and optimizer
train_loader = create_loader(train_dataset, batch_size)
model = TrivialNet() # Pick a model
if is_cuda:
    model = model.to(device)
opt = optim.Adam(model.parameters(), lr=lr) # Pick an optimizer
criterion = nn.MSELoss() # Pick a loss function
# criterion = nmse_loss
avg_train_losses = []
avg_test_losses = []
epoch = 0

In [None]:
# Checks to see if there is a checkpoint file, if there is, it will load the model and optimizer state

# load in checkpoint if it exists and train
if os.path.exists('model_checkpoint.pt'):
    checkpoint = torch.load('model_checkpoint.pt')
    model.load_state_dict(checkpoint['model_state_dict'])
    opt.load_state_dict(checkpoint['optimizer_state_dict'])
    avg_train_losses = checkpoint['avg_train_losses']
    avg_test_losses = checkpoint['avg_test_losses']
    epoch = checkpoint['epoch']
    print('Model loaded from checkpoint at epoch', checkpoint['epoch'])
    print('Length of Training loss', len(avg_test_losses))
    print('Length of Testing loss:', len(avg_test_losses))

In [40]:
# train model (from scratch or a checkpoint)
train_model(model, opt, criterion, train_loader, test_loader, num_epoch, sig, avg_train_losses, avg_test_losses, epoch)

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 32 but got size 33 for tensor number 1 in the list.

___
### Apply it to one of the validation image

In [None]:
val_dataset = TIFFDataset('val-clean-tif', transform=transform) # Create the dataset for validation images
val_clean = val_dataset[0] # Load one clean image from the validation dataset
val_noisy = val_clean + (torch.randn_like(val_clean) * sig) # Add noise to the clean image
# val_denoised = model(val_noisy).detach() # Denoise the noisy image using the trained model
# val_clean_fft = torch.fft.fft2(val_clean)
val_noisy_4d = val_noisy.unsqueeze(0) # Add an extra dimension to represent the batch size
val_denoised = model(val_noisy_4d).detach() # Now pass this 4D tensor to the model

# Remove the batch dimension before further processing
val_denoised = val_denoised.squeeze(0)
val_clean_FFT = np.fft.fftn(val_clean)
val_noisy_fft = torch.fft.fftn(val_noisy)
val_denoised_fft = torch.fft.fftn(val_denoised)


# Your existing code to generate the figure and axes
fig, ax = plt.subplots(3, 3, figsize=(10, 7))

# Plot clean image
ax[0, 0].imshow(np.abs(val_clean).squeeze().numpy(), cmap='gray', vmin=0, vmax=1)
ax[0, 0].set_title('clean image')
ax[0, 0].axis('off')

# Plot noisy image
ax[0, 1].imshow(np.abs(val_noisy).squeeze().numpy(), cmap='gray', vmin=0, vmax=1)
ax[0, 1].set_title('noisy image')
ax[0, 1].axis('off')

# Plot denoised image
ax[0, 2].imshow(np.abs(val_denoised).squeeze().numpy(), cmap='gray', vmin=0, vmax=1)
ax[0, 2].set_title('denoised image')
ax[0, 2].axis('off')

# Plot corresponding error images
ax[1, 0].imshow(3*np.abs(val_clean - val_clean).squeeze().numpy(), cmap='gray', vmin=0, vmax=1)
ax[1, 0].axis('off')
ax[1, 0].text(0.02, 0.98, r'$\times 3$', transform=ax[1, 0].transAxes, fontsize=14, va='top', ha='left', color='white')

ax[1, 1].imshow(3*np.abs(val_clean - val_noisy).squeeze().numpy(), cmap='gray', vmin=0, vmax=1)
ax[1, 1].axis('off')
ax[1, 1].text(0.02, 0.98, r'$\times 3$', transform=ax[1, 1].transAxes, fontsize=14, va='top', ha='left', color='white')

ax[1, 2].imshow(3*np.abs(val_clean - val_denoised).squeeze().numpy(), cmap='gray', vmin=0, vmax=1)
ax[1, 2].axis('off')
ax[1, 2].text(0.02, 0.98, r'$\times 3$', transform=ax[1, 2].transAxes, fontsize=14, va='top', ha='left', color='white')

# Plot the Fourier transform of the clean image
ax[2, 0].imshow(np.log(np.abs(np.fft.fftshift(val_clean_FFT.squeeze()))**2), cmap='gray')
ax[2, 0].axis('off')

# Plot the Fourier transform of the noisy image
ax[2, 1].imshow(np.log(np.abs(np.fft.fftshift(val_noisy_fft.squeeze()))**2), cmap='gray')
ax[2, 1].axis('off')

# Plot the Fourier transform of the denoised image
ax[2, 2].imshow(np.log(np.abs(np.fft.fftshift(val_denoised_fft.squeeze()))**2), cmap='gray')
ax[2, 2].axis('off')


plt.tight_layout()
plt.show()


In [None]:
# define NMSE loss
def nmse_loss(output, target):
    diff = output - target
    diff_norm_squared = torch.sum(diff ** 2) ** (1/2)
    target_norm_squared = torch.sum(target ** 2) ** (1/2)
    nmse = diff_norm_squared / (target_norm_squared + 1e-8)  # Add a small constant to avoid division by zero
    return 20 * torch.log10(nmse)  # Convert to dB

# load in validation dataset
val_dataset = TIFFDataset('val-clean-tif', transform=transform) # Create the dataset for validation images

# Create a data loader for the validation dataset
val_loader = create_loader(val_dataset, batch_size)

# Evaluate performance on the validation set using NMSE
criterion = nmse_loss
model.eval()  # Set the model to evaluation mode
total_nmse = 0
with torch.no_grad():  # Disable gradient calculation for efficiency
    for i, val_batch in enumerate(val_loader):
        noisy_batch = val_batch + (torch.randn_like(val_batch) * sig)
        denoised_batch = model(noisy_batch)
        nmse = criterion(denoised_batch, val_batch)
        total_nmse += nmse.item()
nmse = total_nmse / len(val_loader)

print('Validation Normalized Mean Squared Error:', nmse)
print(min(100-5*(nmse+20.7), 100))