In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms

import os,sys

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import random_split
from PIL import Image

In [2]:
class UNet(nn.Module):
    def __init__(self,in_channels=1,out_channels=1,init_features=64):
        super(UNet, self).__init__()

        features = init_features
        self.encoder1 = UNet.DoubleConv2d(in_channels,features)
        self.pool1 = nn.MaxPool2d(2, 2)
        
        self.encoder2 = UNet.DoubleConv2d(features,2*features)
        self.pool2 = nn.MaxPool2d(2, 2)

        self.encoder3 = UNet.DoubleConv2d(2*features,4*features)
        self.pool3 = nn.MaxPool2d(2, 2)
        
        self.encoder4 = UNet.DoubleConv2d(4*features,8*features)
        self.pool4 = nn.MaxPool2d(2, 2)
        
        self.bottleneck = UNet.DoubleConv2d(8*features,16*features)
        
        self.upconv4 = nn.ConvTranspose2d(16*features,8*features,kernel_size=2,stride=2)        
        self.decoder4 = UNet.DoubleConv2d(16*features,8*features) #concate, 2*8=16
        
        self.upconv3 = nn.ConvTranspose2d(8*features,4*features,kernel_size=2,stride=2)
        self.decoder3 = UNet.DoubleConv2d(8*features,4*features) #concate, 2*4=8
        
        self.upconv2 = nn.ConvTranspose2d(4*features,2*features,kernel_size=2,stride=2)
        self.decoder2 = UNet.DoubleConv2d(4*features,2*features) #concate, 2*2=4
        
        
        self.upconv1 = nn.ConvTranspose2d(2*features,features,kernel_size=2,stride=2)
        self.decoder1 = UNet.DoubleConv2d(2*features,features) #concate, 2*1=2
        
        self.conv_out = nn.Conv2d(features, 1, 1)
        
    def forward(self, input):
        enc1 = self.encoder1(input)
        
        enc2 = self.encoder2(self.pool1(enc1))
        
        enc3 = self.encoder3(self.pool2(enc2))
        
        enc4 = self.encoder4(self.pool3(enc3))
        
        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = torch.cat([enc4,self.upconv4(bottleneck)],dim=1)
        dec4 = self.decoder4(dec4)
        
        dec3 = torch.cat([enc3,self.upconv3(dec4)],dim=1)
        dec3 = self.decoder3(dec3)
        
        dec2 = torch.cat([enc2,self.upconv2(dec3)],dim=1)
        dec2 = self.decoder2(dec2)
        
        dec1 = torch.cat([enc1,self.upconv1(dec2)],dim=1)
        dec1 = self.decoder1(dec1)
        
        output = torch.sigmoid(self.conv_out(dec1))
        
        return output
    
    def DoubleConv2d(in_channels,features):
        return nn.Sequential(
                nn.Conv2d(in_channels, features, kernel_size=3, padding=1),
                nn.BatchNorm2d(features),
                nn.ReLU(inplace=True),
                nn.Conv2d(features, features, kernel_size=3, padding=1),
                nn.BatchNorm2d(features),
                nn.ReLU(inplace=True)
        )

