In [1]:
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt
from tqdm import tqdm 
import os
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader 
from PIL import Image
import itertools

In [2]:
path = "/net/ens/am4ip/datasets/project-dataset"
rainyImagesPath = path + '/rainy_images'
rainySsegPath = path + '/rainy_sseg'
sunnyImagesPath = path + '/sunny_images'
sunnySsegPath = path + '/sunny_sseg'

In [None]:
num_images = 10

# Initialize a list to store the images
images_list = []

# Iterate over the first `num_images` files in the folder
for i, filename in enumerate(sorted(os.listdir(sunnyImagesPath))):
    if filename.endswith(".png") or filename.endswith(".jpg"):  # Add other extensions if necessary
        # Load the image
        img_path = os.path.join(sunnyImagesPath, filename)
        image = Image.open(img_path)
        resized_image = image.resize((320, 480))
        grayscale_image = resized_image.convert("L")
        img = np.array(grayscale_image)

        images_list.append(np.array(img))  # Convert to numpy array and append to list
        if len(images_list) >= num_images:  # Stop after reading `num_images`
            break

# Stack the images to create a 3D numpy array
images_array = np.stack(images_list, axis=0)  # Shape: (10, n, n)

print(f"Images array shape: {images_array.shape}")

In [4]:
def patchig(data, patchsize=[64, 64], step=[16, 16]):
    # find starting indices
    x = np.arange(0, data.shape[0] - patchsize[0], step=step[0])
    y = np.arange(0, data.shape[1] - patchsize[1], step=step[1])
    TopLefts = list(itertools.product(x, y))

    print('Extracting %i patches' % len(TopLefts))

    patches = np.zeros([len(TopLefts), patchsize[0], patchsize[1]])

    for i, pi in enumerate(TopLefts):
        patches[i] = data[pi[0]:pi[0]+patchsize[0], pi[1]:pi[1]+patchsize[1]]

    return patches

In [5]:
noisydata = images_array

In [None]:
temp = []

for noisy in noisydata:
    noisyPatches = patchig(noisy,  patchsize=[32, 32], step=[20,20])
    temp.append(noisyPatches)

noisyPatches = np.concatenate(temp, axis = 0)

In [None]:
fig, axs = plt.subplots(3,6,figsize=[15,7])
for i in range(6*3):
    axs.ravel()[i].imshow(noisyPatches[i])
fig.tight_layout()

In [None]:
noisyPatches.shape

