In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

import os,sys

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import random_split
from PIL import Image

import numpy as np

In [2]:
class UNet(nn.Module):
    def __init__(self,in_channels=1,out_channels=1,init_features=64):
        super(UNet, self).__init__()

        features = init_features
        self.encoder1 = UNet.DoubleConv2d(in_channels,features)
        self.pool1 = nn.MaxPool2d(2, 2)
        
        self.encoder2 = UNet.DoubleConv2d(features,2*features)
        self.pool2 = nn.MaxPool2d(2, 2)

        self.encoder3 = UNet.DoubleConv2d(2*features,4*features)
        self.pool3 = nn.MaxPool2d(2, 2)
        
        self.encoder4 = UNet.DoubleConv2d(4*features,8*features)
        self.pool4 = nn.MaxPool2d(2, 2)
        
        self.bottleneck = UNet.DoubleConv2d(8*features,16*features)
        
        self.upconv4 = nn.ConvTranspose2d(16*features,8*features,kernel_size=2,stride=2)        
        self.decoder4 = UNet.DoubleConv2d(16*features,8*features) #concate, 2*8=16
        
        self.upconv3 = nn.ConvTranspose2d(8*features,4*features,kernel_size=2,stride=2)
        self.decoder3 = UNet.DoubleConv2d(8*features,4*features) #concate, 2*4=8
        
        self.upconv2 = nn.ConvTranspose2d(4*features,2*features,kernel_size=2,stride=2)
        self.decoder2 = UNet.DoubleConv2d(4*features,2*features) #concate, 2*2=4
        
        
        self.upconv1 = nn.ConvTranspose2d(2*features,features,kernel_size=2,stride=2)
        self.decoder1 = UNet.DoubleConv2d(2*features,features) #concate, 2*1=2
        
        self.conv_out = nn.Conv2d(features, 1, 1)
        
    def forward(self, input):
        enc1 = self.encoder1(input)
        
        enc2 = self.encoder2(self.pool1(enc1))
        
        enc3 = self.encoder3(self.pool2(enc2))
        
        enc4 = self.encoder4(self.pool3(enc3))
        
        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = torch.cat([enc4,self.upconv4(bottleneck)],dim=1)
        dec4 = self.decoder4(dec4)
        
        dec3 = torch.cat([enc3,self.upconv3(dec4)],dim=1)
        dec3 = self.decoder3(dec3)
        
        dec2 = torch.cat([enc2,self.upconv2(dec3)],dim=1)
        dec2 = self.decoder2(dec2)
        
        dec1 = torch.cat([enc1,self.upconv1(dec2)],dim=1)
        dec1 = self.decoder1(dec1)
        
        output = torch.sigmoid(self.conv_out(dec1))
        
        return output
    
    def DoubleConv2d(in_channels,features):
        return nn.Sequential(
                nn.Conv2d(in_channels, features, kernel_size=3, padding=1),
                nn.BatchNorm2d(features),
                nn.ReLU(inplace=True),
                nn.Conv2d(features, features, kernel_size=3, padding=1),
                nn.BatchNorm2d(features),
                nn.ReLU(inplace=True)
        )

In [3]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.img_conv1 = nn.Conv2d(1, 16, kernel_size = 3, padding=1)
        self.img_conv2 = nn.Conv2d(16, 64, kernel_size = 3, padding=1)
        self.img_pool = nn.MaxPool2d(2, 2)
        
        self.mask_conv = nn.Conv2d(1, 64, kernel_size = 3, padding=1)
        self.mask_pool = nn.MaxPool2d(2, 2)
        
        self.dis_conv1 = nn.Conv2d(128, 256, kernel_size = 3, padding=1)
        self.dis_pool1 = nn.MaxPool2d(2, 2)
        self.dis_conv2 = nn.Conv2d(256, 512, kernel_size = 3, padding=1)
        self.dis_pool2 = nn.MaxPool2d(2, 2)
        
        self.output = nn.Linear(512,1)
        
        
    def forward(self, input, mask):
        x = self.img_conv1(input)
        x = self.img_conv2(x)
        x = self.img_pool(x)
        
        y = self.mask_conv(mask)
        y = self.mask_pool(y)
        
        x = torch.cat([x,y],dim=1)
        x = self.dis_conv1(x)
        x = self.dis_pool1(x)
        x = self.dis_conv2(x)
        x = self.dis_pool2(x)
        x = F.max_pool2d(x,kernel_size=x.size()[2:]).squeeze()

        output = torch.sigmoid(self.output(x))
        
        return output

