In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms

from modules.UNet import *
from modules.Discriminator import *
from modules.DataSet import *
from modules.Losses import *

import os,sys

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import random_split
from PIL import Image

import numpy as np

## Training

In [2]:
# def train_adversarial(net,discriminator,device,val_per=0.1,epochs=10,batch_size=10,resize_to=None):
#     if resize_to is not None:
#         transform_image = transforms.Compose([
#         transforms.Resize(resize_to),
#         transforms.ToTensor(),
#         transforms.Normalize(0.5,0.5)
#         ])
#         transform_label = transforms.Compose([
#         transforms.Resize(resize_to),
#         transforms.ToTensor()
#         ])
#     else:
#         transform_image = transforms.Compose([
#         transforms.ToTensor(),
#         transforms.Normalize(0.5,0.5)
#         ])
#         transform_label = transforms.Compose([
#         transforms.ToTensor()
#         ])
            
    
#     dataSet = UltraSoundDataSet(root_dir,(transform_image,transform_label))
#     nTrain = int(len(dataSet)*(1-val_per))
#     nValid = int(len(dataSet)-nTrain)
    
#     trainSet,validSet = random_split(dataSet,[nTrain,nValid])
    
#     train_loader = DataLoader(trainSet,batch_size=batch_size,shuffle=True,num_workers=4)
#     valid_loader = DataLoader(validSet,batch_size=batch_size,shuffle=True,num_workers=4)
    
#     optimizer_unet = torch.optim.Adam(net.parameters())
#     optimizer_disc = torch.optim.Adam(discriminator.parameters())
#     scheduler_unet = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_unet,mode='min',patience=10) #mae: dice-index
#     scheduler_disc = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_disc,mode='min',patience=10)
#     BCELoss = nn.BCELoss()
    
#     running_loss_seg = 0
#     running_loss_disc = 0
#     step = 0
#     np.set_printoptions(precision=2)
#     for epoch in range(epochs):
#         net.train()
#         discriminator.train()
        
#         for batch in train_loader:
#             imgs,labels = batch
            
#             imgs = imgs.to(device=device,dtype=torch.float32)
#             labels = labels.to(device=device,dtype=torch.float32)
            
#             # segmentation network
#             pred = net(imgs)
            
#             # discriminator network
#             pred_valid = discriminator(imgs, labels)
#             pred_fake = discriminator(imgs, pred.detach())   
#             print("valid: ",np.array(pred_valid.view([1,-1]).tolist()),"; fake: ",np.array(pred_fake.view([1,-1]).tolist()) )
#             # seg loss
#             #seg_loss = DiceLoss(pred,labels)
#             seg_loss = BCELoss(pred,labels)
            
#             if epoch<epochs/2:
#                 optimizer_unet.zero_grad()
#                 seg_loss.backward(retain_graph=True)
#                 optimizer_unet.step()
            
            
#             valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device=device,dtype=torch.float32)
#             fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device=device,dtype=torch.float32)
            
#             # discriminator loss
#             disc_loss = BCELoss(pred_valid,valid) + BCELoss(pred_fake,fake)

#             optimizer_disc.zero_grad()
#             disc_loss.backward(retain_graph=False)
#             optimizer_disc.step()
            
#             running_loss_seg += seg_loss.item()
#             running_loss_disc += disc_loss.item()
#             step += 1
            
#             if step % 10 == 9:    # print every 10 mini-batches
#                 print('[%d, %5d] segmentation loss: %.3f; discrimination loss: %.3f' %(epoch + 1, step + 1, running_loss_seg / 10, running_loss_disc / 10))
#                 running_loss_seg = 0.0
#                 running_loss_disc = 0.0
                
#             if step%100 == 99:
#                 val_loss_seg = 0
#                 val_loss_disc = 0
#                 for batch in valid_loader:
#                     imgs,labels = batch
#                     imgs = imgs.to(device=device,dtype=torch.float32)
#                     labels = labels.to(device=device,dtype=torch.float32)
                    
#                     with torch.no_grad():
#                         pred_seg = net(imgs)
#                         pred_valid = discriminator(imgs, labels)
#                         pred_fake = discriminator(imgs, pred_seg.detach())
                    
