In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms

from modules.UNet import *
from modules.Discriminator import *
from modules.DataSet import *
from modules.Losses import *

import os,sys

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import random_split
from PIL import Image,ImageEnhance

import numpy as np

## Training

### With Adversarial Block

In [2]:
def train_adversarial(net,discriminator,device,val_per=0.1,epochs=10,batch_size=10,resize_to=None,alpha=1):
    if resize_to is not None:
        transform_image = transforms.Compose([
        transforms.Resize(resize_to),
        transforms.ToTensor(),
        transforms.Normalize(0.5,0.5)
        ])
        transform_label = transforms.Compose([
        transforms.Resize(resize_to),
        #transforms.ToTensor()
        ])
    else:
        transform_image = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(0.5,0.5)
        ])
        transform_label = transforms.Compose([
        #transforms.ToTensor()
        ])
            
    
    dataSet = UltraSoundDataSet2(root_dir,(transform_image,transform_label))
    nTrain = int(len(dataSet)*(1-val_per))
    nValid = int(len(dataSet)-nTrain)
    
    trainSet,validSet = random_split(dataSet,[nTrain,nValid])
    
    train_loader = DataLoader(trainSet,batch_size=batch_size,shuffle=True,num_workers=4)
    valid_loader = DataLoader(validSet,batch_size=batch_size,shuffle=True,num_workers=4)
    
    optimizer_unet = torch.optim.Adam(net.parameters())
    optimizer_disc = torch.optim.Adam(discriminator.parameters())
    scheduler_unet = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_unet,mode='min',patience=10) #mae: dice-index
    scheduler_disc = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_disc,mode='min',patience=10)
    BCELoss = nn.BCELoss()
    
    running_loss_seg = 0
    running_loss_disc = 0
    step = 0
    np.set_printoptions(precision=2)
    
    isAdv = True;
    for epoch in range(epochs):
        net.train()
        discriminator.train()
        
        for batch in train_loader:
            imgs,labels = batch
            
            imgs = imgs.to(device=device,dtype=torch.float32)
            labels = labels.to(device=device,dtype=torch.float32)

            valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device=device,dtype=torch.float32)
            fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device=device,dtype=torch.float32)
            
            if(isAdv):
                # Train adversarial 
                with torch.no_grad():
                    # segmentation
                    pred = net(imgs)
                pred_valid = discriminator(imgs, labels)
                pred_fake = discriminator(imgs, pred.detach())
                print("valid: ",np.array(pred_valid.view([1,-1]).tolist()),"; fake: ",np.array(pred_fake.view([1,-1]).tolist()) )
                
                # discriminator loss
                disc_loss = BCELoss(pred_valid,valid) + BCELoss(pred_fake,fake)
                
                optimizer_disc.zero_grad()
                disc_loss.backward()
                optimizer_disc.step()
                
                running_loss_disc += disc_loss.item()
                
            else:
                # segmentation
                pred = net(imgs)
                with torch.no_grad():
                    pred_fake = discriminator(imgs, pred.detach())
                print("fake: ",np.array(pred_fake.view([1,-1]).tolist()) )
                
                # segmentation loss
                #seg_loss = BCELoss(pred,labels) + alpha*BCELoss(pred_fake,valid)
                seg_loss = BCELoss(pred,labels) + alpha*BCELoss(pred_fake,valid)
                
                optimizer_unet.zero_grad()
                seg_loss.backward()
                optimizer_unet.step()
                
                running_loss_seg += seg_loss.item()
              
            step += 1 
            if step % 10 == 9:    # print every 10 mini-batches
                if(isAdv):
                    print('[%d, %5d] discrimination loss: %.3f' %(epoch + 1, step + 1, running_loss_disc / 10))
                    running_loss_disc = 0.0
                else:
                    print('[%d, %5d] segmentation loss: %.3f' %(epoch + 1, step + 1, running_loss_seg / 10))  
                    running_loss_seg = 0.0
                    
            if step%100 == 99:
                val_loss_seg = 0
                val_loss_disc = 0
                for batch in valid_loader:
                    imgs,labels = batch
                    imgs = imgs.to(device=device,dtype=torch.float32)
                    labels = labels.to(device=device,dtype=torch.float32)
                    
                    with torch.no_grad():
                        pred_seg = net(imgs)
                        pred_valid = discriminator(imgs, labels)
                        pred_fake = discriminator(imgs, pred_seg.detach())
                    
                    valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device=device,dtype=torch.float32)
                    fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device=device,dtype=torch.float32)
                    val_loss_seg += (BCELoss(pred_seg,labels) + alpha*BCELoss(pred_fake,fake)).item()
                    val_loss_disc += (BCELoss(pred_valid,valid) + BCELoss(pred_fake,fake)).item()
                print('[%d, %5d] validation loss(seg): %.3f; validation loss(disc): %.3f' %(epoch + 1, step + 1, val_loss_seg / len(valid_loader), val_loss_disc / len(valid_loader)))
                scheduler_unet.step(val_loss_seg)
                scheduler_disc.step(val_loss_disc)
                
                if(isAdv):
                    scheduler_disc.step(val_loss_disc)
                else:
                    scheduler_unet.step(val_loss_seg)
                
                isAdv = not isAdv
                

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = UNet3(init_features=32)
discriminator = Discriminator()

