In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms

from modules.UNet import *
from modules.Discriminator import *
from modules.DataSet import *
from modules.Losses import *

from flownet2 import models

import os,sys

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import random_split
from PIL import Image,ImageEnhance

import numpy as np

In [2]:
def get_next_mask_edge(label,flow):
    #print(flow.shape)
    #print(label.shape)
    label = np.array(label)
    edge_points = []
    result = np.zeros_like(label)
    for i in range(label.shape[0]):
        curr_row = label[i,:]
        if any(curr_row>0):
            lnon_zero = np.where(curr_row>0)[0]
            edge_points.append([i,lnon_zero[0]])
            edge_points.append([i,lnon_zero[-1]])

    for j in range(label.shape[1]):
        curr_col = label[:,j]
        if any(curr_col>0):
            lnon_zero = np.where(curr_col>0)[0]
            edge_points.append([lnon_zero[0],j])
            edge_points.append([lnon_zero[-1],j])

    for coord in edge_points:
        [dx,dy] = flow[:,coord[0],coord[1]]
        coord = (coord+np.array([dy,dx])).astype(np.long)
        coord[0] = np.clip(coord[0],0,label.shape[0]-1)
        coord[1] = np.clip(coord[1],0,label.shape[1]-1)
        
        result[coord[0],coord[1]] = 1
    
    return result

In [3]:
def get_next_mask(label,flow):
    #print(flow.shape)
    #print(label.shape)
    label = np.array(label)
    result = np.zeros_like(label)
    
    rows,cols = np.where(label>0.5)

    for i in range(len(cols)):
        
        [dcol,drow] = flow[:,rows[i],cols[i]]
        
        tmp_row = np.clip(int(rows[i]+drow),0,label.shape[0]-1)
        tmp_col = np.clip(int(cols[i]+dcol),0,label.shape[1]-1)
        
        result[tmp_row,tmp_col] = 1
    
    return result

In [4]:
def train_OF(net,flownet,device,val_per=0.1,epochs=10,batch_size=10,resize_to=None):

    transform_image = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(0.5,0.5)
    ])

        
    dataSet = UltraSoundDataSet_OF(root_dir,transform_image,resize_to)
    nTrain = int(len(dataSet)*(1-val_per))
    nValid = int(len(dataSet)-nTrain)
    
    trainSet,validSet = random_split(dataSet,[nTrain,nValid])
    
    train_loader = DataLoader(trainSet,batch_size=batch_size,shuffle=True,num_workers=4)
    valid_loader = DataLoader(validSet,batch_size=batch_size,shuffle=True,num_workers=4)
    
    optimizer = torch.optim.Adam(net.parameters())
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='min',patience=10)
    
    running_loss_seg = 0
    
    step = 0
    np.set_printoptions(precision=2)
    
    for epoch in range(epochs):
        net.train()
        
        for batch in train_loader:
            im, images_curr, labels_prev, labels_curr = batch
                
            #calculate optical flow
            im = im.to(device)
            with torch.no_grad():
                flow = flownet(im)

            flow = np.array(flow.cpu())
            
            labels_curr_of = np.zeros_like(labels_curr)
            for n in range(im.shape[0]):
                #estimate vessel position by using the flow
                labels_curr_of[n,...] = get_next_mask(labels_prev[n,0,...],flow[n,...])
                
            labels_curr_of = torch.Tensor(labels_curr_of).to(device,dtype=torch.float32)
            images_curr = images_curr.to(device)
            labels_curr = labels_curr.to(device)
