### UNET Implementation

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from collections import OrderedDict
from torch.nn import init
import numpy as np

def conv3x3(in_channels, out_channels, stride=1,
            padding=1, bias=True, groups=1):
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=3,
        stride=stride,
        padding=padding,
        bias=bias,
        groups=groups)

def upconv2x2(in_channels, out_channels, mode='transpose'):
    if mode == 'transpose':
        return nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size=2,
            stride=2)
    else:
        # out_channels is always going to be the same
        # as in_channels
        return nn.Sequential(
            nn.Upsample(mode='bilinear', scale_factor=2),
            conv1x1(in_channels, out_channels))

def conv1x1(in_channels, out_channels, groups=1):
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=1,
        groups=groups,
        stride=1)


class DownConv(nn.Module):
    """
    A helper Module that performs 2 convolutions and 1 MaxPool.
    A ReLU activation follows each convolution.
    """
    def __init__(self, in_channels, out_channels, pooling=True):
        super(DownConv, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.pooling = pooling

        self.conv1 = conv3x3(self.in_channels, self.out_channels)
        self.conv2 = conv3x3(self.out_channels, self.out_channels)

        if self.pooling:
            self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        before_pool = x
        if self.pooling:
            x = self.pool(x)
        return x, before_pool


class UpConv(nn.Module):
    """
    A helper Module that performs 2 convolutions and 1 UpConvolution.
    A ReLU activation follows each convolution.
    """
    def __init__(self, in_channels, out_channels,
                 merge_mode='concat', up_mode='transpose'):
        super(UpConv, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.merge_mode = merge_mode
        self.up_mode = up_mode

        self.upconv = upconv2x2(self.in_channels, self.out_channels,
            mode=self.up_mode)

        if self.merge_mode == 'concat':
            self.conv1 = conv3x3(
                2*self.out_channels, self.out_channels)
        else:
            # num of input channels to conv2 is same
            self.conv1 = conv3x3(self.out_channels, self.out_channels)
        self.conv2 = conv3x3(self.out_channels, self.out_channels)


    def forward(self, from_down, from_up):
        """ Forward pass
        Arguments:
            from_down: tensor from the encoder pathway
            from_up: upconv'd tensor from the decoder pathway
        """
        from_up = self.upconv(from_up)
        if self.merge_mode == 'concat':
            x = torch.cat((from_up, from_down), 1)
        else:
            x = from_up + from_down
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        return x


from torch.autograd import Variable

class UNet(nn.Module):
    """ `UNet` class is based on https://arxiv.org/abs/1505.04597
    The U-Net is a convolutional encoder-decoder neural network.
    Contextual spatial information (from the decoding,
    expansive pathway) about an input tensor is merged with
    information representing the localization of details
    (from the encoding, compressive pathway).
    Modifications to the original paper:
    (1) padding is used in 3x3 convolutions to prevent loss
        of border pixels
    (2) merging outputs does not require cropping due to (1)
    (3) residual connections can be used by specifying
        UNet(merge_mode='add')
    (4) if non-parametric upsampling is used in the decoder
        pathway (specified by upmode='upsample'), then an
        additional 1x1 2d convolution occurs after upsampling
        to reduce channel dimensionality by a factor of 2.
        This channel halving happens with the convolution in
        the tranpose convolution (specified by upmode='transpose')
    """

    def __init__(self, num_classes, in_channels=1, depth=5,
                 start_filts=64, up_mode='transpose',
                 merge_mode='add'):
        """
        Arguments:
            in_channels: int, number of channels in the input tensor.
                Default is 3 for RGB images.
            depth: int, number of MaxPools in the U-Net.
            start_filts: int, number of convolutional filters for the
                first conv.
            up_mode: string, type of upconvolution. Choices: 'transpose'
                for transpose convolution or 'upsample' for nearest neighbour
                upsampling.
        """
        super(UNet, self).__init__()

        if up_mode in ('transpose', 'upsample'):
            self.up_mode = up_mode
        else:
            raise ValueError("\"{}\" is not a valid mode for "
                             "upsampling. Only \"transpose\" and "
                             "\"upsample\" are allowed.".format(up_mode))

        if merge_mode in ('concat', 'add'):
            self.merge_mode = merge_mode
        else:
            raise ValueError("\"{}\" is not a valid mode for"
                             "merging up and down paths. "
                             "Only \"concat\" and "
                             "\"add\" are allowed.".format(up_mode))

        # NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'
        if self.up_mode == 'upsample' and self.merge_mode == 'add':
            raise ValueError("up_mode \"upsample\" is incompatible "
                             "with merge_mode \"add\" at the moment "
                             "because it doesn't make sense to use "
                             "nearest neighbour to reduce "
                             "depth channels (by half).")

        self.num_classes = num_classes
        self.in_channels = in_channels
        self.start_filts = start_filts
        self.depth = depth

        self.down_convs = []
        self.up_convs = []
        
        self.noiseSTD = nn.Parameter(data=torch.log(torch.tensor(0.5)))
        
        

        # create the encoder pathway and add to a list
        for i in range(depth):
            ins = self.in_channels if i == 0 else outs
            outs = self.start_filts*(2**i)
            pooling = True if i < depth-1 else False

            down_conv = DownConv(ins, outs, pooling=pooling)
            self.down_convs.append(down_conv)

        # create the decoder pathway and add to a list
        # - careful! decoding only requires depth-1 blocks
        for i in range(depth-1):
            ins = outs
            outs = ins // 2
            up_conv = UpConv(ins, outs, up_mode=up_mode,
                merge_mode=merge_mode)
            self.up_convs.append(up_conv)

        self.conv_final = conv1x1(outs, self.num_classes)

        # add the list of modules to current module
        self.down_convs = nn.ModuleList(self.down_convs)
        self.up_convs = nn.ModuleList(self.up_convs)

        self.reset_params()

    @staticmethod
    def weight_init(m):
        if isinstance(m, nn.Conv2d):
            init.xavier_normal(m.weight)
            init.constant(m.bias, 0)


    def reset_params(self):
        for i, m in enumerate(self.modules()):
            self.weight_init(m)


    def forward(self, x):
        encoder_outs = []

        # encoder pathway, save outputs for merging
        for i, module in enumerate(self.down_convs):
            x, before_pool = module(x)
            encoder_outs.append(before_pool)

        for i, module in enumerate(self.up_convs):
            before_pool = encoder_outs[-(i+2)]
            x = module(before_pool, x)

        # No softmax is used. This means you need to use
        # nn.CrossEntropyLoss is your training script,
        # as this module includes a softmax already.
        x = self.conv_final(x)
        return x

### Many functions we need

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision

print("CUDA?",torch.cuda.is_available())
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def normalDens(x,m_=0.0,std_=None):
    tmp=-((x-m_)**2) 
    tmp=tmp / (2.0*std_*std_)
    tmp= torch.exp(tmp )
    tmp= tmp/ torch.sqrt( (2.0*np.pi)*std_*std_)
    return tmp


In [None]:


print("CUDA?",torch.cuda.is_available())


def imgToTensor(img):
    img.shape=(img.shape[0],img.shape[1],1)
    imgOut = torchvision.transforms.functional.to_tensor(img)
    return imgOut

def get_stratified_coords2D(box_size, shape):
    coords = []
    box_count_y = int(np.ceil(shape[0] / box_size))
    box_count_x = int(np.ceil(shape[1] / box_size))
    for i in range(box_count_y):
        for j in range(box_count_x):
            y = np.random.randint(0, box_size)
            x = np.random.randint(0, box_size)
            y = int(i * box_size + y)
            x = int(j * box_size + x)
            if (y < shape[0] and x < shape[1]):
                coords.append((y, x))
    return coords

def jointShuffle(inA, inB):
    dataTmp=np.concatenate( (inA[...,np.newaxis],inB[...,np.newaxis]) , axis=-1)
    np.random.shuffle(dataTmp)
    return dataTmp[...,0], dataTmp[...,1]

def randomCropFRI(data, width, height, dataClean=None, counter=None):

    if counter is None or counter>=data.shape[0]:
        counter=0
        if dataClean is not None:
            data, dataClean=jointShuffle(data,dataClean)
        else:
            np.random.shuffle(data)
    index=counter
    counter+=1
        
    img=data[index]
    if dataClean is not None:
        imgClean = dataClean[index]
    else:
        imgClean = None
    imgOut, imgOutC, mask =randomCrop(img, width, height, imgClean=imgClean)
    return imgOut, imgOutC, mask, counter

def randomCrop(img, width, height, imgClean=None, hotPixels=64):
    assert img.shape[0] >= height
    assert img.shape[1] >= width
    
    n2v=False
    if imgClean is None:
        imgClean=img.copy()
        n2v=True

    x = np.random.randint(0, img.shape[1] - width)
    y = np.random.randint(0, img.shape[0] - height)
    
   
    imgOut = img[y:y+height, x:x+width].copy()
    imgOutC= imgClean[y:y+height, x:x+width].copy()  
    mask=np.zeros(imgOut.shape)
    maxA=imgOut.shape[1]-1
    maxB=imgOut.shape[0]-1
    
    if n2v:
        # Noise2Void training, i.e. no clean targets
        hotPixels=get_stratified_coords2D(box_size,imgOut.shape)

        for p in hotPixels:
            a,b=p[1],p[0]

            roiMinA=max(a-2,0)
            roiMaxA=min(a+3,maxA)
            roiMinB=max(b-2,0)
            roiMaxB=min(b+3,maxB)
            roi=imgOut[roiMinB:roiMaxB,roiMinA:roiMaxA]
          #  print(roi.shape,b ,a)
         #   print(b-2,b+3 ,a-2,a+3)
            a_ = 2
            b_ = 2
            while a_==2 and b_==2:
                a_ = np.random.randint(0, roi.shape[1] )
                b_ = np.random.randint(0, roi.shape[0] )

            repl=roi[b_,a_]
            imgOut[b,a]=repl
            mask[b,a]=1.0
    else:
        # Noise2Clean
        mask[:] = 1.0

    rot=np.random.randint(0,4)
    imgOut=np.array(np.rot90(imgOut,rot))
    imgOutC=np.array(np.rot90(imgOutC,rot))
    mask=np.array(np.rot90(mask,rot))
    if np.random.choice((True,False)):
        imgOut=np.array(np.flip(imgOut))
        imgOutC=np.array(np.flip(imgOutC))
        mask=np.array(np.flip(mask))
    
    return imgOut, imgOutC, mask
    

def PSNR(gt, pred, range_=255.0 ):
    mse = np.mean((gt - pred)**2)
    return 20 * np.log10((range_)/np.sqrt(mse))

import numpy as np

def normalize(img, mean, std):
    zero_mean = img - mean
    return zero_mean/std

def denormalize(x, mean, std):
    return x*std + mean



### Load data

In [None]:
# This is the path to the input data
path="/home/florian/projects/Fish/raw/"

# Training data
data_raw=np.load(path+'training_big_raw.npy')
data_gt=np.load(path+'../gt/training_big_GT.npy')
#np.random.shuffle(data_raw)
print(data_raw.shape)
imgFactor=int(data_raw.shape[0]/data_gt.shape[0])
print(imgFactor)
index=604
plt.imshow(data_raw[index])
plt.show()
plt.imshow(data_gt[index//imgFactor])
plt.show()

# Normalize
mean=np.mean(data_raw)
std=np.std(data_raw)
print(mean,std)
data=normalize(data_raw,mean,std)
dataGT=normalize(data_gt,mean,std)
dataGT=np.repeat(dataGT,imgFactor,axis=0)

dataTest_raw=np.load(path+"test_noisy.npy")
dataTest=normalize(dataTest_raw,mean,std)
plt.imshow(dataTest_raw[5])

dataTestGT=np.load(path+"../gt/test_gt.npy")
print(dataTestGT[0].shape)
plt.imshow(dataTestGT[0])
print(mean,std)

### Train the network

In [None]:
def trainingPred(my_train_data, my_train_data_clean , dataCounter,size,bs):
        # Init Variables
        inputs= torch.zeros(bs,1,size,size)
        labels= torch.zeros(bs,size,size)
        masks= torch.zeros(bs,size,size)

        # Assamble mini batch
        for j in range(bs):
            im,l,m, dataCounter=randomCropFRI(my_train_data,size,size,counter=dataCounter,dataClean=my_train_data_clean)
            inputs[j,:,:,:]=imgToTensor(im)
            labels[j,:,:]=imgToTensor(l)
            masks[j,:,:]=imgToTensor(m)

        # Move to GPU
        inputs, labels, masks= inputs.to(device), labels.to(device), masks.to(device)

        # Forward step 
        outputs = net(inputs)
        return outputs, labels, masks, dataCounter
    

def lossFunction(outputs, labels, masks):
    outs=outputs[:,0,...]
    #print(outs.shape,labels.shape,masks.shape)
    loss=torch.sum(masks*(labels-outs)**2)/torch.sum(masks)
    return loss
    

In [None]:
import torch.distributions as tdist
import torch.optim as optim

#data_c=np.concatenate((data.copy(),dataTest.copy()))
data_c=data.copy()
dataGT_c=dataGT.copy()
data_c,dataGT_c=jointShuffle(data_c,dataGT_c)

#my_train_data=data_c.copy()
#my_val_data=data_c.copy()

#my_train_dataGT=dataGT_c.copy()
#my_val_dataGT=dataGT_c.copy()

my_train_dataGT=None
my_val_dataGT=None

my_train_data=data_c.copy()
my_val_data=data_c.copy()

#device = torch.device("cpu")


net = UNet(1, depth=3)

net.to(device)
net.train(True)
bs=24
size=100
num_pix=100*100/32.0
dataCounter=None
box_size = np.round(np.sqrt(size * size / num_pix)).astype(np.int)

vbatch=20 # Virtual batch size
optimizer = optim.Adam(net.parameters(), lr=0.0001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.5, verbose=True)

running_loss = 0.0
stepCounter=0


valSize=20
stepsPerEpoch=1
trainHist=[]
valHist=[]

for step in range(40000):  # loop over the dataset multiple times
    losses=[]
    optimizer.zero_grad()
    stepCounter+=1
     
    # Iterate over virtual batch    
    for a in range (vbatch):
      
        outputs, labels, masks, dataCounter = trainingPred(my_train_data, my_train_dataGT, dataCounter,size,bs)
        
        loss=lossFunction(outputs, labels,masks)
        loss.backward()
        running_loss += loss.item()
        losses.append(loss.item())
        
        
    optimizer.step()

    
    if stepCounter % stepsPerEpoch == stepsPerEpoch-1:
        running_loss=(np.mean(losses))
        print("Step:", stepCounter, "| Avg. epoch loss:", running_loss)
        losses=np.array(losses)
        print("avg. loss: "+str(np.mean(losses))+"+-"+str(np.std(losses)/np.sqrt(losses.size)))
        trainHist.append(np.mean(losses))
        losses=[]


        torch.save(net,"last"+".net")

        valCounter=0
        net.train(False)
        losses=[]
        for i in range(valSize):
            outputs, labels, masks, valCounter = trainingPred(my_val_data,my_val_dataGT, valCounter,size,bs)
            loss=lossFunction(outputs, labels,masks)
            losses.append(loss.item())
        net.train(True)
        avgValLoss=np.mean(losses)
        if len(valHist)==0 or avgValLoss < np.min(valHist):
            torch.save(net,"best"+".net")
        valHist.append(avgValLoss)

        epoch= (stepCounter / stepsPerEpoch)


        np.save("history"+".npy", (np.array( [np.arange(epoch),trainHist,valHist ] ) ) )

        plt.plot(valHist)
        plt.plot(trainHist)
        plt.show()
        scheduler.step(avgValLoss)
        
        if stepCounter / stepsPerEpoch > 200:
            break


        
        

print('Finished Training')

### Evaluation

In [None]:
def predict(im, l):
    inputs= torch.zeros(1,1,im.shape[0],im.shape[1])
    inputs[0,:,:,:]=imgToTensor(im);  
    
    # copy to GPU
    inputs = inputs.to(device)
    
    output=net(inputs)

    samples=(output).permute(1, 0, 2, 3)
    
    means = samples[0,...] # Sum up over all samples
    
    # Get data from GPU
    means=means.cpu().detach().numpy()

    # Reshape to 2D images and remove padding
    means.shape=(output.shape[2],output.shape[3])
    
    # Denormalize
    means=denormalize(means,mean,std)
    return means

In [None]:
#from scipy import ndimage, misc

results=[]
meanRes=[]


#load network
net=torch.load("last.net")

#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
estimate=torch.tensor(25.0/std).to(device)
for index in range(dataTest.shape[0]):

    im=dataTest[index]
    l=dataTestGT[0]
    print(im.shape,l.shape)
    means=np.zeros(im.shape)
    mseEst=np.zeros(im.shape)
    
    
    
    # We have to use tiling because of memory constraints on the GPU
    ps=128
    overlap=48
    xmin=0
    ymin=0
    xmax=ps
    ymax=ps
    ovLeft=0
    while (xmin<im.shape[1]):
        ovTop=0
        while (ymin<im.shape[0]):     
            a= predict(im[ymin:ymax,xmin:xmax],l[ymin:ymax,xmin:xmax])
            means[ymin:ymax,xmin:xmax][ovTop:,ovLeft:] = a[ovTop:,ovLeft:]
            ymin=ymin-overlap+ps
            ymax=ymin+ps
            ovTop=overlap//2
        ymin=0 
        ymax=ps
        xmin=xmin-overlap+ps
        xmax=xmin+ps
        ovLeft=overlap//2
    

    im=denormalize(im,mean,std)
    vmi=np.percentile(l,0.05)
    vma=np.percentile(l,99.5)
    print(vmi,vma)
    
    
    psnrPrior=PSNR(l, means,255 )
    results.append(psnrPrior)

    
    print ("PSNR raw",PSNR(l, im,255 ))
    print ("PSNR prior",psnrPrior) # Without info from masked pixel
    print ("index",index) 
    print (np.min(means),np.max(means))
     
    plt.imshow(im[200:328,200:328],cmap='gray',vmin=0,vmax=255) # GT
    plt.show()
    
    plt.imshow(means[200:328,200:328],cmap='gray') # MSE estimate using the masked pixel
    plt.show()
    


print("Avg Prior:", np.mean(np.array(results)))