In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms

from modules.UNet import *
from modules.Discriminator import *
from modules.DataSet import *
from modules.Losses import *

from flownet2 import models

import os,sys

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import random_split
from PIL import Image,ImageEnhance

import numpy as np

In [3]:
def get_next_mask(label,flow):
    #print(flow.shape)
    #print(label.shape)
    label = np.array(label)
    edge_points = []
    result = np.zeros_like(label)
    for i in range(label.shape[0]):
        curr_row = label[i,:]
        if any(curr_row>0):
            lnon_zero = np.where(curr_row>0)[0]
            edge_points.append([i,lnon_zero[0]])
            edge_points.append([i,lnon_zero[-1]])

    for j in range(label.shape[1]):
        curr_col = label[:,j]
        if any(curr_col>0):
            lnon_zero = np.where(curr_col>0)[0]
            edge_points.append([lnon_zero[0],j])
            edge_points.append([lnon_zero[-1],j])

    for coord in edge_points:
        [dx,dy] = flow[:,coord[0],coord[1]]
        coord = (coord+np.array([dy,dx])).astype(np.long)
        
        result[coord[0],coord[1]] = 1
    
    return result

In [4]:
def train_OF(net,flownet,device,val_per=0.1,epochs=10,batch_size=10,resize_to=None):

    transform_image = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(0.5,0.5)
    ])

        
    dataSet = UltraSoundDataSet_OF(root_dir,transform_image,resize_to)
    nTrain = int(len(dataSet)*(1-val_per))
    nValid = int(len(dataSet)-nTrain)
    
    trainSet,validSet = random_split(dataSet,[nTrain,nValid])
    
    train_loader = DataLoader(trainSet,batch_size=batch_size,shuffle=True,num_workers=4)
    valid_loader = DataLoader(validSet,batch_size=batch_size,shuffle=True,num_workers=4)
    
    optimizer = torch.optim.Adam(net.parameters())
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='min',patience=10)
    
    running_loss_seg = 0
    
    step = 0
    np.set_printoptions(precision=2)
    
    for epoch in range(epochs):
        net.train()
        
        for batch in train_loader:
            im, images_curr, labels_prev, labels_curr = batch
                
            #calculate optical flow
            im = im.to(device)
            with torch.no_grad():
                flow = flownet(im)

            flow = np.array(flow.cpu())
            
            labels_curr_of = np.zeros_like(labels_curr)
            for n in range(im.shape[0]):
                #estimate vessel position by using the flow
                labels_curr_of[n,...] = get_next_mask(labels_prev[n,0,...],flow[n,...])
                
            labels_curr_of = torch.Tensor(labels_curr_of).to(device,dtype=torch.float32)
            images_curr = images_curr.to(device)
            labels_curr = labels_curr.to(device)
#             print(images_curr.shape)
#             print(labels_curr_of.shape)
            if np.random.rand(1)>0.5:
                pred = net(images_curr,labels_curr_of, mode=1)
                print("+",end='')
            else:
                pred = net(images_curr,labels_curr_of, mode=0)
                print("-",end='')
                
            seg_loss = DiceLoss(pred,labels_curr)
            #print(seg_loss.item())
            optimizer.zero_grad()
            seg_loss.backward()
            optimizer.step()

            running_loss_seg += seg_loss.item()
            
            step += 1    
            if step % 10 == 9:    # print every 10 mini-batches
                print()
                print('[%d, %5d] loss: %.3f' %(epoch + 1, step + 1, running_loss_seg / 10))
                running_loss_seg = 0.0
                
            if step%50 == 49:
                val_loss = 0
                for batch in valid_loader:
                    im, images_curr, labels_prev, labels_curr = batch
                    
                    #calculate optical flow
                    im = im.to(device)
                    with torch.no_grad():
                        #print(im.shape)
                        flow = flownet(im)
                        #print(flow.shape)

                    flow = np.array(flow.cpu())
                        
                    labels_curr_of = np.zeros_like(labels_curr)
                    for n in range(im.shape[0]):
                        #estimate vessel position by using the flow   
                        labels_curr_of[n,...] = get_next_mask(labels_prev[n,0,...],flow[n,...])
                    
                    labels_curr_of = torch.Tensor(labels_curr_of).to(device,dtype=torch.float32)
                    labels_curr = labels_curr.to(device)
                    images_curr = torch.Tensor(images_curr).to(device)
                    with torch.no_grad():
                        pred = net(images_curr,labels_curr_of,mode=0)

                    val_loss += DiceLoss(pred,labels_curr)
                print('[%d, %5d] validation loss: %.3f' %(epoch + 1, step + 1, val_loss / len(valid_loader)))
                scheduler.step(val_loss)

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#print("use ",device)
unet = UNet_OF(init_features=64).to(device)