In [4]:
class UltraSoundDataSet(Dataset):
    def __init__(self, root_dir, transforms):
        self.root_dir = root_dir
        self.sample_list = os.listdir(root_dir)

        self.transform_image, self.transform_label = transforms
        
    def __len__(self):
        return len(self.sample_list)
    
    def __getitem__(self,idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        image_path = os.path.join(self.root_dir,self.sample_list[idx],"image.png")
        label_path = os.path.join(self.root_dir,self.sample_list[idx],"label.png")
        
        image = Image.open(image_path)
        label = Image.open(label_path)
        
        if self.transform_image is not None:
            image = self.transform_image(image)
            
        if self.transform_label is not None:
            label = self.transform_label(label)
        
        return image,label
        #sample = {"image":image,"label":label}
        
        #return sample
        

In [5]:
def DiceLoss(pred,target,slack=10):
    
    index = (2*torch.sum(pred*target)+slack)/(torch.sum(pred)+torch.sum(target)+slack)
    #if torch.sum(target).item() == 0:
    #print("instersection: ",torch.sum(pred*target).item())
    #print("pred: ",torch.sum(pred).item())
    #print("target: ",torch.sum(target).item())
    #print("Index: ", index.item())
    return 1-index

In [6]:
def train(net,device,val_per=0.1,epochs=10,batch_size=10,resize_to=None):
    if resize_to is not None:
        transform_image = transforms.Compose([
        transforms.Resize(resize_to),
        transforms.ToTensor(),
        transforms.Normalize(0.5,0.5)
        ])
        transform_label = transforms.Compose([
        transforms.Resize(resize_to),
        transforms.ToTensor()
        ])
    else:
        transform_image = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(0.5,0.5)
        ])
        transform_label = transforms.Compose([
        transforms.ToTensor()
        ])
            
    
    dataSet = UltraSoundDataSet(root_dir,(transform_image,transform_label))
    nTrain = int(len(dataSet)*(1-val_per))
    nValid = int(len(dataSet)-nTrain)
    
    trainSet,validSet = random_split(dataSet,[nTrain,nValid])
    
    train_loader = DataLoader(trainSet,batch_size=batch_size,shuffle=True,num_workers=4)
    valid_loader = DataLoader(validSet,batch_size=batch_size,shuffle=True,num_workers=4)
    
    optimizer = torch.optim.Adam(net.parameters())
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='min',patience=10) #mae: dice-index
    
    running_loss = 0
    step = 0
    for epoch in range(epochs):
        net.train()
        for batch in train_loader:
            imgs,labels = batch
            
            imgs = imgs.to(device=device,dtype=torch.float32)
            labels = labels.to(device=device,dtype=torch.float32)
            
            pred = net(imgs)
            
            loss = DiceLoss(pred,labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            step += 1
            
            if step % 10 == 9:    # print every 10 mini-batches
                print('[%d, %5d] loss: %.3f' %(epoch + 1, step + 1, running_loss / 10))
                running_loss = 0.0
                
            if step%50 == 49:
                val_loss = 0
                for batch in valid_loader:
                    imgs,labels = batch
                    imgs = imgs.to(device=device,dtype=torch.float32)
                    labels = labels.to(device=device,dtype=torch.float32)
                    
                    with torch.no_grad():
                        pred = net(imgs)
                    val_loss += DiceLoss(pred,labels)
                print('[%d, %5d] validation loss: %.3f' %(epoch + 1, step + 1, val_loss / len(valid_loader)))
                scheduler.step(val_loss)
            
        
    

In [7]:
def train_adversarial(net,discriminator,device,val_per=0.1,epochs=10,batch_size=10,resize_to=None):
    if resize_to is not None:
        transform_image = transforms.Compose([
        transforms.Resize(resize_to),
        transforms.ToTensor(),
        transforms.Normalize(0.5,0.5)
        ])
        transform_label = transforms.Compose([
        transforms.Resize(resize_to),
        transforms.ToTensor()
        ])
    else:
        transform_image = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(0.5,0.5)
        ])
        transform_label = transforms.Compose([
        transforms.ToTensor()
        ])
            
    
    dataSet = UltraSoundDataSet(root_dir,(transform_image,transform_label))
    nTrain = int(len(dataSet)*(1-val_per))
    nValid = int(len(dataSet)-nTrain)
    
    trainSet,validSet = random_split(dataSet,[nTrain,nValid])
    
    train_loader = DataLoader(trainSet,batch_size=batch_size,shuffle=True,num_workers=4)
    valid_loader = DataLoader(validSet,batch_size=batch_size,shuffle=True,num_workers=4)
    
    optimizer_unet = torch.optim.Adam(net.parameters())
    optimizer_disc = torch.optim.Adam(discriminator.parameters())
    scheduler_unet = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_unet,mode='min',patience=10) #mae: dice-index
    scheduler_disc = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_disc,mode='min',patience=10)
    BCELoss = nn.BCELoss()
    
    running_loss_seg = 0
    running_loss_disc = 0
    step = 0
    np.set_printoptions(precision=2)
    for epoch in range(epochs):
        net.train()
        discriminator.train()
        
        for batch in train_loader:
            imgs,labels = batch
            
            imgs = imgs.to(device=device,dtype=torch.float32)
            labels = labels.to(device=device,dtype=torch.float32)
            
            # segmentation network
            pred = net(imgs)
            
            # discriminator network
            pred_valid = discriminator(imgs, labels)
            pred_fake = discriminator(imgs, pred.detach())   
            print("valid: ",np.array(pred_valid.view([1,-1]).tolist()),"; fake: ",np.array(pred_fake.view([1,-1]).tolist()) )
            # seg loss
            seg_loss = DiceLoss(pred,labels)
            
            optimizer_unet.zero_grad()
            seg_loss.backward(retain_graph=True)
            optimizer_unet.step()
            
            
            valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device=device,dtype=torch.float32)
            fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device=device,dtype=torch.float32)
            
            # discriminator loss
            disc_loss = BCELoss(pred_valid,valid) + BCELoss(pred_fake,fake)

            optimizer_disc.zero_grad()
            disc_loss.backward(retain_graph=False)
            optimizer_disc.step()
            
            running_loss_seg += seg_loss.item()
            running_loss_disc += disc_loss.item()
            step += 1
            
            if step % 10 == 9:    # print every 10 mini-batches
                print('[%d, %5d] segmentation loss: %.3f; discrimination loss: %.3f' %(epoch + 1, step + 1, running_loss_seg / 10, running_loss_disc / 10))
                running_loss_seg = 0.0
                running_loss_disc = 0.0
                
            if step%50 == 49:
                val_loss_seg = 0
                val_loss_disc = 0
                for batch in valid_loader:
                    imgs,labels = batch
                    imgs = imgs.to(device=device,dtype=torch.float32)
                    labels = labels.to(device=device,dtype=torch.float32)
                    
                    with torch.no_grad():
                        pred_seg = net(imgs)
                        pred_valid = discriminator(imgs, labels)
                        pred_fake = discriminator(imgs, pred_seg.detach())
                    
                    val_loss_seg += DiceLoss(pred_seg,labels).item()
                    valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device=device,dtype=torch.float32)
                    fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device=device,dtype=torch.float32)
                    val_loss_disc += (BCELoss(pred_valid,valid) + BCELoss(pred_fake,fake)).item()
                print('[%d, %5d] validation loss(seg): %.3f; validation loss(disc): %.3f' %(epoch + 1, step + 1, val_loss_seg / len(valid_loader), val_loss_disc / len(valid_loader)))
                scheduler_unet.step(val_loss_seg)
                scheduler_disc.step(val_loss_disc)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = UNet(init_features=32)