#                     val_loss_seg += DiceLoss(pred_seg,labels).item()
#                     valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device=device,dtype=torch.float32)
#                     fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device=device,dtype=torch.float32)
#                     val_loss_disc += (BCELoss(pred_valid,valid) + BCELoss(pred_fake,fake)).item()
#                 print('[%d, %5d] validation loss(seg): %.3f; validation loss(disc): %.3f' %(epoch + 1, step + 1, val_loss_seg / len(valid_loader), val_loss_disc / len(valid_loader)))
#                 scheduler_unet.step(val_loss_seg)
#                 scheduler_disc.step(val_loss_disc)

In [2]:
def train_adversarial(net,discriminator,device,val_per=0.1,epochs=10,batch_size=10,resize_to=None,alpha=1):
    if resize_to is not None:
        transform_image = transforms.Compose([
        transforms.Resize(resize_to),
        transforms.ToTensor(),
        transforms.Normalize(0.5,0.5)
        ])
        transform_label = transforms.Compose([
        transforms.Resize(resize_to),
        transforms.ToTensor()
        ])
    else:
        transform_image = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(0.5,0.5)
        ])
        transform_label = transforms.Compose([
        transforms.ToTensor()
        ])
            
    
    dataSet = UltraSoundDataSet(root_dir,(transform_image,transform_label))
    nTrain = int(len(dataSet)*(1-val_per))
    nValid = int(len(dataSet)-nTrain)
    
    trainSet,validSet = random_split(dataSet,[nTrain,nValid])
    
    train_loader = DataLoader(trainSet,batch_size=batch_size,shuffle=True,num_workers=4)
    valid_loader = DataLoader(validSet,batch_size=batch_size,shuffle=True,num_workers=4)
    
    optimizer_unet = torch.optim.Adam(net.parameters())
    optimizer_disc = torch.optim.Adam(discriminator.parameters())
    scheduler_unet = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_unet,mode='min',patience=10) #mae: dice-index
    scheduler_disc = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_disc,mode='min',patience=10)
    BCELoss = nn.BCELoss()
    
    running_loss_seg = 0
    running_loss_disc = 0
    step = 0
    np.set_printoptions(precision=2)
    
    isAdv = True;
    for epoch in range(epochs):
        net.train()
        discriminator.train()
        
        for batch in train_loader:
            imgs,labels = batch
            
            imgs = imgs.to(device=device,dtype=torch.float32)
            labels = labels.to(device=device,dtype=torch.float32)
            
            valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device=device,dtype=torch.float32)
            fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device=device,dtype=torch.float32)
            
            if(isAdv):
                # Train adversarial 
                with torch.no_grad():
                    # segmentation
                    pred = net(imgs)
                pred_valid = discriminator(imgs, labels)
                pred_fake = discriminator(imgs, pred.detach())
                
                # discriminator loss
                disc_loss = BCELoss(pred_valid,valid) + BCELoss(pred_fake,fake)
                
                optimizer_disc.zero_grad()
                disc_loss.backward()
                optimizer_disc.step()
                
                running_loss_disc += disc_loss.item()
                
            else:
                # segmentation
                pred = net(imgs)
                with torch.no_grad():
                    pred_fake = discriminator(imgs, pred)
                    
                # segmentation loss
                seg_loss = BCELoss(pred,labels) + alpha*BCELoss(pred_fake,valid)
                
                optimizer_unet.zero_grad()
                seg_loss.backward()
                optimizer_unet.step()
                
                running_loss_seg += seg_loss.item()
              
            step += 1 
            if step % 10 == 9:    # print every 10 mini-batches
                if(isAdv):
                    print('[%d, %5d] discrimination loss: %.3f' %(epoch + 1, step + 1, running_loss_disc / 10))
                    running_loss_disc = 0.0
                else:
                    print('[%d, %5d] segmentation loss: %.3f' %(epoch + 1, step + 1, running_loss_seg / 10))  
                    running_loss_seg = 0.0
                    
            if step%100 == 99:
                val_loss_seg = 0
                val_loss_disc = 0
                for batch in valid_loader:
                    imgs,labels = batch
                    imgs = imgs.to(device=device,dtype=torch.float32)
                    labels = labels.to(device=device,dtype=torch.float32)
                    
                    with torch.no_grad():
                        pred_seg = net(imgs)
                        pred_valid = discriminator(imgs, labels)
                        pred_fake = discriminator(imgs, pred_seg.detach())
                    
                    val_loss_seg += (BCELoss(pred_seg,labels) + alpha*BCELoss(pred_fake,fake)).item()
                    valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device=device,dtype=torch.float32)
                    fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device=device,dtype=torch.float32)
                    val_loss_disc += (BCELoss(pred_valid,valid) + BCELoss(pred_fake,fake)).item()
                print('[%d, %5d] validation loss(seg): %.3f; validation loss(disc): %.3f' %(epoch + 1, step + 1, val_loss_seg / len(valid_loader), val_loss_disc / len(valid_loader)))
                scheduler_unet.step(val_loss_seg)
                scheduler_disc.step(val_loss_disc)
                
                if(isAdv):
                    scheduler_disc.step(val_loss_disc)
                else:
                    scheduler_unet.step(val_loss_seg)
                
                isAdv = not isAdv
                

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = UNet2(init_features=32)
discriminator = Discriminator()

