In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

import os,sys

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import random_split
from PIL import Image

In [2]:
class UNet(nn.Module):
    def __init__(self,in_channels=1,out_channels=1,init_features=64):
        super(UNet, self).__init__()

        features = init_features
        self.encoder1 = UNet.DoubleConv2d(in_channels,features)
        self.pool1 = nn.MaxPool2d(2, 2)
        
        self.encoder2 = UNet.DoubleConv2d(features,2*features)
        self.pool2 = nn.MaxPool2d(2, 2)

        self.encoder3 = UNet.DoubleConv2d(2*features,4*features)
        self.pool3 = nn.MaxPool2d(2, 2)
        
        self.encoder4 = UNet.DoubleConv2d(4*features,8*features)
        self.pool4 = nn.MaxPool2d(2, 2)
        
        self.bottleneck = UNet.DoubleConv2d(8*features,16*features)
        
        self.upconv4 = nn.ConvTranspose2d(16*features,8*features,kernel_size=2,stride=2)        
        self.decoder4 = UNet.DoubleConv2d(16*features,8*features) #concate, 2*8=16
        
        self.upconv3 = nn.ConvTranspose2d(8*features,4*features,kernel_size=2,stride=2)
        self.decoder3 = UNet.DoubleConv2d(8*features,4*features) #concate, 2*4=8
        
        self.upconv2 = nn.ConvTranspose2d(4*features,2*features,kernel_size=2,stride=2)
        self.decoder2 = UNet.DoubleConv2d(4*features,2*features) #concate, 2*2=4
        
        
        self.upconv1 = nn.ConvTranspose2d(2*features,features,kernel_size=2,stride=2)
        self.decoder1 = UNet.DoubleConv2d(2*features,features) #concate, 2*1=2
        
        self.conv_out = nn.Conv2d(features, 1, 1)
        
    def forward(self, input):
        enc1 = self.encoder1(input)
        
        enc2 = self.encoder2(self.pool1(enc1))
        
        enc3 = self.encoder3(self.pool2(enc2))
        
        enc4 = self.encoder4(self.pool3(enc3))
        
        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = torch.cat([enc4,self.upconv4(bottleneck)],dim=1)
        dec4 = self.decoder4(dec4)
        
        dec3 = torch.cat([enc3,self.upconv3(dec4)],dim=1)
        dec3 = self.decoder3(dec3)
        
        dec2 = torch.cat([enc2,self.upconv2(dec3)],dim=1)
        dec2 = self.decoder2(dec2)
        
        dec1 = torch.cat([enc1,self.upconv1(dec2)],dim=1)
        dec1 = self.decoder1(dec1)
        
        output = torch.sigmoid(self.conv_out(dec1))
        
        return output
    
    def DoubleConv2d(in_channels,features):
        return nn.Sequential(
                nn.Conv2d(in_channels, features, kernel_size=3, padding=1),
                nn.BatchNorm2d(features),
                nn.ReLU(inplace=True),
                nn.Conv2d(features, features, kernel_size=3, padding=1),
                nn.BatchNorm2d(features),
                nn.ReLU(inplace=True)
        )

In [3]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.img_conv1 = nn.Conv2d(1, 16, kernel_size = 3, padding=1)
        self.img_conv2 = nn.Conv2d(16, 64, kernel_size = 3, padding=1)
        self.img_pool = nn.MaxPool2d(2, 2)
        
        self.mask_conv = nn.Conv2d(1, 64, kernel_size = 3, padding=1)
        self.mask_pool = nn.MaxPool2d(2, 2)
        
        self.dis_conv1 = nn.Conv2d(128, 256, kernel_size = 3, padding=1)
        self.dis_pool1 = nn.MaxPool2d(2, 2)
        self.dis_conv2 = nn.Conv2d(256, 512, kernel_size = 3, padding=1)
        self.dis_pool2 = nn.MaxPool2d(2, 2)
        
        self.output = nn.Linear(512,1)
        
        
    def forward(self, input, mask):
        x = self.img_conv1(input)
        x = self.img_conv2(x)
        x = self.img_pool(x)
        
        y = self.mask_conv(mask)
        y = self.mask_pool(y)
        
        x = torch.cat([x,y],dim=1)
        x = self.dis_conv1(x)
        x = self.dis_pool1(x)
        x = self.dis_conv2(x)
        x = self.dis_pool2(x)
        x = F.max_pool2d(x,kernel_size=x.size()[2:]).squeeze()

        output = torch.sigmoid(self.output(x))
        
        return output