discriminator = Discriminator()

net = net.to(device)
discriminator = discriminator.to(device)

root_dir = os.path.expanduser("~/workspace/us_robot/DataSet/SimDataset2")
try:
    #train(net=net,device=device,resize_to=None,epochs=50,batch_size=5)
    train_adversarial(net=net,discriminator=discriminator,device=device,resize_to=None,epochs=20,batch_size=3)
except KeyboardInterrupt:
    sys.exit()

valid:  [[0.52 0.55 0.55]] ; fake:  [[0.54 0.54 0.55]]
valid:  [[0.3  0.31 0.31]] ; fake:  [[0.19 0.19 0.18]]
valid:  [[0.96 0.96 0.88]] ; fake:  [[0.9  0.89 0.88]]
valid:  [[0.74 0.74 0.83]] ; fake:  [[0.51 0.5  0.52]]
valid:  [[0.38 0.39 0.6 ]] ; fake:  [[0.13 0.14 0.15]]
valid:  [[0.73 0.68 0.86]] ; fake:  [[0.2  0.2  0.22]]
valid:  [[0.96 0.98 0.96]] ; fake:  [[0.41 0.44 0.43]]
valid:  [[0.92 0.98 0.98]] ; fake:  [[0.09 0.1  0.09]]
valid:  [[0.66 0.64 0.97]] ; fake:  [[0.01 0.01 0.01]]
[1,    10] segmentation loss: 0.856; discrimination loss: 0.857
valid:  [[0.91 0.93 0.99]] ; fake:  [[0.01 0.01 0.01]]
valid:  [[1.   1.   0.99]] ; fake:  [[0.01 0.01 0.01]]
valid:  [[1. 1. 1.]] ; fake:  [[0.02 0.02 0.02]]
valid:  [[1. 1. 1.]] ; fake:  [[0.03 0.04 0.03]]
valid:  [[1. 1. 1.]] ; fake:  [[0.02 0.03 0.03]]
valid:  [[1. 1. 1.]] ; fake:  [[0.02 0.02 0.02]]
valid:  [[1. 1. 1.]] ; fake:  [[0.02 0.02 0.01]]
valid:  [[1. 1. 1.]] ; fake:  [[0. 0. 0.]]
valid:  [[1. 1. 1.]] ; fake:  [[0. 0. 0.]]