#             print(images_curr.shape)
#             print(labels_curr_of.shape)
            if np.random.rand(1)>0.5:
                pred = net(images_curr,labels_curr_of, mode=1)
                print("+",end='')
            else:
                pred = net(images_curr,labels_curr_of, mode=0)
                print("-",end='')
                
            seg_loss = DiceLoss(pred,labels_curr)
            #print(seg_loss.item())
            optimizer.zero_grad()
            seg_loss.backward()
            optimizer.step()

            running_loss_seg += seg_loss.item()
            
            step += 1    
            if step % 10 == 9:    # print every 10 mini-batches
                print()
                print('[%d, %5d] loss: %.3f' %(epoch + 1, step + 1, running_loss_seg / 10))
                running_loss_seg = 0.0
                
            if step%50 == 49:
                net.eval()
                val_loss = 0
                for batch in valid_loader:
                    im, images_curr, labels_prev, labels_curr = batch
                    
                    #calculate optical flow
                    im = im.to(device)
                    with torch.no_grad():
                        #print(im.shape)
                        flow = flownet(im)
                        #print(flow.shape)

                    flow = np.array(flow.cpu())
                        
                    labels_curr_of = np.zeros_like(labels_curr)
                    for n in range(im.shape[0]):
                        #estimate vessel position by using the flow   
                        labels_curr_of[n,...] = get_next_mask(labels_prev[n,0,...],flow[n,...])
                    
                    labels_curr_of = torch.Tensor(labels_curr_of).to(device,dtype=torch.float32)
                    labels_curr = labels_curr.to(device)
                    images_curr = torch.Tensor(images_curr).to(device)
                    with torch.no_grad():
                        pred = net(images_curr,labels_curr_of,mode=0)

                    val_loss += DiceLoss(pred,labels_curr)
                print('[%d, %5d] validation loss: %.3f' %(epoch + 1, step + 1, val_loss / len(valid_loader)))
                scheduler.step(val_loss)
                net.train()

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#print("use ",device)
unet = UNet_OF2(init_features=64).to(device)

class FlowNet2Args():
    fp16 = False
    rgb_max = 255

args = FlowNet2Args()
flownet = models.FlowNet2(args).cuda()
flownet.load_state_dict(torch.load(os.path.expanduser("~/Downloads/FlowNet2_checkpoint.pth.tar"))["state_dict"])
flownet = flownet.eval()

#IMG_SIZE = [256,256]

# root_dir = os.path.expanduser("~/workspace/us_robot/DataSet/realDataSet/linear/vessel_dataset")
root_dir = os.path.expanduser("~/workspace/us_robot/DataSet/phantomDataset/vessel_dataset")

In [6]:
try:
    train_OF(unet,flownet,device,val_per=0.1,epochs=10,batch_size=10,resize_to=None)
except KeyboardInterrupt:
    sys.exit()

  config = yaml.load(open(root_dir+'/config.yaml'))
  "See the documentation of nn.Upsample for details.".format(mode))


---+++-++
[1,    10] loss: 0.745
-+++++--++
[1,    20] loss: 0.593
+++-+++-+-
[1,    30] loss: 0.238
---+++----
[1,    40] loss: 0.275
-++++++++-
[1,    50] loss: 0.099
[1,    50] validation loss: 0.255
+++-++--++
[1,    60] loss: 0.132
+--+----++
[1,    70] loss: 0.149
+--++---+-
[1,    80] loss: 0.165
++-+++--++
[1,    90] loss: 0.076
--+++++-+-
[1,   100] loss: 0.145
[1,   100] validation loss: 0.479
-++-+-++-+
[2,   110] loss: 0.211
----++--+-
[2,   120] loss: 0.185
--+++++---
[2,   130] loss: 0.141
++----+-++
[2,   140] loss: 0.089
-+--++----
[2,   150] loss: 0.112
[2,   150] validation loss: 0.104
--+-+--+-+
[2,   160] loss: 0.082
--+--+-++-
[2,   170] loss: 0.061
-+---+-+++
[2,   180] loss: 0.053
+---+-+++-
[2,   190] loss: 0.069
------+---
[2,   200] loss: 0.077
[2,   200] validation loss: 0.057
-++--+++--
[3,   210] loss: 0.039
+--+-+-+++
[3,   220] loss: 0.057
+------+++
[3,   230] loss: 0.062
+-------+-
[3,   240] loss: 0.053
----++++-+
[3,   250] loss: 0.042
[3,   250] vali

In [7]:
#for python
torch.save(unet.state_dict(), './unet_of_usseg_phantom.pth')
# #for c++
# traced_script_module = torch.jit.trace(unet, img)
# traced_script_module.save("./unet_of_usseg_traced.pt")

## Inference

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

PATH = os.path.expanduser("~/workspace/us_robot/network/unet_of_usseg.pth")
unet = UNet_OF(init_features=64).to(device)
unet.load_state_dict(torch.load(PATH))
unet = unet.eval()

class FlowNet2Args():
    fp16 = False
    rgb_max = 255

args = FlowNet2Args()
flownet = models.FlowNet2(args).cuda()
flownet.load_state_dict(torch.load('/home/zhenyuli/Downloads/FlowNet2_checkpoint.pth.tar')["state_dict"])
flownet = flownet.eval()

In [6]:
unet = unet.eval()
flownet = flownet.eval()