In [4]:
class UltraSoundDataSet(Dataset):
    def __init__(self, root_dir, transforms):
        self.root_dir = root_dir
        self.sample_list = os.listdir(root_dir)

        self.transform_image, self.transform_label = transforms
        
    def __len__(self):
        return len(self.sample_list)
    
    def __getitem__(self,idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        image_path = os.path.join(self.root_dir,self.sample_list[idx],"image.png")
        label_path = os.path.join(self.root_dir,self.sample_list[idx],"label.png")
        
        image = Image.open(image_path)
        label = Image.open(label_path)
        
        if self.transform_image is not None:
            image = self.transform_image(image)
            
        if self.transform_label is not None:
            label = self.transform_label(label)
        
        return image,label
        #sample = {"image":image,"label":label}
        
        #return sample
        

In [5]:
def DiceLoss(pred,target,slack=10):
    
    index = (2*torch.sum(pred*target)+slack)/(torch.sum(pred)+torch.sum(target)+slack)
    #if torch.sum(target).item() == 0:
    #print("instersection: ",torch.sum(pred*target).item())
    #print("pred: ",torch.sum(pred).item())
    #print("target: ",torch.sum(target).item())
    #print("Index: ", index.item())
    return 1-index

In [5]:
def train(net,device,val_per=0.1,epochs=10,batch_size=10,resize_to=None):
    if resize_to is not None:
        transform_image = transforms.Compose([
        transforms.Resize(resize_to),
        transforms.ToTensor(),
        transforms.Normalize(0.5,0.5)
        ])
        transform_label = transforms.Compose([
        transforms.Resize(resize_to),
        transforms.ToTensor()
        ])
    else:
        transform_image = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(0.5,0.5)
        ])
        transform_label = transforms.Compose([
        transforms.ToTensor()
        ])
            
    
    dataSet = UltraSoundDataSet(root_dir,(transform_image,transform_label))
    nTrain = int(len(dataSet)*(1-val_per))
    nValid = int(len(dataSet)-nTrain)
    
    trainSet,validSet = random_split(dataSet,[nTrain,nValid])
    
    train_loader = DataLoader(trainSet,batch_size=batch_size,shuffle=True,num_workers=4)
    valid_loader = DataLoader(validSet,batch_size=batch_size,shuffle=True,num_workers=4)
    
    optimizer = torch.optim.Adam(net.parameters())
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='min',patience=10) #mae: dice-index
    
    running_loss = 0
    step = 0
    for epoch in range(epochs):
        net.train()
        for batch in train_loader:
            imgs,labels = batch
            
            imgs = imgs.to(device=device,dtype=torch.float32)
            labels = labels.to(device=device,dtype=torch.float32)
            
            pred = net(imgs)
            
            loss = DiceLoss(pred,labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            step += 1
            
            if step % 10 == 9:    # print every 10 mini-batches
                print('[%d, %5d] loss: %.3f' %(epoch + 1, step + 1, running_loss / 10))
                running_loss = 0.0
                
            if step%50 == 49:
                val_loss = 0
                for batch in valid_loader:
                    imgs,labels = batch
                    imgs = imgs.to(device=device,dtype=torch.float32)
                    labels = labels.to(device=device,dtype=torch.float32)
                    
                    with torch.no_grad():
                        pred = net(imgs)
                    val_loss += DiceLoss(pred,labels)
                print('[%d, %5d] validation loss: %.3f' %(epoch + 1, step + 1, val_loss / len(valid_loader)))
                scheduler.step(val_loss)
            
        
    

In [6]:
def train_adversarial(net,discriminator,device,val_per=0.1,epochs=10,batch_size=10,resize_to=None):
    if resize_to is not None:
        transform_image = transforms.Compose([
        transforms.Resize(resize_to),
        transforms.ToTensor(),
        transforms.Normalize(0.5,0.5)
        ])
        transform_label = transforms.Compose([
        transforms.Resize(resize_to),
        transforms.ToTensor()
        ])
    else:
        transform_image = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(0.5,0.5)
        ])
        transform_label = transforms.Compose([
        transforms.ToTensor()
        ])
            
    
    dataSet = UltraSoundDataSet(root_dir,(transform_image,transform_label))
    nTrain = int(len(dataSet)*(1-val_per))
    nValid = int(len(dataSet)-nTrain)
    
    trainSet,validSet = random_split(dataSet,[nTrain,nValid])
    
    train_loader = DataLoader(trainSet,batch_size=batch_size,shuffle=True,num_workers=4)
    valid_loader = DataLoader(validSet,batch_size=batch_size,shuffle=True,num_workers=4)
    
    optimizer_unet = torch.optim.Adam(net.parameters())
    optimizer_disc = torch.optim.Adam(discriminator.parameters())
    scheduler_unet = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_unet,mode='min',patience=10) #mae: dice-index
    scheduler_disc = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_disc,mode='min',patience=10)
    BCELoss = nn.BCELoss()
    
    running_loss_seg = 0
    running_loss_disc = 0
    step = 0
    for epoch in range(epochs):
        net.train()
        discriminator.train()
        
        for batch in train_loader:
            imgs,labels = batch
            
            imgs = imgs.to(device=device,dtype=torch.float32)
            labels = labels.to(device=device,dtype=torch.float32)
            
            # segmentation network
            pred = net(imgs)
            
            # discriminator network
            pred_valid = discriminator(imgs, labels)
            pred_fake = discriminator(imgs, pred.detach())   
            
            # seg loss
            seg_loss = DiceLoss(pred,labels)
            
            optimizer_unet.zero_grad()
            seg_loss.backward(retain_graph=True)
            optimizer_unet.step()
            
            valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device=device,dtype=torch.float32)
            fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device=device,dtype=torch.float32)
            
            # discriminator loss
            disc_loss = BCELoss(pred_valid,valid) + BCELoss(pred_fake,fake)

            optimizer_disc.zero_grad()
            disc_loss.backward(retain_graph=False)
            optimizer_disc.step()
            
            running_loss_seg += seg_loss.item()
            running_loss_disc += disc_loss.item()
            step += 1
            
            if step % 10 == 9:    # print every 10 mini-batches
                print('[%d, %5d] segmentation loss: %.3f; discrimination loss: %.3f' %(epoch + 1, step + 1, running_loss_seg / 10, running_loss_disc / 10))
                running_loss_seg = 0.0
                running_loss_disc = 0.0
                
            if step%50 == 49:
                val_loss_seg = 0
                val_loss_disc = 0
                for batch in valid_loader:
                    imgs,labels = batch
                    imgs = imgs.to(device=device,dtype=torch.float32)
                    labels = labels.to(device=device,dtype=torch.float32)
                    
                    with torch.no_grad():
                        pred_seg = net(imgs)
                        pred_valid = discriminator(imgs, labels)
                        pred_fake = discriminator(imgs, pred_seg.detach())
                    
                    val_loss_seg += DiceLoss(pred_seg,labels).item()
                    valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device=device,dtype=torch.float32)
                    fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device=device,dtype=torch.float32)
                    val_loss_disc += (BCELoss(pred_valid,valid) + BCELoss(pred_fake,fake)).item()
                print('[%d, %5d] validation loss(seg): %.3f; validation loss(disc): %.3f' %(epoch + 1, step + 1, val_loss_seg / len(valid_loader), val_loss_disc / len(valid_loader)))
                scheduler_unet.step(val_loss_seg)
                scheduler_disc.step(val_loss_disc)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = UNet(init_features=32)