valid:  [[1. 1. 1.]] ; fake:  [[0.07 0.03 0.08]]
valid:  [[1. 1. 1.]] ; fake:  [[0.16 0.   0.01]]
valid:  [[1. 1. 1.]] ; fake:  [[0.   0.   0.04]]
valid:  [[1. 1. 1.]] ; fake:  [[1.57e-02 5.61e-05 4.81e-05]]
[1,   130] segmentation loss: 0.721; discrimination loss: 0.481
valid:  [[1. 1. 1.]] ; fake:  [[6.80e-06 7.04e-06 8.79e-06]]
valid:  [[1.   1.   0.99]] ; fake:  [[3.68e-03 3.04e-06 1.39e-06]]
valid:  [[1.   1.   0.96]] ; fake:  [[1.41e-03 1.05e-03 4.41e-07]]
valid:  [[0.8  0.84 0.84]] ; fake:  [[8.04e-08 6.79e-08 7.35e-08]]
valid:  [[1. 1. 1.]] ; fake:  [[5.83e-06 5.28e-06 6.47e-06]]
valid:  [[1. 1. 1.]] ; fake:  [[3.13e-04 5.15e-05 3.51e-05]]
valid:  [[1. 1. 1.]] ; fake:  [[0. 0. 0.]]
valid:  [[1. 1. 1.]] ; fake:  [[0.   0.01 0.  ]]
valid:  [[1. 1. 1.]] ; fake:  [[0.03 0.03 0.12]]
valid:  [[1. 1. 1.]] ; fake:  [[0.05 0.06 0.2 ]]
[1,   140] segmentation loss: 0.709; discrimination loss: 0.039
valid:  [[1. 1. 1.]] ; fake:  [[0.03 0.05 0.02]]
valid:  [[1. 1. 1.]] ; fake:  [[0.01 0.  

valid:  [[0.63 0.69 0.56]] ; fake:  [[0.02 0.01 0.01]]
valid:  [[0.94 0.95 0.89]] ; fake:  [[0.04 0.04 0.05]]
valid:  [[0.98 0.99 0.99]] ; fake:  [[0.16 0.14 0.16]]
valid:  [[0.99 0.99 0.99]] ; fake:  [[0.17 0.15 0.2 ]]
valid:  [[0.99 0.99 0.98]] ; fake:  [[0.1  0.08 0.11]]
valid:  [[0.95 0.9  0.9 ]] ; fake:  [[0.06 0.08 0.06]]
valid:  [[0.85 0.86 0.71]] ; fake:  [[0.04 0.03 0.01]]
valid:  [[0.86 0.94 0.92]] ; fake:  [[0.04 0.09 0.09]]
valid:  [[0.96 0.97 0.97]] ; fake:  [[0.22 0.22 0.13]]
[1,   270] segmentation loss: 0.328; discrimination loss: 0.331
valid:  [[0.9  0.94 0.93]] ; fake:  [[0.02 0.04 0.07]]
valid:  [[0.93 0.95 0.92]] ; fake:  [[0.1  0.06 0.01]]
valid:  [[0.83 0.87 0.91]] ; fake:  [[0.01 0.01 0.01]]
valid:  [[0.97 0.97 0.98]] ; fake:  [[0.03 0.06 0.05]]
valid:  [[0.99 0.99 0.99]] ; fake:  [[0.19 0.08 0.12]]
valid:  [[0.99 0.97 0.99]] ; fake:  [[0.49 0.31 0.58]]
valid:  [[0.2  0.17 0.29]] ; fake:  [[0. 0. 0.]]
valid:  [[0.92 0.93 0.96]] ; fake:  [[0.03 0.01 0.03]]
valid: 

valid:  [[0.81 0.98 0.95]] ; fake:  [[0.17 0.34 0.29]]
valid:  [[0.83 0.45 0.79]] ; fake:  [[0.08 0.03 0.08]]
valid:  [[0.65 0.94 0.72]] ; fake:  [[0.12 0.3  0.15]]
valid:  [[0.97 0.99 0.99]] ; fake:  [[0.3  0.43 0.56]]
valid:  [[0.39 0.49 0.59]] ; fake:  [[0.06 0.04 0.06]]
valid:  [[0.96 0.81 0.96]] ; fake:  [[0.19 0.13 0.3 ]]
[1,   400] segmentation loss: 0.099; discrimination loss: 0.493
[1,   400] validation loss(seg): 0.133; validation loss(disc): 0.475
valid:  [[0.95 0.96 0.96]] ; fake:  [[0.29 0.35 0.34]]
valid:  [[0.49 0.46 0.51]] ; fake:  [[0.01 0.01 0.01]]
valid:  [[0.96 0.98 0.98]] ; fake:  [[0.3  0.48 0.51]]
valid:  [[0.47 0.48 0.46]] ; fake:  [[0.02 0.02 0.01]]
valid:  [[1.   0.99 1.  ]] ; fake:  [[0.88 0.37 0.38]]
valid:  [[0.99 0.99 0.86]] ; fake:  [[0.35 0.45 0.03]]
valid:  [[0.09 0.04 0.38]] ; fake:  [[0.   0.   0.01]]
valid:  [[1.   0.99 0.96]] ; fake:  [[0.85 0.32 0.3 ]]
valid:  [[0.91 0.94 0.96]] ; fake:  [[0.46 0.35 0.38]]
valid:  [[0.11 0.48 0.48]] ; fake:  [[0.  

valid:  [[0.97 0.39 0.52]] ; fake:  [[0.12 0.01 0.02]]
valid:  [[0.99 0.97 0.7 ]] ; fake:  [[0.04 0.03 0.02]]
valid:  [[0.99 0.92 0.99]] ; fake:  [[0.19 0.48 0.21]]
valid:  [[0.99 0.99 0.99]] ; fake:  [[0.15 0.42 0.19]]
valid:  [[0.86 0.96 0.88]] ; fake:  [[0.01 0.06 0.01]]
[1,   530] segmentation loss: 0.139; discrimination loss: 0.385
valid:  [[0.87 0.92 0.89]] ; fake:  [[0.03 0.01 0.03]]
valid:  [[0.61 0.8  0.87]] ; fake:  [[0.02 0.02 0.03]]
valid:  [[0.98 0.94 0.92]] ; fake:  [[0.04 0.05 0.04]]
valid:  [[0.98 0.99 0.96]] ; fake:  [[0.04 0.15 0.75]]
valid:  [[0.94 0.97 0.97]] ; fake:  [[0.64 0.08 0.09]]
valid:  [[0.89 0.85 0.89]] ; fake:  [[0.01 0.01 0.01]]
valid:  [[0.92 0.8  0.74]] ; fake:  [[0.   0.01 0.01]]
valid:  [[0.77 0.81 0.74]] ; fake:  [[0.03 0.   0.22]]
valid:  [[0.81 0.8  0.96]] ; fake:  [[0.69 0.02 0.04]]
valid:  [[0.99 0.93 0.93]] ; fake:  [[0.07 0.1  0.1 ]]
[1,   540] segmentation loss: 0.176; discrimination loss: 0.294
valid:  [[0.99 0.99 0.99]] ; fake:  [[0.15 0.23

valid:  [[0.97 0.64 0.56]] ; fake:  [[0.03 0.42 0.4 ]]
valid:  [[0.97 0.54 0.96]] ; fake:  [[0.02 0.44 0.03]]
valid:  [[0.96 0.67 0.95]] ; fake:  [[0.01 0.03 0.01]]
valid:  [[0.56 0.96 0.84]] ; fake:  [[0.18 0.03 0.  ]]
[2,   660] segmentation loss: 0.054; discrimination loss: 0.508
valid:  [[0.98 0.84 0.49]] ; fake:  [[0.01 0.01 0.04]]
valid:  [[0.48 0.98 0.58]] ; fake:  [[0.22 0.01 0.02]]
valid:  [[0.9  0.41 0.8 ]] ; fake:  [[0.02 0.3  0.  ]]
valid:  [[0.99 0.96 0.93]] ; fake:  [[0.03 0.01 0.01]]
valid:  [[0.92 0.9  0.98]] ; fake:  [[0.09 0.03 0.02]]
valid:  [[0.56 0.99 0.99]] ; fake:  [[0.   0.09 0.02]]
valid:  [[0.99 0.6  0.55]] ; fake:  [[0.04 0.04 0.  ]]
valid:  [[0.55 0.98 0.56]] ; fake:  [[0.   0.02 0.27]]
valid:  [[0.99 0.99 0.83]] ; fake:  [[0.19 0.09 0.05]]
valid:  [[0.68 1.   0.96]] ; fake:  [[0.49 0.24 0.17]]
[2,   670] segmentation loss: 0.049; discrimination loss: 0.342
valid:  [[0.99 1.   0.99]] ; fake:  [[0.09 0.22 0.05]]
valid:  [[0.71 0.72 0.99]] ; fake:  [[0.63 0.63

In [9]:
test_dir = os.path.expanduser("~/workspace/us_robot/DataSet/SimDatasetTest2")
testset_list = os.listdir(test_dir)
resize_to=None
transform_image = transforms.Compose([
    #transforms.Resize(resize_to),
    transforms.ToTensor(),
    transforms.Normalize(0.5,0.5) #Division by 255 is done, when the transformation assumes an image.
    ])
transform_label = transforms.Compose([
    #transforms.Resize(resize_to),
    transforms.ToTensor()
    ])
invtransform_label = transforms.Compose([
    transforms.ToPILImage(),
    #transforms.Resize([1000,500])
    ])
for sample in testset_list:
    image_path = os.path.join(test_dir,sample,"image.png")
    label_path = os.path.join(test_dir,sample,"label.png")
    
    img = Image.open(image_path)
    img = transform_image(img)
    img = img.to(device)
    img = img.unsqueeze(0)
    label = Image.open(label_path)
    label = transform_label(label).to(device)
    label = label.unsqueeze(0)
    
    
    with torch.no_grad():
        pred = net(img)
    
    DiceIndex = (1 - DiceLoss(pred,label)).cpu().item()

    pred = invtransform_label(pred.cpu().squeeze(0))
    fname = "pred%.2f.png"%DiceIndex
    sav_path = os.path.join(test_dir,sample,fname)
    pred.save(sav_path)
    


In [9]:
#for python
torch.save(net.state_dict(), './unet_usseg.pth')

#for c++
traced_script_module = torch.jit.trace(net, img)
traced_script_module.save("./unet_usseg_traced.pt")

In [3]:
test_dir = os.path.expanduser("~/workspace/us_robot/simulator/SimDatasetTest2")
testset_list = os.listdir(test_dir)
resize_to=None
transform_image = transforms.Compose([
    #transforms.Resize(resize_to),
    transforms.ToTensor(),
    transforms.Normalize(0.5,0.5) #Division by 255 is done, when the transformation assumes an image.
    ])
transform_label = transforms.Compose([
    #transforms.Resize(resize_to),
    transforms.ToTensor()
    ])
invtransform_label = transforms.Compose([
    transforms.ToPILImage(),
    #transforms.Resize([1000,500])
    ])

In [18]:
sample = '0002'

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image_path = os.path.join(test_dir,sample,"image.png")
label_path = os.path.join(test_dir,sample,"label.png")
pred_path = os.path.join(test_dir,sample,"pred0.06.png")

img = Image.open(image_path)
img = transform_image(img)
img = img.to(device)
img = img.unsqueeze(0)
label = Image.open(label_path)
label = transform_label(label).to(device)
label = label.unsqueeze(0)
pred = Image.open(pred_path)
pred = transform_label(pred).to(device)
pred = pred.unsqueeze(0)

#with torch.no_grad():
#    pred = net(img)

DiceIndex = (1 - DiceLoss(pred,label)).cpu().item()


In [19]:
pred_path = os.path.join(test_dir,sample,"pred0.06.png")
pred = Image.open(pred_path)
pred = transform_label(pred)

In [3]:
loss = nn.BCELoss()

In [11]:
x = torch.Tensor([0.6,0.5,0.5])
y = torch.Tensor([1,1,1])

In [12]:
loss(x,y)

tensor(0.6324)

In [28]:
x = torch.rand([1,2,2])

In [31]:
x.size()[2:]

torch.Size([2])

In [32]:
x.size()

torch.Size([1, 2, 2])

In [9]:
round([1.111,2,222],2)

TypeError: type list doesn't define __round__ method