In [1]:
import nibabel as nb
import SimpleITK as sitk
import matplotlib.pyplot as plt
import numpy as np 
import torch 
from torch import nn

import torch
import torchvision
import torchvision.transforms as transforms

import torch.optim as optim
from torch.utils.data import DataLoader

from sklearn.model_selection import train_test_split

import os, glob, cv2, sys

In [2]:
# Modelo

class Unet(nn.Module):
    
    def __init__(self):
        super(Unet,self).__init__()
        
        # Contract
        self.layer1Down = nn.Sequential (
            torch.nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU(),
            torch.nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU(),
        )
        
        self.layer2Down = nn.Sequential (
            torch.nn.Conv2d(32, 64, kernel_size = 3, stride=1, padding=1),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 64, kernel_size= 3, stride=1, padding=1),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(),
        )
        
        self.layer3Down = nn.Sequential (
            torch.nn.Conv2d(64, 128, kernel_size = 3, stride=1, padding=1),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            torch.nn.Conv2d(128, 128, kernel_size = 3, stride=1, padding=1),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
        )
        
        self.layer4Down = nn.Sequential (
            torch.nn.Conv2d(128, 256, kernel_size = 3, stride=1, padding=1),
            torch.nn.BatchNorm2d(256),
            torch.nn.ReLU(),
            torch.nn.Conv2d(256, 256, kernel_size = 3, stride=1, padding=1),
            torch.nn.BatchNorm2d(256),
            torch.nn.ReLU(),
        )
        
        self.layer5Down = nn.Sequential (
            torch.nn.Conv2d(256, 512, kernel_size = 3, stride=1, padding=1),
            torch.nn.BatchNorm2d(512),
            torch.nn.ReLU(),
            torch.nn.Conv2d(512, 512, kernel_size = 3, stride=1, padding=1),
            torch.nn.BatchNorm2d(512),
            torch.nn.ReLU(),
        )
        
        self.Pooling = torch.nn.MaxPool2d(kernel_size=2, stride=2)
        
        #Expand
        
        self.layer5ConvTransposed = nn.Sequential ( 
            torch.nn.ConvTranspose2d(512, 256, kernel_size = 2, stride = 2, padding = 0)
        )
        
        self.layer4Up = nn.Sequential (
            torch.nn.Conv2d(512, 256, kernel_size = 3, stride=1, padding=1),
            torch.nn.BatchNorm2d(256),
            torch.nn.ReLU(),
            torch.nn.Conv2d(256, 256, kernel_size = 3, stride=1, padding=1),
            torch.nn.BatchNorm2d(256),
            torch.nn.ReLU(),
        )
        
        self.layer4ConvTransposed = nn.Sequential ( 
            torch.nn.ConvTranspose2d(256, 128, kernel_size = 2, stride = 2, padding = 0)
        )
        
        self.layer3Up = nn.Sequential (
            torch.nn.Conv2d(256, 128, kernel_size = 3, stride=1, padding=1),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            torch.nn.Conv2d(128, 128, kernel_size = 3, stride=1, padding=1),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
        )
        
        self.layer3ConvTransposed = nn.Sequential ( 
            torch.nn.ConvTranspose2d(128, 64, kernel_size = 2, stride = 2, padding = 0)
        )
        
        self.layer2Up = nn.Sequential (
            torch.nn.Conv2d(128, 64, kernel_size = 3, stride=1, padding=1),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 64, kernel_size = 3, stride=1, padding=1),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(),
        )
        
        self.layer2ConvTransposed = nn.Sequential ( 
            torch.nn.ConvTranspose2d(64, 32, kernel_size = 2, stride = 2, padding = 0)
        )
        
        self.layer1Up = nn.Sequential (
            torch.nn.Conv2d(64, 32, kernel_size = 3, stride=1, padding=1),
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU(),
            torch.nn.Conv2d(32, 32, kernel_size = 3, stride=1, padding=1),
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU(),
        )
        
    
    def forward(self, x):

        # Down
        conv1 = self.layer1Down(x)
        maxPool1 = self.Pooling(conv1)
        
        
        conv2 = self.layer2Down(maxPool1)
        maxPool2 = self.Pooling(conv2)
        
        conv3 = self.layer3Down(maxPool2)
        maxPool3 = self.Pooling(conv3)
        
        conv4 = self.layer4Down(maxPool3)
        maxPool4 = self.Pooling(conv4)
        
        conv5 = self.layer5Down(maxPool4)
        
        # Up
        layerConvTransposed = self.layer5ConvTransposed(conv5)
        layer4UpData = torch.cat((layerConvTransposed,conv4), dim=1)
        convTrans4 = self.layer4Up(layer4UpData)
        
 
        layerConvTransposed = self.layer4ConvTransposed(convTrans4)
        layer3UpData = torch.cat((layerConvTransposed,conv3), dim = 1)
        convTrans3 = self.layer3Up(layer3UpData)
        
        layerConvTransposed = self.layer3ConvTransposed(convTrans3)
        layer2UpData = torch.cat((layerConvTransposed,conv2), dim = 1)
        convTrans2 = self.layer2Up(layer2UpData)
        
        layerConvTransposed = self.layer2ConvTransposed(convTrans2)
        layer1UpData = torch.cat((layerConvTransposed,conv1), dim = 1)
        convTrans1 = self.layer1Up(layer1UpData)
        
        outNet = torch.nn.Conv2d(32,2, kernel_size = 1)(convTrans1)
        return outNet