discriminator = Discriminator()

net = net.to(device)
discriminator = discriminator.to(device)

root_dir = os.path.expanduser("~/workspace/us_robot/DataSet/SimDataset2")
try:
    #train(net=net,device=device,resize_to=None,epochs=50,batch_size=5)
    train_adversarial(net=net,discriminator=discriminator,device=device,resize_to=None,epochs=20,batch_size=3)
except KeyboardInterrupt:
    sys.exit()

[1,    10] segmentation loss: 0.850; discrimination loss: 0.699
[1,    20] segmentation loss: 0.869; discrimination loss: 0.057
[1,    30] segmentation loss: 0.882; discrimination loss: 0.001
[1,    40] segmentation loss: 0.895; discrimination loss: 0.002
[1,    50] segmentation loss: 0.844; discrimination loss: 0.003
[1,    50] validation loss(seg): 0.808; validation loss(disc): 0.001
[1,    60] segmentation loss: 0.808; discrimination loss: 0.002
[1,    70] segmentation loss: 0.801; discrimination loss: 0.344
[1,    80] segmentation loss: 0.791; discrimination loss: 1.758
[1,    90] segmentation loss: 0.658; discrimination loss: 1.716
[1,   100] segmentation loss: 0.713; discrimination loss: 1.993
[1,   100] validation loss(seg): 0.655; validation loss(disc): 0.003
[1,   110] segmentation loss: 0.724; discrimination loss: 0.135
[1,   120] segmentation loss: 0.567; discrimination loss: 0.013
[1,   130] segmentation loss: 0.582; discrimination loss: 0.021
[1,   140] segmentation loss: 

