In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms

from modules.UNet import *
from modules.Discriminator import *
from modules.DataSet import *
from modules.Losses import *

import os,sys

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import random_split
from PIL import Image

import numpy as np

## Training

In [2]:
def train_adversarial(net,discriminator,device,val_per=0.1,epochs=10,batch_size=10,resize_to=None):
    if resize_to is not None:
        transform_image = transforms.Compose([
        transforms.Resize(resize_to),
        transforms.ToTensor(),
        transforms.Normalize(0.5,0.5)
        ])
        transform_label = transforms.Compose([
        transforms.Resize(resize_to),
        transforms.ToTensor()
        ])
    else:
        transform_image = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(0.5,0.5)
        ])
        transform_label = transforms.Compose([
        transforms.ToTensor()
        ])
            
    
    dataSet = UltraSoundDataSet(root_dir,(transform_image,transform_label))
    nTrain = int(len(dataSet)*(1-val_per))
    nValid = int(len(dataSet)-nTrain)
    
    trainSet,validSet = random_split(dataSet,[nTrain,nValid])
    
    train_loader = DataLoader(trainSet,batch_size=batch_size,shuffle=True,num_workers=4)
    valid_loader = DataLoader(validSet,batch_size=batch_size,shuffle=True,num_workers=4)
    
    optimizer_unet = torch.optim.Adam(net.parameters())
    optimizer_disc = torch.optim.Adam(discriminator.parameters())
    scheduler_unet = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_unet,mode='min',patience=10) #mae: dice-index
    scheduler_disc = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_disc,mode='min',patience=10)
    BCELoss = nn.BCELoss()
    
    running_loss_seg = 0
    running_loss_disc = 0
    step = 0
    np.set_printoptions(precision=2)
    for epoch in range(epochs):
        net.train()
        discriminator.train()
        
        for batch in train_loader:
            imgs,labels = batch
            
            imgs = imgs.to(device=device,dtype=torch.float32)
            labels = labels.to(device=device,dtype=torch.float32)
            
            # segmentation network
            pred = net(imgs)
            
            # discriminator network
            pred_valid = discriminator(imgs, labels)
            pred_fake = discriminator(imgs, pred.detach())   
            print("valid: ",np.array(pred_valid.view([1,-1]).tolist()),"; fake: ",np.array(pred_fake.view([1,-1]).tolist()) )
            # seg loss
            #seg_loss = DiceLoss(pred,labels)
            seg_loss = BCELoss(pred,labels)
            
            if epoch<epochs/2:
                optimizer_unet.zero_grad()
                seg_loss.backward(retain_graph=True)
                optimizer_unet.step()
            
            
            valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device=device,dtype=torch.float32)
            fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device=device,dtype=torch.float32)
            
            # discriminator loss
            disc_loss = BCELoss(pred_valid,valid) + BCELoss(pred_fake,fake)

            optimizer_disc.zero_grad()
            disc_loss.backward(retain_graph=False)
            optimizer_disc.step()
            
            running_loss_seg += seg_loss.item()
            running_loss_disc += disc_loss.item()
            step += 1
            
            if step % 10 == 9:    # print every 10 mini-batches
                print('[%d, %5d] segmentation loss: %.3f; discrimination loss: %.3f' %(epoch + 1, step + 1, running_loss_seg / 10, running_loss_disc / 10))
                running_loss_seg = 0.0
                running_loss_disc = 0.0
                
            if step%100 == 99:
                val_loss_seg = 0
                val_loss_disc = 0
                for batch in valid_loader:
                    imgs,labels = batch
                    imgs = imgs.to(device=device,dtype=torch.float32)
                    labels = labels.to(device=device,dtype=torch.float32)
                    
                    with torch.no_grad():
                        pred_seg = net(imgs)
                        pred_valid = discriminator(imgs, labels)
                        pred_fake = discriminator(imgs, pred_seg.detach())
                    
                    val_loss_seg += DiceLoss(pred_seg,labels).item()
                    valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device=device,dtype=torch.float32)
                    fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device=device,dtype=torch.float32)
                    val_loss_disc += (BCELoss(pred_valid,valid) + BCELoss(pred_fake,fake)).item()
                print('[%d, %5d] validation loss(seg): %.3f; validation loss(disc): %.3f' %(epoch + 1, step + 1, val_loss_seg / len(valid_loader), val_loss_disc / len(valid_loader)))
                scheduler_unet.step(val_loss_seg)
                scheduler_disc.step(val_loss_disc)

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = UNet2(init_features=32)
discriminator = Discriminator()

net = net.to(device)
discriminator = discriminator.to(device)