class FlowNet2Args():
    fp16 = False
    rgb_max = 255

args = FlowNet2Args()
flownet = models.FlowNet2(args).cuda()
flownet.load_state_dict(torch.load('/home/zhenyuli/Downloads/FlowNet2_checkpoint.pth.tar')["state_dict"])
flownet = flownet.eval()

#IMG_SIZE = [256,256]

root_dir = os.path.expanduser("~/workspace/us_robot/DataSet/realDataSet/linear/vessel_dataset")

In [6]:
try:
    train_OF(unet,flownet,device,val_per=0.1,epochs=2,batch_size=10,resize_to=None)
except KeyboardInterrupt:
    sys.exit()

  "See the documentation of nn.Upsample for details.".format(mode))


+---+--++
[1,    10] loss: 0.872
++--++---+
[1,    20] loss: 0.954
+-+-+-+++-
[1,    30] loss: 0.948
++-+-++++-
[1,    40] loss: 0.932
++-++--++-
[1,    50] loss: 0.902
[1,    50] validation loss: 0.916
----+-+--+
[1,    60] loss: 0.889
+++--+-++-
[1,    70] loss: 0.860
+--++---+-
[1,    80] loss: 0.824
-++-+-+---
[1,    90] loss: 0.768
++++-+++-+
[1,   100] loss: 0.616
[1,   100] validation loss: 0.666
----+-+-+-
[1,   110] loss: 0.543
--+--+++--
[2,   120] loss: 0.398
--+++--++-
[2,   130] loss: 0.299
++--++---+
[2,   140] loss: 0.227
--++------
[2,   150] loss: 0.186
[2,   150] validation loss: 0.205
-+-+-+--++
[2,   160] loss: 0.148
-+-++-----
[2,   170] loss: 0.128
++++----++
[2,   180] loss: 0.157
-+-+++-+++
[2,   190] loss: 0.103
-++++++++-
[2,   200] loss: 0.101
[2,   200] validation loss: 0.129
-++-+--++-
[2,   210] loss: 0.104
+++----++-
[2,   220] loss: 0.095
---++

In [11]:
#for python
torch.save(unet.state_dict(), './unet_of_usseg.pth')
# #for c++
# traced_script_module = torch.jit.trace(unet, img)
# traced_script_module.save("./unet_of_usseg_traced.pt")

## Inference

In [15]:
test_dir = os.path.expanduser("~/workspace/us_robot/DataSet/realDataSet/linear/vessel_test2")
pred_dir = os.path.expanduser("~/workspace/us_robot/DataSet/realDataSet/linear/vessel_pred")

testset_list = os.listdir(test_dir)
testset_list = sorted(list(filter(lambda x: x.endswith('png'), testset_list)))
resize_to=[256,256]
transform_image = transforms.Compose([
    transforms.Resize(resize_to),
    transforms.ToTensor(),
    transforms.Normalize(0.5,0.5) #Division by 255 is done, when the transformation assumes an image.
    ])
transform_label = transforms.Compose([
    transforms.Resize(resize_to),
    transforms.ToTensor()
    ])
invtransform_label = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize([512,512])
    ])

for sample in testset_list:
    image_path = os.path.join(test_dir,sample)
    label_path = os.path.join(pred_dir,sample)
    
    img = Image.open(image_path).convert("L")
    img = ImageEnhance.Contrast(img).enhance(1.5)
    img = transform_image(img).unsqueeze(0)
    labels_curr_of = torch.Tensor(np.zeros_like(img))
    img = img.to(device)
    labels_curr_of = labels_curr_of.to(device)
    #label = Image.open(label_path)
    #label = transform_label(label).to(device)
    #label = label.unsqueeze(0)
    
    
    with torch.no_grad():
        pred = unet(img,labels_curr_of, mode=0)
    
    #DiceIndex = (1 - DiceLoss(pred,label)).cpu().item()

    pred = invtransform_label(pred.cpu().squeeze(0))
    #fname = "pred%.2f.png"%DiceIndex
    fname = sample
    sav_path = os.path.join(pred_dir,fname)
    pred.save(sav_path)