[2,  1070] segmentation loss: 0.060; discrimination loss: 0.364
[2,  1080] segmentation loss: 0.047; discrimination loss: 0.476
[2,  1090] segmentation loss: 0.054; discrimination loss: 0.401
[2,  1100] segmentation loss: 0.048; discrimination loss: 0.369
[2,  1100] validation loss(seg): 0.092; validation loss(disc): 0.371
[2,  1110] segmentation loss: 0.143; discrimination loss: 0.469
[2,  1120] segmentation loss: 0.069; discrimination loss: 0.276
[2,  1130] segmentation loss: 0.044; discrimination loss: 0.400
[2,  1140] segmentation loss: 0.037; discrimination loss: 0.549
[2,  1150] segmentation loss: 0.113; discrimination loss: 0.596
[2,  1150] validation loss(seg): 0.092; validation loss(disc): 0.417
[2,  1160] segmentation loss: 0.045; discrimination loss: 0.358
[2,  1170] segmentation loss: 0.048; discrimination loss: 0.422
[2,  1180] segmentation loss: 0.142; discrimination loss: 0.366
[2,  1190] segmentation loss: 0.230; discrimination loss: 0.600
[2,  1200] segmentation loss: 

In [9]:
test_dir = os.path.expanduser("~/workspace/us_robot/DataSet/SimDatasetTest2")
testset_list = os.listdir(test_dir)
resize_to=None
transform_image = transforms.Compose([
    #transforms.Resize(resize_to),
    transforms.ToTensor(),
    transforms.Normalize(0.5,0.5) #Division by 255 is done, when the transformation assumes an image.
    ])
transform_label = transforms.Compose([
    #transforms.Resize(resize_to),
    transforms.ToTensor()
    ])
invtransform_label = transforms.Compose([
    transforms.ToPILImage(),
    #transforms.Resize([1000,500])
    ])
for sample in testset_list:
    image_path = os.path.join(test_dir,sample,"image.png")
    label_path = os.path.join(test_dir,sample,"label.png")
    
    img = Image.open(image_path)
    img = transform_image(img)
    img = img.to(device)
    img = img.unsqueeze(0)
    label = Image.open(label_path)
    label = transform_label(label).to(device)
    label = label.unsqueeze(0)
    
    
    with torch.no_grad():
        pred = net(img)
    
    DiceIndex = (1 - DiceLoss(pred,label)).cpu().item()

    pred = invtransform_label(pred.cpu().squeeze(0))
    fname = "pred%.2f.png"%DiceIndex
    sav_path = os.path.join(test_dir,sample,fname)
    pred.save(sav_path)
    


In [9]:
#for python
torch.save(net.state_dict(), './unet_usseg.pth')

#for c++
traced_script_module = torch.jit.trace(net, img)
traced_script_module.save("./unet_usseg_traced.pt")

In [3]:
test_dir = os.path.expanduser("~/workspace/us_robot/simulator/SimDatasetTest2")
testset_list = os.listdir(test_dir)
resize_to=None
transform_image = transforms.Compose([
    #transforms.Resize(resize_to),
    transforms.ToTensor(),
    transforms.Normalize(0.5,0.5) #Division by 255 is done, when the transformation assumes an image.
    ])
transform_label = transforms.Compose([
    #transforms.Resize(resize_to),
    transforms.ToTensor()
    ])
invtransform_label = transforms.Compose([
    transforms.ToPILImage(),
    #transforms.Resize([1000,500])
    ])

In [18]:
sample = '0002'

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image_path = os.path.join(test_dir,sample,"image.png")
label_path = os.path.join(test_dir,sample,"label.png")
pred_path = os.path.join(test_dir,sample,"pred0.06.png")

img = Image.open(image_path)
img = transform_image(img)
img = img.to(device)
img = img.unsqueeze(0)
label = Image.open(label_path)
label = transform_label(label).to(device)
label = label.unsqueeze(0)
pred = Image.open(pred_path)
pred = transform_label(pred).to(device)
pred = pred.unsqueeze(0)

#with torch.no_grad():
#    pred = net(img)

DiceIndex = (1 - DiceLoss(pred,label)).cpu().item()


In [19]:
pred_path = os.path.join(test_dir,sample,"pred0.06.png")
pred = Image.open(pred_path)
pred = transform_label(pred)

In [2]:
loss = nn.BCELoss()

In [5]:
x = torch.Tensor([0.1])
y = torch.Tensor([1])

In [6]:
loss(x,y)

tensor(2.3026)

In [28]:
x = torch.rand([1,2,2])

In [31]:
x.size()[2:]

torch.Size([2])

In [32]:
x.size()

torch.Size([1, 2, 2])