In [7]:
test_dir = os.path.expanduser("~/workspace/us_robot/DataSet/realDataSet/linear/vessel_test2/img")
pred_dir = os.path.expanduser("~/workspace/us_robot/DataSet/realDataSet/linear/vessel_pred")

testset_list = os.listdir(test_dir)
testset_list = sorted(list(filter(lambda x: x.endswith('png'), testset_list)))
resize_to=[256,256]
transform_image = transforms.Compose([
    transforms.Resize(resize_to),
    transforms.ToTensor(),
    transforms.Normalize(0.5,0.5) #Division by 255 is done, when the transformation assumes an image.
    ])

invtransform_label = transforms.Compose([
    transforms.ToPILImage(),
    #transforms.Resize([512,512])
    ])

for sample in testset_list:
    image_path = os.path.join(test_dir,sample)
    label_path = os.path.join(pred_dir,sample)
    
    img = Image.open(image_path).convert("L")
    #img = ImageEnhance.Contrast(img).enhance(1.5)
    img = transform_image(img).unsqueeze(0)
    labels_curr_of = torch.Tensor(np.zeros_like(img))
    img = img.to(device)
    labels_curr_of = labels_curr_of.to(device)
    #label = Image.open(label_path)
    #label = transform_label(label).to(device)
    #label = label.unsqueeze(0)
    
    
    with torch.no_grad():
        pred = unet(img,labels_curr_of, mode=0)
    
    #DiceIndex = (1 - DiceLoss(pred,label)).cpu().item()

    pred = invtransform_label(pred.cpu().squeeze(0))
    #fname = "pred%.2f.png"%DiceIndex
    fname = sample
    sav_path = os.path.join(pred_dir,fname)
    pred.save(sav_path)

In [8]:
test_dir = os.path.expanduser("~/workspace/us_robot/DataSet/realDataSet/linear/vessel_test2")
pred_dir = os.path.expanduser("~/workspace/us_robot/DataSet/realDataSet/linear/vessel_pred2")

transform_image = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(0.5,0.5)
    ])
invtransform_label = transforms.Compose([
    transforms.ToPILImage(),
    #transforms.Resize([512,512])
    ])

#eval_dataSet = UltraSoundDataSet_OF(test_dir,transform_image,resize_to=[256,256])
eval_dataSet = UltraSoundDataSet_OF(test_dir,transform_image)
eval_loader = DataLoader(eval_dataSet,batch_size=1,shuffle=False,num_workers=4)

eval_loss = 0
for idx,batch in enumerate(eval_loader):
    im, images_curr, labels_prev_init, labels_curr = batch
    
    #calculate optical flow
    im = im.to(device)
    with torch.no_grad():
        flow = flownet(im)

    flow = np.array(flow.cpu())

    labels_curr_of = np.zeros_like(labels_curr)

    #estimate vessel position by using the flow
    if idx == 0:
        labels_curr_of[0,...] = get_next_mask(labels_prev_init[0,0,...],flow[0,...])
    else:        
        labels_curr_of[0,...] = get_next_mask(labels_prev[0,0,...],flow[0,...])

    labels_curr_of = torch.Tensor(labels_curr_of).to(device,dtype=torch.float32)
    labels_curr = labels_curr.type(torch.float32).to(device)
    images_curr = images_curr.to(device)
    
    with torch.no_grad():
        pred = unet(images_curr,labels_curr_of,mode=int(idx>0))
        #pred = unet(images_curr,labels_curr_of,mode=0)
    
    eval_loss += DiceLoss(pred,labels_curr)
    
    labels_prev = pred.cpu()
    
    pred = invtransform_label(pred.cpu().squeeze(0))
    #fname = "pred%.2f.png"%DiceIndex
    fname = "pred%03d.png" %idx
    sav_path = os.path.join(pred_dir,fname)
    pred.save(sav_path)
    
print('evaluation loss: %.3f' %(eval_loss / len(eval_loader)))

  "See the documentation of nn.Upsample for details.".format(mode))


evaluation loss: 0.145


In [12]:
test_dir = os.path.expanduser("~/workspace/us_robot/DataSet/realDataSet/linear/vessel_dataset")
transform_image = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(0.5,0.5)
    ])

eval_dataSet = UltraSoundDataSet_OF(test_dir,transform_image)
eval_loader = DataLoader(eval_dataSet,batch_size=1,shuffle=False,num_workers=4)
batch = next(iter(eval_loader))
im, images_curr, labels_prev_init, labels_curr = batch