In [9]:
def modifyActivePixels(patch, numActivePixels, neighbourhoodRadius=5):
    radius = neighbourhoodRadius

    # Select active pixel locations
    activeXCoords = np.random.randint(0, patch.shape[0], numActivePixels)
    activeYCoords = np.random.randint(0, patch.shape[1], numActivePixels)
    activePixelIndices = (activeXCoords, activeYCoords)
    
    # Select neighbouring pixel locations
    # Compute shift for neighbouring pixels
    xShift = np.random.randint(-radius // 2 + radius % 2, radius // 2 + radius % 2, numActivePixels)
    yShift = np.random.randint(-radius // 2 + radius % 2, radius // 2 + radius % 2, numActivePixels)
    
    # Ensure no replacement with itself
    for i in range(len(xShift)):
        if xShift[i] == 0 and yShift[i] == 0:
            shiftOptions = np.trim_zeros(np.arange(-radius // 2 + 1, radius // 2 + 1))
            xShift[i] = np.random.choice(shiftOptions[shiftOptions != 0], 1)

    # Find coordinates of neighbouring pixels
    neighbourXCoords = activeXCoords + xShift
    neighbourYCoords = activeYCoords + yShift
    # Wrap indices within patch bounds
    neighbourXCoords = neighbourXCoords + (neighbourXCoords < 0) * patch.shape[0] - (neighbourXCoords >= patch.shape[0]) * patch.shape[0]
    neighbourYCoords = neighbourYCoords + (neighbourYCoords < 0) * patch.shape[1] - (neighbourYCoords >= patch.shape[1]) * patch.shape[1]
    neighbourPixelIndices = (neighbourXCoords, neighbourYCoords)
    
    # Replace active pixel values with neighbours
    modifiedPatch = patch.copy()
    modifiedPatch[activePixelIndices] = patch[neighbourPixelIndices]
    
    # Create active pixel mask
    activePixelMask = np.ones_like(patch)
    activePixelMask[activePixelIndices] = 0.

    return modifiedPatch, activePixelMask


In [10]:
crpt_patch, mask = modifyActivePixels(noisyPatches[6], numActivePixels=10, neighbourhoodRadius=5)

In [11]:
def plot_corruption(noisy,crpt,mask,seismic_cmap='RdBu',vmin=-0.25,vmax=0.25):
    fig,axs = plt.subplots(1,3,figsize=[15,5])
    axs[0].imshow(noisy)
    axs[1].imshow(crpt)
    axs[2].imshow(mask)

    axs[0].set_title('Original')
    axs[1].set_title('Corrupted')
    axs[2].set_title('Corruption Mask')
    
    fig.tight_layout()
    return fig,axs

In [None]:
fig,axs = plot_corruption(noisyPatches[6], crpt_patch, mask)

In [13]:
# Set the percentage of active pixels per patch
percentActive = 2

# Calculate the total number of pixels in a patch
totalPixels = noisyPatches[0].shape[0] * noisyPatches[0].shape[1]

# Determine the number of active pixels based on the chosen percentage
numActivePixels = int(np.floor((totalPixels / 100) * percentActive))

# Apply the pre-processing function with the selected values
corruptedPatch, mask = modifyActivePixels(noisyPatches[6], numActivePixels=numActivePixels, neighbourhoodRadius=5)


In [14]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [15]:
from myUnet import UNet

model = UNet(input_channels=1, output_channels=1, hidden_channels=32, depth=2).to(device)   # grey

In [16]:
lr = 0.0001  # Learning rate
criterion = nn.MSELoss()  # Loss function
optimizer = torch.optim.Adam(model.parameters(), lr=lr)  # Optimiser

In [17]:
numEpochs = 100  # most recommend 150-200 for random noise suppression 

# Choose number of training and validation samples
numTrainingSamples = 2048
numTestSamples = 512

# Choose the batch size for the models training
batchSize = 128

In [None]:
# Initialise arrays to keep track of train and validation metrics
trainLossHistory = np.zeros(numEpochs)
trainAccuracyHistory = np.zeros(numEpochs)
testLossHistory = np.zeros(numEpochs)
testAccuracyHistory = np.zeros(numEpochs)

# For reproducibility
g = torch.Generator()
g.manual_seed(0)

In [19]:
def trainModel(model, lossFunction, optim, dataLoader, device):
    model.train()
    avgAccuracy = 0
    avgLoss = 0

    for batch in tqdm(dataLoader):
        features, labels, exclusionMask = (
            batch[0].to(device),
            batch[1].to(device),
            batch[2].to(device)
        )
        optim.zero_grad()
        predictedProbs = model(features)
        batchLoss = lossFunction(predictedProbs * (1 - exclusionMask), labels * (1 - exclusionMask))
        batchLoss.backward()
        
        optim.step()
        with torch.no_grad():
            detachedPredictions = predictedProbs.detach().cpu().numpy().astype(float)

        avgLoss += batchLoss.item()
        avgAccuracy += np.sqrt(np.mean((labels.cpu().numpy().ravel() - detachedPredictions.ravel())**2))
        
    avgLoss /= len(dataLoader)
    avgAccuracy /= len(dataLoader)

    return avgLoss, avgAccuracy


In [20]:
def evaluateModel(model, lossFunction, optim, dataLoader, device):
    model.eval()
    avgAccuracy = 0
    avgLoss = 0

    for batchData in tqdm(dataLoader):
        inputFeatures, targetLabels, maskMatrix = (
            batchData[0].to(device),
            batchData[1].to(device),
            batchData[2].to(device)
        )
        optim.zero_grad()
        
        predictedOutput = model(inputFeatures)

        with torch.no_grad():            
            batchLoss = lossFunction(predictedOutput * (1 - maskMatrix), targetLabels * (1 - maskMatrix))
            predictedValues = (predictedOutput.detach().cpu().numpy()).astype(float)
        
        avgLoss += batchLoss.item()  
        avgAccuracy += np.sqrt(np.mean((targetLabels.cpu().numpy().ravel() - predictedValues.ravel())**2))  
        
    avgLoss /= len(dataLoader)  
    avgAccuracy /= len(dataLoader)  

    return avgLoss, avgAccuracy

In [21]:
def createDataLoaders(noisyPatches, corruptedPatches, maskArray, numTrain, numTest, batchSize, torchGen):
    trainInputs = np.expand_dims(corruptedPatches[:numTrain], axis=1)
    trainTargets = np.expand_dims(noisyPatches[:numTrain], axis=1)
    trainMasks = np.expand_dims(maskArray[:numTrain], axis=1)
    trainingDataset = TensorDataset(
        torch.from_numpy(trainInputs).float(),
        torch.from_numpy(trainTargets).float(),
        torch.from_numpy(trainMasks).float()
    )

    testInputs = np.expand_dims(corruptedPatches[numTrain:numTrain + numTest], axis=1)
    testTargets = np.expand_dims(noisyPatches[numTrain:numTrain + numTest], axis=1)
    testMasks = np.expand_dims(maskArray[numTrain:numTrain + numTest], axis=1)
    testDataset = TensorDataset(
        torch.from_numpy(testInputs).float(),
        torch.from_numpy(testTargets).float(),
        torch.from_numpy(testMasks).float()
    )

    # Create DataLoader
    trainingLoader = DataLoader(trainingDataset, batch_size=batchSize, shuffle=True, generator=torchGen)
    testingLoader = DataLoader(testDataset, batch_size=batchSize, shuffle=False)

    return trainingLoader, testingLoader

In [None]:
for epoch in range(numEpochs):
    # Randomly corrupt the noisy patches
    corruptedPatches = np.zeros_like(noisyPatches)
    patchMasks = np.zeros_like(corruptedPatches)
    for patchIndex in range(len(noisyPatches)):
        corruptedPatches[patchIndex], patchMasks[patchIndex] = modifyActivePixels(noisyPatches[patchIndex], numActivePixels=int(numActivePixels), neighbourhoodRadius=5)

    # Create data loaders using the predefined function
    trainLoader, testLoader = createDataLoaders(noisyPatches, corruptedPatches, patchMasks, numTrainingSamples, numTestSamples, batchSize=batchSize, torchGen=g)

    # Train the model
    trainLoss, trainAccuracy = trainModel(model=model, lossFunction=criterion, optim=optimizer, dataLoader=trainLoader, device=device)
    trainLossHistory[epoch], trainAccuracyHistory[epoch] = trainLoss, trainAccuracy

    # Evaluate the model (validation)
    testLoss, testAccuracy = evaluateModel( model=model, lossFunction=criterion, optim=optimizer, dataLoader=testLoader, device=device)
    testLossHistory[epoch], testAccuracyHistory[epoch] = testLoss, testAccuracy

    # Print training progress
    print(f"""Epoch {epoch}, Training Loss: {trainLoss:.4f}, Training Accuracy: {trainAccuracy:.4f},  Test Loss: {testLoss:.4f}, Test Accuracy: {testAccuracy:.4f}""")


In [23]:
def plot_training_metrics(trainAccuracyHistory, testAccuracyHistory, trainLossHistory, testLossHistory):
    fig,axs = plt.subplots(1,2,figsize=(15,4))
    
    axs[0].plot(trainAccuracyHistory, 'r', lw=2, label='train')
    axs[0].plot(testAccuracyHistory, 'k', lw=2, label='validation')
    axs[0].set_title('RMSE', size=16)
    axs[0].set_ylabel('RMSE', size=12)

    axs[1].plot(trainLossHistory, 'r', lw=2, label='train')
    axs[1].plot(testLossHistory, 'k', lw=2, label='validation')
    axs[1].set_title('Loss', size=16)
    axs[1].set_ylabel('Loss', size=12)
    
    for ax in axs:
        ax.legend()
        ax.set_xlabel('# Epochs', size=12)
    fig.tight_layout()
    return fig,axs

In [None]:
fig,axs = plot_training_metrics(trainAccuracyHistory, testAccuracyHistory, trainLossHistory, testLossHistory)

In [None]:
testdata = noisydata[5]
print(testdata.shape)
# Convert dataset in tensor for prediction purposes
torch_testdata = torch.from_numpy(np.expand_dims(np.expand_dims(testdata,axis=0),axis=0)).float()
print(torch_testdata.shape)

In [26]:
model.eval()
test_prediction = model(torch_testdata.to(device))

# Return to numpy for plotting purposes
test_pred = test_prediction.detach().cpu().numpy().squeeze()

In [None]:
plt.figure(figsize=(10, 8))
plt.imshow(test_pred, cmap='gray')
plt.title('Test Prediction')
plt.colorbar()
plt.show()

In [29]:
torch.save(model, 'denoisingModel.pth')