unet = Unet()

In [34]:
# Carga de datos

noisyDataSet1_nii = sitk.ReadImage('./noisyDataSet1.nii')
img_noisyDataSet1 = sitk.GetArrayFromImage(noisyDataSet1_nii)

noisyDataSet2_nii = sitk.ReadImage('./noisyDataSet2.nii')
img_noisyDataSet2 = sitk.GetArrayFromImage(noisyDataSet2_nii)

groundTruth_nii = sitk.ReadImage('./groundTruth.nii')
img_groundTruth = sitk.GetArrayFromImage(groundTruth_nii)

img_noisyDataSet1 =img_noisyDataSet1[:,44:300,44:300]
img_noisyDataSet2 =img_noisyDataSet2[:,44:300,44:300]
img_groundTruth =img_groundTruth[:,44:300,44:300]

img_noisyDataSet1 = np.expand_dims(img_noisyDataSet1, axis=-1)
img_noisyDataSet2 = np.expand_dims(img_noisyDataSet2, axis=-1)
img_groundTruth = np.expand_dims(img_groundTruth, axis=-1)

img_noisyDataSet1 = img_noisyDataSet1.astype(np.float32)
img_noisyDataSet2 = img_noisyDataSet2.astype(np.float32)
img_groundTruth = img_groundTruth.astype(np.float32)

In [35]:
# Loss and optimizer

criterion = nn.MSELoss()
optimizer = optim.Adam(unet.parameters(), lr=0.0001)

In [36]:
# Conjunto de ENTRENAMIENTO, TESTEO y VALIDACION

train_noisyImage,test_noisyImage,train_groundTruth,test_groundTruth = train_test_split(img_noisyDataSet1, img_groundTruth, test_size=0.2)

valid_noisyImage = train_noisyImage[-5:,:,:,:]
valid_groundTruth = train_groundTruth[-5:,:,:,:]

train_noisyImage = train_noisyImage [:-5,:,:,:]
train_groundTruth = train_groundTruth[:-5:,:,:,:]


In [37]:
from torch.utils.data import Dataset

transform = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()])

class denoisingDataset(Dataset):
    def __init__(self, data_groundTruth ,data_noisyImage, transform = None):
        self.transform = transform
        self.imgs_data       = data_groundTruth
        self.noisy_imgs_data = data_noisyImage
        
    
    def __getitem__(self, index):  
        
        img       = self.imgs_data[index,:,:,:]
        noisy_img = self.noisy_imgs_data[index,:,:,:]

        if self.transform is not None:            
            img = self.transform(img)             
            noisy_img = self.transform(noisy_img)  

        return img, noisy_img
    
    def __len__(self):
        return len(self.imgs_data)


In [38]:
transform = transforms.Compose([transforms.ToTensor()])

train_dataset       = denoisingDataset(train_groundTruth,train_noisyImage, transform = transform)
valid_dataset       = denoisingDataset(test_groundTruth,valid_noisyImage, transform = transform)
test_dataset       = denoisingDataset(valid_groundTruth,test_noisyImage, transform = transform)

In [39]:
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=4,shuffle=True, num_workers=0)
validloader = torch.utils.data.DataLoader(train_dataset, batch_size=4,shuffle=True, num_workers=0)
testloader = torch.utils.data.DataLoader(train_dataset, batch_size=4,shuffle=True, num_workers=0)

In [44]:
# Pytorch

def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(trainloader):
        # Every data instance is an input + label pair
        inputs, trainGroundTruth = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = unet(inputs)

        # Compute the loss and its gradients
        loss = criterion(outputs, trainGroundTruth)
        loss.backward()

        # Actualizacion de pesos
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        
        if i % 10 == 9:
            last_loss = running_loss / 10 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(trainloader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

In [None]:
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime


timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0

EPOCHS = 1

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    unet.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)

    # We don't need gradients on to do reporting
    unet.train(False)

    running_vloss = 0.0
    for i, vdata in enumerate(validloader):
        vinputs, validGroundTruth = vdata
        voutputs = unet(vinputs)
        vloss = criterion(voutputs,  validGroundTruth)
        running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(unet.state_dict(), model_path)

    epoch_number += 1

EPOCH 1:
  batch 10 loss: 1.8680997848510743
  batch 20 loss: 1.3842426240444183
  batch 30 loss: 1.4859347939491272
  batch 40 loss: 1.4209369838237762
  batch 50 loss: 1.2849078565835952
  batch 60 loss: 1.924480402469635
  batch 70 loss: 1.715177631378174
  batch 80 loss: 1.3233276188373566
  batch 90 loss: 1.331294333934784
  batch 100 loss: 1.6754584789276123
  batch 110 loss: 1.5937489032745362
  batch 120 loss: 1.5057116627693177
  batch 130 loss: 1.7787162899971007
  batch 140 loss: 1.749681407213211
  batch 150 loss: 1.234049129486084
  batch 160 loss: 1.5591177105903626
  batch 170 loss: 1.475886207818985
  batch 180 loss: 1.56326921582222
  batch 190 loss: 1.6132418811321259
  batch 200 loss: 1.3376784861087798
  batch 210 loss: 1.4320926427841187
  batch 220 loss: 1.3655160427093507
  batch 230 loss: 1.231603269279003
  batch 240 loss: 1.6065552711486817
  batch 250 loss: 1.7651351153850556
  batch 260 loss: 1.5759612321853638
  batch 270 loss: 1.4622745215892792
  batch 28