net = net.to(device)
discriminator = discriminator.to(device)

#IMG_SIZE = [256,256]
IMG_SIZE = None

root_dir = os.path.expanduser("~/workspace/us_robot/DataSet/SimRealDataset")
try:
    #train(net=net,device=device,resize_to=None,epochs=50,batch_size=5)
    train_adversarial(net=net,discriminator=discriminator,device=device,resize_to=IMG_SIZE,epochs=20,batch_size=3)
except KeyboardInterrupt:
    sys.exit()

[1,    10] discrimination loss: 0.485
[1,    20] discrimination loss: 0.003
[1,    30] discrimination loss: 0.001
[1,    40] discrimination loss: 0.000
[1,    50] discrimination loss: 0.000
[1,    60] discrimination loss: 0.000
[1,    70] discrimination loss: 0.000
[1,    80] discrimination loss: 0.000
[1,    90] discrimination loss: 0.000
[1,   100] discrimination loss: 0.000
[1,   100] validation loss(seg): 0.770; validation loss(disc): 0.000
[1,   110] segmentation loss: 9.807
[1,   120] segmentation loss: 0.995
[1,   130] segmentation loss: 0.423
[1,   140] segmentation loss: 0.381
[1,   150] segmentation loss: 0.357
[1,   160] segmentation loss: 0.332
[1,   170] segmentation loss: 0.308
[1,   180] segmentation loss: 0.285
[1,   190] segmentation loss: 0.264
[1,   200] segmentation loss: 0.246
[1,   200] validation loss(seg): 98.834; validation loss(disc): 98.597
[1,   210] discrimination loss: 25.356
[1,   220] discrimination loss: 0.004
[1,   230] discrimination loss: 0.008
[1,  

[3,  1870] discrimination loss: 0.000
[3,  1880] discrimination loss: 0.000
[3,  1890] discrimination loss: 0.000
[3,  1900] discrimination loss: 0.000
[3,  1900] validation loss(seg): 0.009; validation loss(disc): 0.000
[3,  1910] segmentation loss: 23.771
[3,  1920] segmentation loss: 21.432
[3,  1930] segmentation loss: 19.320
[3,  1940] segmentation loss: 17.673
[3,  1950] segmentation loss: 17.074
[3,  1960] segmentation loss: 18.274
[3,  1970] segmentation loss: 15.027
[3,  1980] segmentation loss: 14.801
[3,  1990] segmentation loss: 15.151
[3,  2000] segmentation loss: 15.584
[3,  2000] validation loss(seg): 0.008; validation loss(disc): 0.001
[3,  2010] discrimination loss: 0.000
[3,  2020] discrimination loss: 0.001
[3,  2030] discrimination loss: 0.000
[3,  2040] discrimination loss: 0.000
[3,  2050] discrimination loss: 0.000
[3,  2060] discrimination loss: 0.001
[3,  2070] discrimination loss: 0.000
[3,  2080] discrimination loss: 0.000
[3,  2090] discrimination loss: 0.00

[5,  3720] segmentation loss: 9.876
[5,  3730] segmentation loss: 9.976
[5,  3740] segmentation loss: 9.349
[5,  3750] segmentation loss: 8.527
[5,  3760] segmentation loss: 8.119
[5,  3770] segmentation loss: 9.317
[5,  3780] segmentation loss: 7.263
[5,  3790] segmentation loss: 7.641
[5,  3800] segmentation loss: 8.029
[5,  3800] validation loss(seg): 0.065; validation loss(disc): 0.064
[5,  3810] discrimination loss: 0.137
[5,  3820] discrimination loss: 0.043
[5,  3830] discrimination loss: 0.032
[5,  3840] discrimination loss: 0.067
[5,  3850] discrimination loss: 0.056
[5,  3860] discrimination loss: 0.043
[5,  3870] discrimination loss: 0.101
[5,  3880] discrimination loss: 0.021
[5,  3890] discrimination loss: 0.055
[5,  3900] discrimination loss: 0.128
[5,  3900] validation loss(seg): 0.056; validation loss(disc): 0.055
[5,  3910] segmentation loss: 8.056
[5,  3920] segmentation loss: 9.914
[5,  3930] segmentation loss: 8.040
[5,  3940] segmentation loss: 7.306
[5,  3950] seg

[7,  5600] validation loss(seg): 0.082; validation loss(disc): 0.082
[7,  5610] discrimination loss: 0.079
[7,  5620] discrimination loss: 0.177
[7,  5630] discrimination loss: 0.018
[7,  5640] discrimination loss: 0.080
[7,  5650] discrimination loss: 0.076
[7,  5660] discrimination loss: 0.120
[7,  5670] discrimination loss: 0.030
[7,  5680] discrimination loss: 0.060
[7,  5690] discrimination loss: 0.054
[7,  5700] discrimination loss: 0.034
[7,  5700] validation loss(seg): 0.081; validation loss(disc): 0.081
[7,  5710] segmentation loss: 7.300
[7,  5720] segmentation loss: 6.891
[7,  5730] segmentation loss: 6.243
[7,  5740] segmentation loss: 7.642
[7,  5750] segmentation loss: 6.424
[7,  5760] segmentation loss: 5.193
[7,  5770] segmentation loss: 5.487
[7,  5780] segmentation loss: 6.432
[7,  5790] segmentation loss: 6.640
[7,  5800] segmentation loss: 6.707
[7,  5800] validation loss(seg): 0.080; validation loss(disc): 0.080
[7,  5810] discrimination loss: 0.011
[7,  5820] disc

[9,  7470] discrimination loss: 0.026
[9,  7480] discrimination loss: 0.053
[9,  7490] discrimination loss: 0.024
[9,  7500] discrimination loss: 0.038
[9,  7500] validation loss(seg): 0.063; validation loss(disc): 0.064
[9,  7510] segmentation loss: 7.081
[9,  7520] segmentation loss: 7.445
[9,  7530] segmentation loss: 7.155
[9,  7540] segmentation loss: 6.891
[9,  7550] segmentation loss: 7.322
[9,  7560] segmentation loss: 7.123
[9,  7570] segmentation loss: 6.769
[9,  7580] segmentation loss: 7.468
[9,  7590] segmentation loss: 7.040
[9,  7600] segmentation loss: 8.229
[9,  7600] validation loss(seg): 0.061; validation loss(disc): 0.062
[9,  7610] discrimination loss: 0.046
[9,  7620] discrimination loss: 0.038
[9,  7630] discrimination loss: 0.025
[9,  7640] discrimination loss: 0.015
[9,  7650] discrimination loss: 0.008
[9,  7660] discrimination loss: 0.046
[9,  7670] discrimination loss: 0.016
[9,  7680] discrimination loss: 0.077
[9,  7690] discrimination loss: 0.051
[9,  770

[11,  9310] segmentation loss: 6.107
[11,  9320] segmentation loss: 5.579
[11,  9330] segmentation loss: 7.063
[11,  9340] segmentation loss: 6.762
[11,  9350] segmentation loss: 7.071
[11,  9360] segmentation loss: 7.199
[11,  9370] segmentation loss: 6.957
[11,  9380] segmentation loss: 7.488
[11,  9390] segmentation loss: 6.922
[11,  9400] segmentation loss: 7.307
[11,  9400] validation loss(seg): 0.048; validation loss(disc): 0.050
[11,  9410] discrimination loss: 0.050
[11,  9420] discrimination loss: 0.021
[11,  9430] discrimination loss: 0.029
[11,  9440] discrimination loss: 0.026
[11,  9450] discrimination loss: 0.037
[11,  9460] discrimination loss: 0.013
[11,  9470] discrimination loss: 0.029
[11,  9480] discrimination loss: 0.015
[11,  9490] discrimination loss: 0.041
[11,  9500] discrimination loss: 0.042
[11,  9500] validation loss(seg): 0.050; validation loss(disc): 0.052
[11,  9510] segmentation loss: 6.691
[11,  9520] segmentation loss: 6.282
[11,  9530] segmentation l

[13, 11140] segmentation loss: 7.771
[13, 11150] segmentation loss: 7.121
[13, 11160] segmentation loss: 5.970
[13, 11170] segmentation loss: 7.964
[13, 11180] segmentation loss: 7.371
[13, 11190] segmentation loss: 7.336
[13, 11200] segmentation loss: 6.593
[13, 11200] validation loss(seg): 0.048; validation loss(disc): 0.051
[13, 11210] discrimination loss: 0.069
[13, 11220] discrimination loss: 0.079
[13, 11230] discrimination loss: 0.034
[13, 11240] discrimination loss: 0.022
[13, 11250] discrimination loss: 0.026
[13, 11260] discrimination loss: 0.041
[13, 11270] discrimination loss: 0.031
[13, 11280] discrimination loss: 0.036
[13, 11290] discrimination loss: 0.025
[13, 11300] discrimination loss: 0.024
[13, 11300] validation loss(seg): 0.036; validation loss(disc): 0.039
[13, 11310] segmentation loss: 7.071
[13, 11320] segmentation loss: 6.722
[13, 11330] segmentation loss: 7.527
[13, 11340] segmentation loss: 6.230
[13, 11350] segmentation loss: 7.011
[13, 11360] segmentation l

[15, 12970] segmentation loss: 6.790
[15, 12980] segmentation loss: 7.489
[15, 12990] segmentation loss: 8.375
[15, 13000] segmentation loss: 8.897
[15, 13000] validation loss(seg): 0.037; validation loss(disc): 0.041
[15, 13010] discrimination loss: 0.026
[15, 13020] discrimination loss: 0.024
[15, 13030] discrimination loss: 0.040
[15, 13040] discrimination loss: 0.068
[15, 13050] discrimination loss: 0.020
[15, 13060] discrimination loss: 0.019
[15, 13070] discrimination loss: 0.054
[15, 13080] discrimination loss: 0.045
[15, 13090] discrimination loss: 0.036
[15, 13100] discrimination loss: 0.011
[15, 13100] validation loss(seg): 0.038; validation loss(disc): 0.042
[15, 13110] segmentation loss: 7.052
[15, 13120] segmentation loss: 7.839
[15, 13130] segmentation loss: 7.677
[15, 13140] segmentation loss: 6.450
[15, 13150] segmentation loss: 7.290
[15, 13160] segmentation loss: 7.913
[15, 13170] segmentation loss: 7.340
[15, 13180] segmentation loss: 7.909
[15, 13190] segmentation l

[17, 14800] segmentation loss: 7.851
[17, 14800] validation loss(seg): 0.028; validation loss(disc): 0.034
[17, 14810] discrimination loss: 0.035
[17, 14820] discrimination loss: 0.029
[17, 14830] discrimination loss: 0.010
[17, 14840] discrimination loss: 0.036
[17, 14850] discrimination loss: 0.037
[17, 14860] discrimination loss: 0.007
[17, 14870] discrimination loss: 0.033
[17, 14880] discrimination loss: 0.020
[17, 14890] discrimination loss: 0.025
[17, 14900] discrimination loss: 0.017
[17, 14900] validation loss(seg): 0.029; validation loss(disc): 0.035
[17, 14910] segmentation loss: 7.808
[17, 14920] segmentation loss: 7.362
[17, 14930] segmentation loss: 6.680
[17, 14940] segmentation loss: 6.189
[17, 14950] segmentation loss: 7.175
[17, 14960] segmentation loss: 7.234
[17, 14970] segmentation loss: 8.691
[17, 14980] segmentation loss: 8.308
[17, 14990] segmentation loss: 7.892
[17, 15000] segmentation loss: 6.539
[17, 15000] validation loss(seg): 0.030; validation loss(disc):

[19, 16610] discrimination loss: 0.063
[19, 16620] discrimination loss: 0.009
[19, 16630] discrimination loss: 0.016
[19, 16640] discrimination loss: 0.010
[19, 16650] discrimination loss: 0.017
[19, 16660] discrimination loss: 0.025
[19, 16670] discrimination loss: 0.012
[19, 16680] discrimination loss: 0.019
[19, 16690] discrimination loss: 0.021
[19, 16700] discrimination loss: 0.032
[19, 16700] validation loss(seg): 0.026; validation loss(disc): 0.034
[19, 16710] segmentation loss: 6.727
[19, 16720] segmentation loss: 7.521
[19, 16730] segmentation loss: 8.479
[19, 16740] segmentation loss: 8.609
[19, 16750] segmentation loss: 7.389
[19, 16760] segmentation loss: 7.613
[19, 16770] segmentation loss: 8.227
[19, 16780] segmentation loss: 8.398
[19, 16790] segmentation loss: 7.902
[19, 16800] segmentation loss: 8.548
[19, 16800] validation loss(seg): 0.027; validation loss(disc): 0.034
[19, 16810] discrimination loss: 0.082
[19, 16820] discrimination loss: 0.041
[19, 16830] discrimina

In [6]:
#for python
torch.save(net.state_dict(), './unet_gan_usseg.pth')
torch.save(discriminator.state_dict(), './disc_gan_usseg.pth')
#for c++
traced_script_module = torch.jit.trace(net, img)
traced_script_module.save("./unet_gan_usseg_traced.pt")

## Inference

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

PATH = os.path.expanduser("~/workspace/us_robot/unet_gan_usseg.pth")
net = UNet(init_features=32).to(device)
net.load_state_dict(torch.load(PATH))
net = net.eval()


In [10]:
test_dir = os.path.expanduser("~/workspace/us_robot/DataSet/realDataSet/linear/vessel_neck")
pred_dir = os.path.expanduser("~/workspace/us_robot/DataSet/realDataSet/linear/vessel_pred")

testset_list = os.listdir(test_dir)
testset_list = list(filter(lambda x: x.endswith('jpg'), testset_list))
resize_to=[512,512]
transform_image = transforms.Compose([
    transforms.Resize(resize_to),
    transforms.ToTensor(),
    transforms.Normalize(0.5,0.5) #Division by 255 is done, when the transformation assumes an image.
    ])
transform_label = transforms.Compose([
    #transforms.Resize(resize_to),
    transforms.ToTensor()
    ])
invtransform_label = transforms.Compose([
    transforms.ToPILImage(),
    #transforms.Resize([1000,500])
    ])

for sample in testset_list:
    image_path = os.path.join(test_dir,sample)
    label_path = os.path.join(pred_dir,sample)
    
    img = Image.open(image_path).convert("L")
    img = transform_image(img)
    img = img.to(device)
    img = img.unsqueeze(0)
    #label = Image.open(label_path)
    #label = transform_label(label).to(device)
    #label = label.unsqueeze(0)
    
    
    with torch.no_grad():
        pred = net(img)
    
    #DiceIndex = (1 - DiceLoss(pred,label)).cpu().item()

    pred = invtransform_label(pred.cpu().squeeze(0))
    #fname = "pred%.2f.png"%DiceIndex
    fname = sample
    sav_path = os.path.join(pred_dir,fname)
    pred.save(sav_path)

In [5]:
test_dir = os.path.expanduser("~/workspace/us_robot/DataSet/SimRealDatasetTest")
testset_list = os.listdir(test_dir)
resize_to=None
transform_image = transforms.Compose([
    #transforms.Resize(resize_to),
    transforms.ToTensor(),
    transforms.Normalize(0.5,0.5) #Division by 255 is done, when the transformation assumes an image.
    ])
transform_label = transforms.Compose([
    #transforms.Resize(resize_to),
    transforms.ToTensor()
    ])
invtransform_label = transforms.Compose([
    transforms.ToPILImage(),
    #transforms.Resize([1000,500])
    ])
for sample in testset_list:
    image_path = os.path.join(test_dir,sample,"image.png")
    label_path = os.path.join(test_dir,sample,"label.png")
    
    img = Image.open(image_path)
    img = transform_image(img)
    img = img.to(device)
    img = img.unsqueeze(0)
    label = Image.open(label_path)
    label = transform_label(label).to(device)
    label = label.unsqueeze(0)
    
    
    with torch.no_grad():
        pred = net(img)
    
    DiceIndex = (1 - DiceLoss(pred,label)).cpu().item()

    pred = invtransform_label(pred.cpu().squeeze(0))
    fname = "pred%.2f.png"%DiceIndex
    sav_path = os.path.join(test_dir,sample,fname)
    pred.save(sav_path)