#IMG_SIZE = [256,256]
IMG_SIZE = None

root_dir = os.path.expanduser("~/workspace/us_robot/DataSet/SimRealDataset")
try:
    #train(net=net,device=device,resize_to=None,epochs=50,batch_size=5)
    train_adversarial(net=net,discriminator=discriminator,device=device,resize_to=IMG_SIZE,epochs=5,batch_size=3)
except KeyboardInterrupt:
    sys.exit()

valid:  [[0.52 0.52 0.52]] ; fake:  [[0.52 0.52 0.52]]
valid:  [[0.86 0.86 0.86]] ; fake:  [[0.63 0.63 0.63]]
valid:  [[0.28 0.32 0.33]] ; fake:  [[0.1  0.09 0.1 ]]
valid:  [[0.92 0.92 0.92]] ; fake:  [[0.28 0.3  0.3 ]]
valid:  [[0.98 0.8  0.98]] ; fake:  [[0.36 0.33 0.36]]
valid:  [[0.99 0.99 0.99]] ; fake:  [[0.17 0.16 0.17]]
valid:  [[0.99 0.99 0.79]] ; fake:  [[0.04 0.03 0.03]]
valid:  [[0.99 0.99 0.99]] ; fake:  [[0.01 0.01 0.01]]
valid:  [[0.99 0.99 0.99]] ; fake:  [[0. 0. 0.]]
[1,    10] segmentation loss: 0.616; discrimination loss: 0.510
valid:  [[0.99 0.97 0.99]] ; fake:  [[0. 0. 0.]]
valid:  [[1. 1. 1.]] ; fake:  [[0. 0. 0.]]
valid:  [[1.   1.   0.99]] ; fake:  [[7.66e-05 7.87e-05 7.71e-05]]
valid:  [[1. 1. 1.]] ; fake:  [[7.98e-05 3.53e-05 2.16e-05]]
valid:  [[1. 1. 1.]] ; fake:  [[1.99e-05 1.98e-05 5.18e-06]]
valid:  [[1. 1. 1.]] ; fake:  [[1.04e-05 8.69e-06 3.61e-05]]
valid:  [[1. 1. 1.]] ; fake:  [[6.76e-06 1.43e-05 1.51e-05]]
valid:  [[1. 1. 1.]] ; fake:  [[1.31e-05 7.2

valid:  [[1. 1. 1.]] ; fake:  [[1.53e-05 2.28e-05 3.86e-06]]
valid:  [[1. 1. 1.]] ; fake:  [[5.57e-06 3.09e-05 9.22e-06]]
valid:  [[1. 1. 1.]] ; fake:  [[1.32e-05 1.02e-05 1.42e-04]]
valid:  [[1. 1. 1.]] ; fake:  [[5.53e-05 5.55e-05 1.33e-05]]
[1,   130] segmentation loss: 0.189; discrimination loss: 0.000
valid:  [[1. 1. 1.]] ; fake:  [[4.08e-05 4.82e-05 1.92e-05]]
valid:  [[1. 1. 1.]] ; fake:  [[3.76e-05 7.45e-05 3.32e-05]]
valid:  [[1. 1. 1.]] ; fake:  [[3.38e-05 4.04e-05 2.19e-04]]
valid:  [[1. 1. 1.]] ; fake:  [[0. 0. 0.]]
valid:  [[1. 1. 1.]] ; fake:  [[1.65e-03 1.75e-03 9.15e-05]]
valid:  [[1. 1. 1.]] ; fake:  [[2.32e-04 3.89e-05 3.88e-04]]
valid:  [[1. 1. 1.]] ; fake:  [[4.55e-05 1.02e-04 2.61e-04]]
valid:  [[1. 1. 1.]] ; fake:  [[2.51e-05 1.34e-05 2.50e-05]]
valid:  [[1. 1. 1.]] ; fake:  [[7.83e-06 1.25e-05 3.59e-06]]
valid:  [[1. 1. 1.]] ; fake:  [[1.84e-04 5.67e-05 1.17e-05]]
[1,   140] segmentation loss: 0.179; discrimination loss: 0.000
valid:  [[1. 1. 1.]] ; fake:  [[9.48

valid:  [[1. 1. 1.]] ; fake:  [[0.00e+00 3.83e-39 0.00e+00]]
valid:  [[1. 1. 1.]] ; fake:  [[9.59e-25 3.95e-34 4.28e-34]]
[1,   250] segmentation loss: 0.085; discrimination loss: 0.248
valid:  [[1. 1. 1.]] ; fake:  [[1.06e-30 4.22e-30 1.35e-32]]
valid:  [[1. 1. 1.]] ; fake:  [[1.33e-25 2.56e-22 2.18e-27]]
valid:  [[1. 1. 1.]] ; fake:  [[2.35e-24 6.57e-17 3.36e-19]]
valid:  [[1. 1. 1.]] ; fake:  [[7.98e-14 5.91e-18 6.62e-15]]
valid:  [[1. 1. 1.]] ; fake:  [[2.06e-10 5.91e-20 3.66e-09]]
valid:  [[1. 1. 1.]] ; fake:  [[7.96e-04 4.77e-05 2.60e-04]]
valid:  [[1. 1. 1.]] ; fake:  [[1.38e-30 3.50e-24 9.49e-01]]
valid:  [[1. 1. 1.]] ; fake:  [[1.00e+00 1.17e-03 3.64e-31]]
valid:  [[1. 1. 1.]] ; fake:  [[4.41e-37 1.13e-12 6.60e-13]]
valid:  [[1. 1. 1.]] ; fake:  [[4.52e-20 3.58e-13 3.11e-19]]
[1,   260] segmentation loss: 0.087; discrimination loss: 0.385
valid:  [[1. 0. 1.]] ; fake:  [[1.65e-23 0.00e+00 7.12e-26]]
valid:  [[1. 1. 1.]] ; fake:  [[4.61e-30 2.69e-30 2.53e-31]]
valid:  [[1. 1. 1.

valid:  [[1. 1. 1.]] ; fake:  [[3.62e-33 9.50e-10 0.00e+00]]
valid:  [[1. 1. 1.]] ; fake:  [[2.95e-13 8.27e-32 4.42e-25]]
valid:  [[1. 1. 1.]] ; fake:  [[2.80e-15 1.36e-38 4.62e-21]]
valid:  [[1. 1. 1.]] ; fake:  [[1.55e-27 6.10e-21 4.97e-22]]
valid:  [[1. 1. 1.]] ; fake:  [[6.25e-29 1.19e-28 1.07e-18]]
valid:  [[1. 1. 1.]] ; fake:  [[7.06e-10 7.23e-25 0.00e+00]]
valid:  [[1. 1. 1.]] ; fake:  [[0.00e+00 4.78e-18 0.00e+00]]
[1,   380] segmentation loss: 0.048; discrimination loss: 0.000
valid:  [[1. 1. 1.]] ; fake:  [[5.79e-13 1.73e-10 1.50e-35]]
valid:  [[1. 1. 1.]] ; fake:  [[8.44e-30 1.31e-32 1.76e-15]]
valid:  [[1. 1. 1.]] ; fake:  [[1.56e-29 2.43e-33 8.90e-15]]
valid:  [[1. 1. 1.]] ; fake:  [[1.53e-15 2.20e-27 2.09e-37]]
valid:  [[1. 1. 1.]] ; fake:  [[0.00e+00 1.17e-37 3.14e-12]]
valid:  [[1. 1. 1.]] ; fake:  [[2.97e-13 0.00e+00 7.64e-17]]
valid:  [[1. 1. 1.]] ; fake:  [[2.41e-18 4.35e-14 7.99e-23]]
valid:  [[1. 1. 1.]] ; fake:  [[0.00e+00 6.19e-11 1.02e-26]]
valid:  [[1. 1. 1.]] 

valid:  [[1. 1. 1.]] ; fake:  [[1.49e-37 0.00e+00 3.26e-34]]
valid:  [[1. 1. 1.]] ; fake:  [[4.52e-14 2.42e-30 0.00e+00]]
valid:  [[1. 1. 1.]] ; fake:  [[8.94e-18 6.57e-16 2.48e-29]]
valid:  [[1. 1. 1.]] ; fake:  [[6.66e-09 2.98e-17 9.41e-18]]
valid:  [[1. 1. 1.]] ; fake:  [[7.74e-28 4.34e-01 1.00e+00]]
valid:  [[1. 1. 1.]] ; fake:  [[1.50e-20 1.49e-09 3.65e-07]]
valid:  [[1. 1. 1.]] ; fake:  [[1.88e-31 5.16e-33 4.42e-28]]
[1,   500] segmentation loss: 0.023; discrimination loss: 1.060
[1,   500] validation loss(seg): 0.398; validation loss(disc): 1.333
valid:  [[1. 1. 1.]] ; fake:  [[7.84e-30 0.00e+00 9.26e-22]]
valid:  [[1. 1. 1.]] ; fake:  [[5.57e-34 0.00e+00 0.00e+00]]
valid:  [[1. 1. 1.]] ; fake:  [[9.41e-35 0.00e+00 0.00e+00]]
valid:  [[1. 1. 1.]] ; fake:  [[0.0e+00 1.9e-32 0.0e+00]]
valid:  [[0.   0.19 0.86]] ; fake:  [[0.00e+00 0.00e+00 8.76e-39]]
valid:  [[1. 1. 1.]] ; fake:  [[0. 0. 0.]]
valid:  [[1. 1. 1.]] ; fake:  [[0.00e+00 4.99e-37 3.05e-35]]
valid:  [[1. 1. 1.]] ; fake:

valid:  [[1. 1. 1.]] ; fake:  [[1.45e-20 6.17e-16 1.80e-34]]
valid:  [[1. 1. 1.]] ; fake:  [[3.41e-21 4.78e-27 4.11e-24]]
valid:  [[1. 1. 1.]] ; fake:  [[1.81e-34 9.37e-22 2.24e-22]]
valid:  [[1. 1. 1.]] ; fake:  [[1.11e-32 1.59e-23 9.67e-31]]
valid:  [[1. 1. 1.]] ; fake:  [[2.37e-20 1.28e-29 2.68e-26]]
valid:  [[1. 1. 1.]] ; fake:  [[1.87e-17 3.38e-30 1.73e-33]]
valid:  [[1. 1. 1.]] ; fake:  [[4.08e-14 1.76e-34 8.34e-34]]
[1,   620] segmentation loss: 0.024; discrimination loss: 0.000
valid:  [[1. 1. 1.]] ; fake:  [[1.04e-18 8.17e-27 2.49e-12]]
valid:  [[1. 1. 1.]] ; fake:  [[3.09e-24 3.59e-24 6.08e-20]]
valid:  [[1. 1. 1.]] ; fake:  [[5.87e-18 2.26e-23 2.12e-21]]
valid:  [[1. 1. 1.]] ; fake:  [[3.87e-22 7.80e-31 1.37e-23]]
valid:  [[1. 1. 1.]] ; fake:  [[6.67e-17 5.85e-24 3.30e-26]]
valid:  [[1. 1. 1.]] ; fake:  [[7.43e-26 1.26e-24 3.84e-06]]
valid:  [[1.00e+00 2.53e-34 1.00e+00]] ; fake:  [[1.40e-26 1.94e-34 8.86e-16]]
valid:  [[1. 1. 1.]] ; fake:  [[4.94e-23 2.24e-22 3.41e-17]]
val

valid:  [[1. 1. 1.]] ; fake:  [[2.39e-32 2.37e-34 1.15e-31]]
valid:  [[1. 1. 1.]] ; fake:  [[1.18e-28 1.76e-32 1.59e-36]]
valid:  [[1. 1. 1.]] ; fake:  [[2.75e-37 3.00e-21 6.48e-35]]
valid:  [[1. 1. 1.]] ; fake:  [[1.14e-33 0.00e+00 0.00e+00]]
valid:  [[1. 1. 1.]] ; fake:  [[1.4e-30 0.0e+00 0.0e+00]]
valid:  [[1. 1. 1.]] ; fake:  [[3.70e-37 2.29e-34 7.13e-36]]
[1,   740] segmentation loss: 0.016; discrimination loss: 0.000
valid:  [[1. 1. 1.]] ; fake:  [[3.47e-39 8.23e-39 0.00e+00]]
valid:  [[1. 1. 1.]] ; fake:  [[1.51e-35 0.00e+00 0.00e+00]]
valid:  [[1. 1. 1.]] ; fake:  [[0. 0. 0.]]
valid:  [[1.   1.   0.99]] ; fake:  [[9.21e-34 7.97e-38 0.00e+00]]
valid:  [[1. 1. 1.]] ; fake:  [[0.00e+00 1.87e-38 6.11e-39]]
valid:  [[1. 1. 1.]] ; fake:  [[1.06e-37 0.00e+00 0.00e+00]]
valid:  [[1.   0.98 1.  ]] ; fake:  [[0. 0. 0.]]
valid:  [[1.   0.06 0.98]] ; fake:  [[0. 0. 0.]]
valid:  [[1. 1. 1.]] ; fake:  [[3.50e-39 1.49e-33 0.00e+00]]
valid:  [[1. 1. 1.]] ; fake:  [[0.00e+00 3.64e-34 0.00e+00]]

valid:  [[1. 1. 1.]] ; fake:  [[1.00e+00 1.14e-29 2.69e-15]]
valid:  [[1. 1. 1.]] ; fake:  [[3.20e-19 3.62e-32 2.81e-20]]
valid:  [[1. 1. 1.]] ; fake:  [[1.61e-30 1.65e-27 9.01e-27]]
valid:  [[1. 1. 1.]] ; fake:  [[9.74e-29 2.99e-29 6.01e-35]]
valid:  [[1. 1. 1.]] ; fake:  [[2.30e-34 2.22e-30 3.68e-30]]
[1,   860] segmentation loss: 0.014; discrimination loss: 0.367
valid:  [[1. 1. 1.]] ; fake:  [[2.06e-33 0.00e+00 0.00e+00]]
valid:  [[1. 1. 1.]] ; fake:  [[4.02e-39 4.36e-35 1.26e-38]]
valid:  [[1. 1. 1.]] ; fake:  [[0.00e+00 1.53e-37 0.00e+00]]
valid:  [[0.97 1.   0.98]] ; fake:  [[0.00e+00 6.85e-38 0.00e+00]]
valid:  [[0.24 0.92 1.  ]] ; fake:  [[0. 0. 0.]]
valid:  [[1. 1. 1.]] ; fake:  [[2.91e-38 0.00e+00 0.00e+00]]
valid:  [[1. 1. 1.]] ; fake:  [[3.60e-31 5.72e-32 0.00e+00]]
valid:  [[1. 1. 1.]] ; fake:  [[2.47e-37 1.10e-32 5.62e-37]]
valid:  [[1. 1. 1.]] ; fake:  [[1.28e-31 3.17e-31 2.90e-29]]
valid:  [[1. 1. 1.]] ; fake:  [[1.42e-33 7.35e-30 1.49e-30]]
[1,   870] segmentation los

SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


## Inference

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

PATH = os.path.expanduser("~/workspace/us_robot/unet_gan_usseg.pth")
net = UNet(init_features=32).to(device)
net.load_state_dict(torch.load(PATH))
net = net.eval()


In [5]:
test_dir = os.path.expanduser("~/workspace/us_robot/DataSet/realDataSet/linear/vessel_neck")
pred_dir = os.path.expanduser("~/workspace/us_robot/DataSet/realDataSet/linear/vessel_pred")

testset_list = os.listdir(test_dir)
resize_to=None
transform_image = transforms.Compose([
    #transforms.Resize(resize_to),
    transforms.ToTensor(),
    transforms.Normalize(0.5,0.5) #Division by 255 is done, when the transformation assumes an image.
    ])
transform_label = transforms.Compose([
    #transforms.Resize(resize_to),
    transforms.ToTensor()
    ])
invtransform_label = transforms.Compose([
    transforms.ToPILImage(),
    #transforms.Resize([1000,500])
    ])

for sample in testset_list:
    image_path = os.path.join(test_dir,sample)
    label_path = os.path.join(pred_dir,sample)
    
    img = Image.open(image_path).convert("L")
    img = transform_image(img)
    img = img.to(device)
    img = img.unsqueeze(0)
    #label = Image.open(label_path)
    #label = transform_label(label).to(device)
    #label = label.unsqueeze(0)
    
    
    with torch.no_grad():
        pred = net(img)
    
    #DiceIndex = (1 - DiceLoss(pred,label)).cpu().item()

    pred = invtransform_label(pred.cpu().squeeze(0))
    #fname = "pred%.2f.png"%DiceIndex
    fname = sample
    sav_path = os.path.join(pred_dir,fname)
    pred.save(sav_path)

In [6]:
test_dir = os.path.expanduser("~/workspace/us_robot/DataSet/SimRealDatasetTest")
testset_list = os.listdir(test_dir)
resize_to=None
transform_image = transforms.Compose([
    #transforms.Resize(resize_to),
    transforms.ToTensor(),
    transforms.Normalize(0.5,0.5) #Division by 255 is done, when the transformation assumes an image.
    ])
transform_label = transforms.Compose([
    #transforms.Resize(resize_to),
    transforms.ToTensor()
    ])
invtransform_label = transforms.Compose([
    transforms.ToPILImage(),
    #transforms.Resize([1000,500])
    ])
for sample in testset_list:
    image_path = os.path.join(test_dir,sample,"image.png")
    label_path = os.path.join(test_dir,sample,"label.png")
    
    img = Image.open(image_path)
    img = transform_image(img)
    img = img.to(device)
    img = img.unsqueeze(0)
    label = Image.open(label_path)
    label = transform_label(label).to(device)
    label = label.unsqueeze(0)
    
    
    with torch.no_grad():
        pred = net(img)
    
    DiceIndex = (1 - DiceLoss(pred,label)).cpu().item()

    pred = invtransform_label(pred.cpu().squeeze(0))
    fname = "pred%.2f.png"%DiceIndex
    sav_path = os.path.join(test_dir,sample,fname)
    pred.save(sav_path)