net = net.to(device)
discriminator = discriminator.to(device)

IMG_SIZE = [256,256]
#IMG_SIZE = None

#root_dir = os.path.expanduser("~/workspace/us_robot/DataSet/SimRealDataset")
root_dir = os.path.expanduser("~/workspace/us_robot/DataSet/realDataSet/linear/vessel_dataset")
try:
    #train(net=net,device=device,resize_to=None,epochs=50,batch_size=5)
    train_adversarial(net=net,discriminator=discriminator,device=device,resize_to=IMG_SIZE,epochs=20,batch_size=10)
except KeyboardInterrupt:
    sys.exit()

valid:  [[0.49 0.49 0.49 0.49 0.49 0.49 0.49 0.49 0.49 0.49]] ; fake:  [[0.49 0.49 0.49 0.49 0.49 0.49 0.49 0.49 0.49 0.49]]
valid:  [[0.55 0.55 0.55 0.55 0.55 0.55 0.55 0.55 0.55 0.55]] ; fake:  [[0.51 0.51 0.51 0.51 0.51 0.51 0.51 0.51 0.51 0.51]]
valid:  [[0.63 0.63 0.63 0.63 0.62 0.62 0.63 0.62 0.62 0.63]] ; fake:  [[0.5  0.5  0.5  0.49 0.49 0.5  0.5  0.49 0.5  0.5 ]]
valid:  [[0.65 0.65 0.69 0.69 0.65 0.69 0.68 0.64 0.64 0.67]] ; fake:  [[0.39 0.39 0.39 0.39 0.39 0.39 0.39 0.39 0.39 0.39]]
valid:  [[0.82 0.82 0.86 0.83 0.83 0.83 0.82 0.83 0.82 0.86]] ; fake:  [[0.33 0.33 0.33 0.33 0.32 0.33 0.33 0.32 0.33 0.33]]
valid:  [[0.87 0.86 0.86 0.87 0.82 0.89 0.91 0.92 0.91 0.91]] ; fake:  [[0.12 0.13 0.13 0.13 0.13 0.13 0.13 0.13 0.13 0.13]]
valid:  [[0.96 0.97 0.96 0.99 0.97 0.99 0.97 0.96 0.99 0.98]] ; fake:  [[0.04 0.04 0.04 0.04 0.04 0.04 0.04 0.04 0.04 0.05]]
valid:  [[1.   1.   1.   1.   1.   1.   0.99 1.   1.   1.  ]] ; fake:  [[0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01 0.01]]


valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[6.27e-20 8.64e-21 1.91e-20 2.44e-20 4.76e-21 1.02e-19 4.41e-20 3.62e-20
  1.23e-20 1.21e-19]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1.68e-19 4.09e-20 1.19e-19 1.21e-20 4.03e-20 1.02e-19 1.47e-20 1.46e-19
  7.14e-21 1.45e-20]]