In [3]:
class UltraSoundDataSet(Dataset):
    def __init__(self, root_dir, transforms):
        self.root_dir = root_dir
        self.sample_list = os.listdir(root_dir)

        self.transform_image, self.transform_label = transforms
        
    def __len__(self):
        return len(self.sample_list)
    
    def __getitem__(self,idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        image_path = os.path.join(self.root_dir,self.sample_list[idx],"image.png")
        label_path = os.path.join(self.root_dir,self.sample_list[idx],"label.png")
        
        image = Image.open(image_path)
        label = Image.open(label_path)
        
        if self.transform_image is not None:
            image = self.transform_image(image)
            
        if self.transform_label is not None:
            label = self.transform_label(label)
        
        return image,label
        #sample = {"image":image,"label":label}
        
        #return sample
        

In [4]:
def DiceLoss(pred,target,slack=10):
    
    index = (2*torch.sum(pred*target)+slack)/(torch.sum(pred)+torch.sum(target)+slack)
    #if torch.sum(target).item() == 0:
    #print("instersection: ",torch.sum(pred*target).item())
    #print("pred: ",torch.sum(pred).item())
    #print("target: ",torch.sum(target).item())
    #print("Index: ", index.item())
    return 1-index

In [5]:
def train(net,device,val_per=0.1,epochs=10,batch_size=10,resize_to=None):
    if resize_to is not None:
        transform_image = transforms.Compose([
        transforms.Resize(resize_to),
        transforms.ToTensor(),
        transforms.Normalize(0.5,0.5)
        ])
        transform_label = transforms.Compose([
        transforms.Resize(resize_to),
        transforms.ToTensor()
        ])
    else:
        transform_image = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(0.5,0.5)
        ])
        transform_label = transforms.Compose([
        transforms.ToTensor()
        ])
            
    
    dataSet = UltraSoundDataSet(root_dir,(transform_image,transform_label))
    nTrain = int(len(dataSet)*(1-val_per))
    nValid = int(len(dataSet)-nTrain)
    
    trainSet,validSet = random_split(dataSet,[nTrain,nValid])
    
    train_loader = DataLoader(trainSet,batch_size=batch_size,shuffle=True,num_workers=4)
    valid_loader = DataLoader(validSet,batch_size=batch_size,shuffle=True,num_workers=4)
    
    optimizer = torch.optim.Adam(net.parameters())
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='min',patience=10) #mae: dice-index
    
    running_loss = 0
    step = 0
    for epoch in range(epochs):
        net.train()
        for batch in train_loader:
            imgs,labels = batch
            
            imgs = imgs.to(device=device,dtype=torch.float32)
            labels = labels.to(device=device,dtype=torch.float32)
            
            pred = net(imgs)
            
            loss = DiceLoss(pred,labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            step += 1
            
            if step % 10 == 9:    # print every 10 mini-batches
                print('[%d, %5d] loss: %.3f' %(epoch + 1, step + 1, running_loss / 10))
                running_loss = 0.0
                
            if step%50 == 49:
                val_loss = 0
                for batch in valid_loader:
                    imgs,labels = batch
                    imgs = imgs.to(device=device,dtype=torch.float32)
                    labels = labels.to(device=device,dtype=torch.float32)
                    
                    with torch.no_grad():
                        pred = net(imgs)
                    val_loss += DiceLoss(pred,labels)
                print('[%d, %5d] validation loss: %.3f' %(epoch + 1, step + 1, val_loss / len(valid_loader)))
                scheduler.step(val_loss)
            
        
    

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = UNet(init_features=32)

net = net.to(device)

root_dir = "/home/zhenyuli/workspace/us_robot/simulator/SimDataset2"
try:
    train(net=net,device=device,resize_to=None,epochs=50,batch_size=5)
except KeyboardInterrupt:
    sys.exit()

[1,    10] loss: 0.786
[1,    20] loss: 0.898
[1,    30] loss: 0.817
[1,    40] loss: 0.808
[1,    50] loss: 0.840
[1,    50] validation loss: 0.817
[1,    60] loss: 0.812
[1,    70] loss: 0.748
[1,    80] loss: 0.708
[1,    90] loss: 0.755
[1,   100] loss: 0.683
[1,   100] validation loss: 0.654
[1,   110] loss: 0.601
[1,   120] loss: 0.534
[1,   130] loss: 0.504
[1,   140] loss: 0.430
[1,   150] loss: 0.394
[1,   150] validation loss: 0.353
[1,   160] loss: 0.341
[1,   170] loss: 0.288
[1,   180] loss: 0.290
[1,   190] loss: 0.186
[1,   200] loss: 0.150
[1,   200] validation loss: 0.169
[1,   210] loss: 0.140
[1,   220] loss: 0.153
[1,   230] loss: 0.242
[1,   240] loss: 0.197
[1,   250] loss: 0.118
[1,   250] validation loss: 0.123
[1,   260] loss: 0.085
[1,   270] loss: 0.100
[1,   280] loss: 0.111
[1,   290] loss: 0.094
[1,   300] loss: 0.094
[1,   300] validation loss: 0.072
[1,   310] loss: 0.080
[1,   320] loss: 0.106
[1,   330] loss: 0.140
[1,   340] loss: 0.075
[1,   350] los

[8,  2760] loss: 0.029
[8,  2770] loss: 0.029
[8,  2780] loss: 0.030
[8,  2790] loss: 0.028
[8,  2800] loss: 0.025
[8,  2800] validation loss: 0.026
[8,  2810] loss: 0.028
[8,  2820] loss: 0.026
[8,  2830] loss: 0.025
[8,  2840] loss: 0.028
[8,  2850] loss: 0.027
[8,  2850] validation loss: 0.027
[8,  2860] loss: 0.025
[8,  2870] loss: 0.023
[8,  2880] loss: 0.029
[9,  2890] loss: 0.023
[9,  2900] loss: 0.024
[9,  2900] validation loss: 0.026
[9,  2910] loss: 0.029
[9,  2920] loss: 0.023
[9,  2930] loss: 0.025
[9,  2940] loss: 0.026
[9,  2950] loss: 0.025
[9,  2950] validation loss: 0.027
[9,  2960] loss: 0.028
[9,  2970] loss: 0.025
[9,  2980] loss: 0.025
[9,  2990] loss: 0.028
[9,  3000] loss: 0.025
[9,  3000] validation loss: 0.025
[9,  3010] loss: 0.026
[9,  3020] loss: 0.026
[9,  3030] loss: 0.026
[9,  3040] loss: 0.022
[9,  3050] loss: 0.033
[9,  3050] validation loss: 0.025
[9,  3060] loss: 0.026
[9,  3070] loss: 0.023
[9,  3080] loss: 0.021
[9,  3090] loss: 0.027
[9,  3100] los

[16,  5430] loss: 0.021
[16,  5440] loss: 0.022
[16,  5450] loss: 0.026
[16,  5450] validation loss: 0.025
[16,  5460] loss: 0.022
[16,  5470] loss: 0.020
[16,  5480] loss: 0.026
[16,  5490] loss: 0.028
[16,  5500] loss: 0.027
[16,  5500] validation loss: 0.024
[16,  5510] loss: 0.028
[16,  5520] loss: 0.026
[16,  5530] loss: 0.022
[16,  5540] loss: 0.025
[16,  5550] loss: 0.036
[16,  5550] validation loss: 0.023
[16,  5560] loss: 0.024
[16,  5570] loss: 0.026
[16,  5580] loss: 0.021
[16,  5590] loss: 0.021
[16,  5600] loss: 0.021
[16,  5600] validation loss: 0.024
[16,  5610] loss: 0.026
[16,  5620] loss: 0.027
[16,  5630] loss: 0.029
[16,  5640] loss: 0.023
[16,  5650] loss: 0.022
[16,  5650] validation loss: 0.024
[16,  5660] loss: 0.026
[16,  5670] loss: 0.022
[16,  5680] loss: 0.034
[16,  5690] loss: 0.029
[16,  5700] loss: 0.021
[16,  5700] validation loss: 0.023
[16,  5710] loss: 0.025
[16,  5720] loss: 0.020
[16,  5730] loss: 0.024
[16,  5740] loss: 0.029
[16,  5750] loss: 0.02

[23,  8080] loss: 0.025
[23,  8090] loss: 0.024
[23,  8100] loss: 0.023
[23,  8100] validation loss: 0.023
[23,  8110] loss: 0.025
[23,  8120] loss: 0.024
[23,  8130] loss: 0.024
[23,  8140] loss: 0.030
[23,  8150] loss: 0.023
[23,  8150] validation loss: 0.025
[23,  8160] loss: 0.026
[23,  8170] loss: 0.029
[23,  8180] loss: 0.024
[23,  8190] loss: 0.026
[23,  8200] loss: 0.021
[23,  8200] validation loss: 0.024
[23,  8210] loss: 0.024
[23,  8220] loss: 0.022
[23,  8230] loss: 0.025
[23,  8240] loss: 0.030
[23,  8250] loss: 0.027
[23,  8250] validation loss: 0.023
[23,  8260] loss: 0.024
[23,  8270] loss: 0.022
[23,  8280] loss: 0.024
[24,  8290] loss: 0.020
[24,  8300] loss: 0.025
[24,  8300] validation loss: 0.024
[24,  8310] loss: 0.021
[24,  8320] loss: 0.021
[24,  8330] loss: 0.022
[24,  8340] loss: 0.034
[24,  8350] loss: 0.024
[24,  8350] validation loss: 0.024
[24,  8360] loss: 0.029
[24,  8370] loss: 0.022
[24,  8380] loss: 0.024
[24,  8390] loss: 0.022
[24,  8400] loss: 0.01

[30, 10730] loss: 0.022
[30, 10740] loss: 0.021
[30, 10750] loss: 0.026
[30, 10750] validation loss: 0.025
[30, 10760] loss: 0.025
[30, 10770] loss: 0.027
[30, 10780] loss: 0.023
[30, 10790] loss: 0.027
[30, 10800] loss: 0.026
[30, 10800] validation loss: 0.024
[31, 10810] loss: 0.024
[31, 10820] loss: 0.024
[31, 10830] loss: 0.026
[31, 10840] loss: 0.032
[31, 10850] loss: 0.024
[31, 10850] validation loss: 0.032
[31, 10860] loss: 0.024
[31, 10870] loss: 0.022
[31, 10880] loss: 0.023
[31, 10890] loss: 0.024
[31, 10900] loss: 0.033
[31, 10900] validation loss: 0.024
[31, 10910] loss: 0.022
[31, 10920] loss: 0.021
[31, 10930] loss: 0.023
[31, 10940] loss: 0.028
[31, 10950] loss: 0.024
[31, 10950] validation loss: 0.024
[31, 10960] loss: 0.035
[31, 10970] loss: 0.021
[31, 10980] loss: 0.021
[31, 10990] loss: 0.025
[31, 11000] loss: 0.023
[31, 11000] validation loss: 0.032
[31, 11010] loss: 0.024
[31, 11020] loss: 0.022
[31, 11030] loss: 0.026
[31, 11040] loss: 0.025
[31, 11050] loss: 0.02

[38, 13380] loss: 0.023
[38, 13390] loss: 0.028
[38, 13400] loss: 0.025
[38, 13400] validation loss: 0.024
[38, 13410] loss: 0.023
[38, 13420] loss: 0.032
[38, 13430] loss: 0.022
[38, 13440] loss: 0.023
[38, 13450] loss: 0.021
[38, 13450] validation loss: 0.024
[38, 13460] loss: 0.023
[38, 13470] loss: 0.026
[38, 13480] loss: 0.021
[38, 13490] loss: 0.025
[38, 13500] loss: 0.021
[38, 13500] validation loss: 0.024
[38, 13510] loss: 0.030
[38, 13520] loss: 0.029
[38, 13530] loss: 0.028
[38, 13540] loss: 0.024
[38, 13550] loss: 0.021
[38, 13550] validation loss: 0.023
[38, 13560] loss: 0.025
[38, 13570] loss: 0.020
[38, 13580] loss: 0.024
[38, 13590] loss: 0.021
[38, 13600] loss: 0.024
[38, 13600] validation loss: 0.024
[38, 13610] loss: 0.024
[38, 13620] loss: 0.027
[38, 13630] loss: 0.027
[38, 13640] loss: 0.035
[38, 13650] loss: 0.028
[38, 13650] validation loss: 0.024
[38, 13660] loss: 0.025
[38, 13670] loss: 0.024
[38, 13680] loss: 0.020
[39, 13690] loss: 0.021
[39, 13700] loss: 0.02

[45, 16030] loss: 0.025
[45, 16040] loss: 0.023
[45, 16050] loss: 0.030
[45, 16050] validation loss: 0.025
[45, 16060] loss: 0.021
[45, 16070] loss: 0.027
[45, 16080] loss: 0.023
[45, 16090] loss: 0.024
[45, 16100] loss: 0.025
[45, 16100] validation loss: 0.023
[45, 16110] loss: 0.021
[45, 16120] loss: 0.023
[45, 16130] loss: 0.027
[45, 16140] loss: 0.028
[45, 16150] loss: 0.021
[45, 16150] validation loss: 0.024
[45, 16160] loss: 0.022
[45, 16170] loss: 0.026
[45, 16180] loss: 0.029
[45, 16190] loss: 0.022
[45, 16200] loss: 0.027
[45, 16200] validation loss: 0.024
[46, 16210] loss: 0.024
[46, 16220] loss: 0.024
[46, 16230] loss: 0.022
[46, 16240] loss: 0.022
[46, 16250] loss: 0.023
[46, 16250] validation loss: 0.023
[46, 16260] loss: 0.024
[46, 16270] loss: 0.034
[46, 16280] loss: 0.027
[46, 16290] loss: 0.021
[46, 16300] loss: 0.022
[46, 16300] validation loss: 0.024
[46, 16310] loss: 0.021
[46, 16320] loss: 0.022
[46, 16330] loss: 0.023
[46, 16340] loss: 0.027
[46, 16350] loss: 0.02

In [7]:
test_dir = "/home/zhenyuli/workspace/us_robot/simulator/SimDatasetTest2"
testset_list = os.listdir(test_dir)
resize_to=None
transform_image = transforms.Compose([
    #transforms.Resize(resize_to),
    transforms.ToTensor(),
    transforms.Normalize(0.5,0.5) #Division by 255 is done, when the transformation assumes an image.
    ])
transform_label = transforms.Compose([
    #transforms.Resize(resize_to),
    transforms.ToTensor()
    ])
invtransform_label = transforms.Compose([
    transforms.ToPILImage(),
    #transforms.Resize([1000,500])
    ])
for sample in testset_list:
    image_path = os.path.join(test_dir,sample,"image.png")
    label_path = os.path.join(test_dir,sample,"label.png")
    
    img = Image.open(image_path)
    img = transform_image(img)
    img = img.to(device)
    img = img.unsqueeze(0)
    label = Image.open(label_path)
    label = transform_label(label).to(device)
    label = label.unsqueeze(0)
    
    
    with torch.no_grad():
        pred = net(img)
    
    DiceIndex = (1 - DiceLoss(pred,label)).cpu().item()

    pred = invtransform_label(pred.cpu().squeeze(0))
    fname = "pred%.2f.png"%DiceIndex
    sav_path = os.path.join(test_dir,sample,fname)
    pred.save(sav_path)
    


In [9]:
#for python
torch.save(net.state_dict(), './unet_usseg.pth')

#for c++
traced_script_module = torch.jit.trace(net, img)
traced_script_module.save("./unet_usseg_traced.pt")

In [3]:
test_dir = "/home/zhenyuli/workspace/us_robot/simulator/SimDatasetTest2"
testset_list = os.listdir(test_dir)
resize_to=None
transform_image = transforms.Compose([
    #transforms.Resize(resize_to),
    transforms.ToTensor(),
    transforms.Normalize(0.5,0.5) #Division by 255 is done, when the transformation assumes an image.
    ])
transform_label = transforms.Compose([
    #transforms.Resize(resize_to),
    transforms.ToTensor()
    ])
invtransform_label = transforms.Compose([
    transforms.ToPILImage(),
    #transforms.Resize([1000,500])
    ])

In [18]:
sample = '0002'

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image_path = os.path.join(test_dir,sample,"image.png")
label_path = os.path.join(test_dir,sample,"label.png")
pred_path = os.path.join(test_dir,sample,"pred0.06.png")

img = Image.open(image_path)
img = transform_image(img)
img = img.to(device)
img = img.unsqueeze(0)
label = Image.open(label_path)
label = transform_label(label).to(device)
label = label.unsqueeze(0)
pred = Image.open(pred_path)
pred = transform_label(pred).to(device)
pred = pred.unsqueeze(0)

#with torch.no_grad():
#    pred = net(img)

DiceIndex = (1 - DiceLoss(pred,label)).cpu().item()


In [19]:
pred_path = os.path.join(test_dir,sample,"pred0.06.png")
pred = Image.open(pred_path)
pred = transform_label(pred)

In [20]:
torch.sum(pred)

tensor(0.5137)