In [None]:
# Copyright 2021, Jason Lequyer and Laurence Pelletier, All rights reserved.
# Sinai Health System Lunenfeld-Tanenbaum Research Institute

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from PIL import Image
import time

# Set this to your input PNG image path
image_path = 'Confocal_BPAE_R_4.png'  # Replace with your image
out_path = image_path.replace('.png', '_N2F.png')

tsince = 100
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class TwoCon(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        return x

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = TwoCon(1, 64)
        self.conv2 = TwoCon(64, 64)
        self.conv3 = TwoCon(64, 64)
        self.conv4 = TwoCon(64, 64)
        self.conv6 = nn.Conv2d(64, 1, 1)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x1)
        x3 = self.conv3(x2)
        x = self.conv4(x3)
        x = torch.sigmoid(self.conv6(x))
        return x

# Load PNG image
img_pil = Image.open(image_path).convert("L")  # Convert to grayscale
img = np.array(img_pil)

typer = type(img[0, 0])
minner = np.amin(img)
img = img - minner
maxer = np.amax(img)
img = img / maxer
img = img.astype(np.float32)
shape = img.shape

# Ensure dimensions are even
Zshape = [shape[0] - shape[0] % 2, shape[1] - shape[1] % 2]
imgZ = img[:Zshape[0], :Zshape[1]]

# Horizontal rearrangement
imgin = np.zeros((Zshape[0]//2, Zshape[1]), dtype=np.float32)
imgin2 = np.zeros((Zshape[0]//2, Zshape[1]), dtype=np.float32)
for i in range(imgin.shape[0]):
    for j in range(imgin.shape[1]):
        if j % 2 == 0:
            imgin[i, j] = imgZ[2*i+1, j]
            imgin2[i, j] = imgZ[2*i, j]
        else:
            imgin[i, j] = imgZ[2*i, j]
            imgin2[i, j] = imgZ[2*i+1, j]

listimgH = [
    torch.unsqueeze(torch.unsqueeze(torch.from_numpy(imgin), 0), 0).to(device),
    torch.unsqueeze(torch.unsqueeze(torch.from_numpy(imgin2), 0), 0).to(device)
]

# Vertical rearrangement
imgin3 = np.zeros((Zshape[0], Zshape[1]//2), dtype=np.float32)
imgin4 = np.zeros((Zshape[0], Zshape[1]//2), dtype=np.float32)
for i in range(imgin3.shape[0]):
    for j in range(imgin3.shape[1]):
        if i % 2 == 0:
            imgin3[i, j] = imgZ[i, 2*j+1]
            imgin4[i, j] = imgZ[i, 2*j]
        else:
            imgin3[i, j] = imgZ[i, 2*j]
            imgin4[i, j] = imgZ[i, 2*j+1]

listimgV = [
    torch.unsqueeze(torch.unsqueeze(torch.from_numpy(imgin3), 0), 0).to(device),
    torch.unsqueeze(torch.unsqueeze(torch.from_numpy(imgin4), 0), 0).to(device)
]

img_torch = torch.unsqueeze(torch.unsqueeze(torch.from_numpy(img), 0), 0).to(device)

# Create pairs
listimg = [
    [listimgH[1], listimgH[0]],
    [listimgH[0], listimgH[1]],
    [listimgV[0], listimgV[1]],
    [listimgV[1], listimgV[0]],
]

net = Net().to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

running_loss = 0.0
maxpsnr = -np.inf
timesince = 0
last10 = [0] * 105
last10psnr = [0] * 105
cleaned = 0

while timesince <= tsince:
    indx = np.random.randint(0, len(listimg))
    inputs, labello = listimg[indx]
    optimizer.zero_grad()
    outputs = net(inputs)
    loss = criterion(outputs, labello)
    running_loss += loss.item()
    loss.backward()
    optimizer.step()

    with torch.no_grad():
        last10.pop(0)
        last10.append(cleaned * maxer + minner)
        outputstest = net(img_torch)
        cleaned = outputstest[0, 0, :, :].cpu().detach().numpy()
        noisy = img_torch[0, 0, :, :].cpu().numpy()
        ps = -np.mean((noisy - cleaned)**2)
        last10psnr.pop(0)
        last10psnr.append(ps)
        if ps > maxpsnr:
            maxpsnr = ps
            outclean = cleaned * maxer + minner
            timesince = 0
        else:
            timesince += 1.0

# Final Output
H = np.mean(last10, axis=0)

# Save as PNG
Image.fromarray(np.round(H).astype(np.uint8)).save(out_path)

torch.cuda.empty_cache()

print(f"Output saved to: {out_path}")


Output saved to: Confocal_BPAE_R_4_N2F.png


In [None]:
from google.colab import files
files.download('Confocal_BPAE_R_4_N2F.png')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>