[1,    60] discrimination loss: 0.000
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1.06e-20 4.26e-20 3.80e-20 1.49e-20 6.39e-20 9.50e-21 7.80e-21 6.15e-20
  3.92e-19 2.32e-19]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[7.05e-20 3.97e-21 3.71e-20 4.22e-20 8.09e-21 2.04e-19 1.21e-20 8.33e-21
  2.82e-21 1.86e-21]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[2.85e-20 4.18e-20 8.11e-20 1.56e-18 2.06e-20 9.88e-21 3.23e-20 1.54e-19
  7.29e-20 9.17e-21]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[8.13e-20 3.01e-20 1.68e-19 9.90e-21 4.07e-21 2.94e-20 1.46e-20 3.13e-20
  1.39e-19 1.47e-19]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[7.67e-21 3.54e-20 1.13e-20 

fake:  [[1.   1.   0.99 1.   0.99 1.   0.91 1.   1.   1.  ]]
[2,   120] segmentation loss: 5.348
fake:  [[1.   1.   0.97 1.   1.   1.   1.   1.   1.   1.  ]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
[2,   130] segmentation loss: 0.415
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
[2,   140] segmentation los

valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
[3,   260] discrimination loss: 100.000
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 

fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
[4,   390] segmentation loss: 0.092
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
[4,   400] segmentation loss: 0.084
[4,   400] validation loss(seg): 100.080; validation loss(disc): 100.000
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1.

valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
[5,   490] discrimination loss: 100.000
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 

valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
[6,   630] discrimination loss: 100.000
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 

fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
[7,   740] segmentation loss: 0.033
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
[7,   750] segmentation loss: 0.031
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1.

valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
[8,   870] discrimination loss: 100.000
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 

valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
[10,  1010] discrimination loss: 100.000
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1.

valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
[10,  1100] discrimination loss: 100.000
[10,  1100] validation loss(seg): 100.023; validation loss(disc): 100.000
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
[10,  1110] segmentation loss: 0.023
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1

valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
[12,  1240] discrimination loss: 100.000
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
[12,  1250] discriminati

fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
[13,  1370] segmentation loss: 0.021
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
[13,  1380] segmentation loss: 0.021
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 

valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
[14,  1480] discrimination loss: 100.000
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1.

valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
[15,  1620] discrimination loss: 100.000
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1.

fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
[16,  1720] segmentation loss: 0.019
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
[16,  1730] segmentation loss: 0.019
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 

valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
[17,  1850] discrimination loss: 100.000
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
[17,  1860] discriminati

fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
[18,  1990] segmentation loss: 0.019
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
[18,  2000] segmentation loss: 0.019
[18,  2000] validation loss(seg): 100.019; validation loss(disc): 100.000
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1.

valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
[19,  2090] discrimination loss: 100.000
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1.

valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
[20,  2230] discrimination loss: 100.000
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]] ; fake:  [[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
valid:  [[1. 1. 1. 1. 1.

In [6]:
#for python
torch.save(net.state_dict(), './unet_gan_usseg.pth')
torch.save(discriminator.state_dict(), './disc_gan_usseg.pth')
#for c++
traced_script_module = torch.jit.trace(net, img)
traced_script_module.save("./unet_gan_usseg_traced.pt")

## Inference

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

PATH = os.path.expanduser("~/workspace/us_robot/unet_gan_usseg.pth")
net = UNet(init_features=64).to(device)
net.load_state_dict(torch.load(PATH))
net = net.eval()


In [5]:
test_dir = os.path.expanduser("~/workspace/us_robot/DataSet/realDataSet/linear/vessel")
pred_dir = os.path.expanduser("~/workspace/us_robot/DataSet/realDataSet/linear/vessel_pred")

testset_list = os.listdir(test_dir)
testset_list = list(filter(lambda x: x.endswith('jpg'), testset_list))
resize_to=[256,256]
transform_image = transforms.Compose([
    transforms.Resize(resize_to),
    transforms.ToTensor(),
    transforms.Normalize(0.5,0.5) #Division by 255 is done, when the transformation assumes an image.
    ])
transform_label = transforms.Compose([
    transforms.Resize(resize_to),
    transforms.ToTensor()
    ])
invtransform_label = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize([512,512])
    ])

for sample in testset_list:
    image_path = os.path.join(test_dir,sample)
    label_path = os.path.join(pred_dir,sample)
    
    img = Image.open(image_path).convert("L")
    img = ImageEnhance.Contrast(img).enhance(1.5)
    img = transform_image(img)
    img = img.to(device)
    img = img.unsqueeze(0)
    #label = Image.open(label_path)
    #label = transform_label(label).to(device)
    #label = label.unsqueeze(0)
    
    
    with torch.no_grad():
        pred = net(img)
    
    #DiceIndex = (1 - DiceLoss(pred,label)).cpu().item()

    pred = invtransform_label(pred.cpu().squeeze(0))
    #fname = "pred%.2f.png"%DiceIndex
    fname = sample
    sav_path = os.path.join(pred_dir,fname)
    pred.save(sav_path)

In [6]:
test_dir = os.path.expanduser("~/workspace/us_robot/DataSet/SimRealDatasetTest")
testset_list = os.listdir(test_dir)
resize_to=[256,256]
transform_image = transforms.Compose([
    transforms.Resize(resize_to),
    transforms.ToTensor(),
    transforms.Normalize(0.5,0.5) #Division by 255 is done, when the transformation assumes an image.
    ])
transform_label = transforms.Compose([
    transforms.Resize(resize_to),
    transforms.ToTensor()
    ])
invtransform_label = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize([512,512])
    ])
for sample in testset_list:
    image_path = os.path.join(test_dir,sample,"image.png")
    label_path = os.path.join(test_dir,sample,"label.png")
    
    img = Image.open(image_path)
    img = transform_image(img)
    img = img.to(device)
    img = img.unsqueeze(0)
    label = Image.open(label_path)
    label = transform_label(label).to(device)
    label = label.unsqueeze(0)
    
    
    with torch.no_grad():
        pred = net(img)
    
    DiceIndex = (1 - DiceLoss(pred,label)).cpu().item()

    pred = invtransform_label(pred.cpu().squeeze(0))
    fname = "pred%.2f.png"%DiceIndex
    sav_path = os.path.join(test_dir,sample,fname)
    pred.save(sav_path)