In [1]:
import numpy as np
import nibabel as nib
from glob import glob
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import datasets
from models.detr import DETR, LossSegmentation
from models.segmentation import DETRsegm, PostProcessPanoptic
from models.matcher import HungarianMatcher
from hubconf import detr_resnet101_panoptic, detr_resnet3d_panoptic
from torchvision.transforms import Resize
from util.misc import get_world_size
import open3d as o3d
import torch.optim as optim
import pkbar

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
nib.imageglobals.logger.level = 51

In [3]:
def sort_func(path):
    path_id = int(path.split('/')[-1].split('_')[1])
    return path_id

In [4]:
def show_image_and_label(image, target):
    fig, axs = plt.subplots(nrows=1,ncols=2, squeeze=False,figsize=(12, 12))
    axs[0, 0].imshow(image)
    axs[0, 0].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[], title='Image')
    axs[0, 1].imshow(target)
    axs[0, 1].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[], title='Target')

In [5]:
def uniform_downsample_image(image, down_scale = 8):
    image_shape = image.shape
    dim_0_indexes = torch.arange(0, image_shape[0], down_scale)
    dim_1_indexes = torch.arange(0, image_shape[1], down_scale)
    dim_2_indexes = torch.arange(0, image_shape[2], down_scale)
    downsampled_image = image[dim_0_indexes,:,:]
    downsampled_image = downsampled_image[:,dim_1_indexes,:]
    downsampled_image = downsampled_image[:,:,dim_2_indexes]
    return downsampled_image

In [6]:
def visualize_segmented_image(segmented_image):
#     segmented_image.squeeze()
    if segmented_image.get_device() > -1: segmented_image = segmented_image.to('cpu')
    classes = segmented_image.unique()
    colors = [[128,128,128],[255,0,0],[255,255,0],[0,255,0],[0,255,255],[0,0,255],[255,0,255]]
    vis = o3d.visualization.Visualizer()
    vis.create_window()
    for _class in classes[1:]:
        points_numpy = (segmented_image == _class).nonzero().numpy()
        o3d_point_cloud = o3d.geometry.PointCloud()
        o3d_point_cloud.points = o3d.utility.Vector3dVector(points_numpy)
        o3d_point_cloud.estimate_normals()
        o3d_point_cloud.paint_uniform_color(np.array(colors[int(_class)-1])/255)
        vis.add_geometry(o3d_point_cloud)
    vis.run()
    vis.destroy_window()

In [7]:
def get_bounding_boxes(segmented_image):
    if segmented_image.get_device() > -1: segmented_image = segmented_image.to('cpu')
    classes = segmented_image.unique()
    bb_list = []
    for class_ in classes: 
        points = (segmented_image == 1).nonzero()
        x_min, x_max = points[:,0].min(), points[:,0].max()
        y_min, y_max = points[:,1].min(), points[:,1].max()
        z_min, z_max = points[:,2].min(), points[:,2].max()
        bb = torch.tensor([(x_max-x_min)/2, (y_max-y_min)/2, (z_max-z_min)/2, x_max-x_min, y_max-y_min, z_max-z_min])
        bb[0::2] = bb[0::2]/torch.tensor(segmented_image.shape)
        bb[1::2] = bb[1::2]/torch.tensor(segmented_image.shape)
        bb_list.append(bb)
    bbs = torch.stack(bb_list)
    return bbs

In [8]:
def get_masks(segmented_image):
    if segmented_image.get_device() > -1: segmented_image = segmented_image.to('cpu')
    classes = segmented_image.unique()
    masks_list = []
    for class_ in classes:
        mask = segmented_image.clone()
        mask[segmented_image == class_] = 1
        mask[segmented_image != class_] = 0
        masks_list.append(mask.short())
    masks = torch.stack(masks_list)
    return masks

In [9]:
def get_labels(segmented_image):
    return segmented_image.unique().long()

In [10]:
def create_target_dict(segmented_image):
    labels = get_labels(segmented_image)
    boxes = get_bounding_boxes(segmented_image)
    masks = get_masks(segmented_image)
    target_dict = {'labels': labels, 'boxes': boxes, 'masks': masks, 'seg_im': segmented_image}
    return target_dict

In [11]:
def compute_panoptic_quality(pred_seg, target):
    ious = []
    for label in target['labels'][1:]:
        target_mask = target['masks'][label] 
        pred_mask = (pred_seg == label).int()
        pred_mask[pred_mask == 0] = -1
        intersection = (pred_mask == target_mask).count_nonzero()
        union = pred_mask[pred_mask == 1].count_nonzero() + target_mask[target_mask == 1].count_nonzero() - intersection
        iou = intersection/union
        ious.append(iou)
    ious_tensor = torch.tensor(ious)
    tp = abs((ious_tensor > 0.5).count_nonzero())
    fn = abs(len(target['labels']) - 1 - tp)
    fp = abs(len(pred_seg.unique()) - 1 - tp)
    if tp > 0:
        sq = ious_tensor[ious_tensor > 0.5].sum() / tp
        rq = rq = tp / (tp + fp/2 + fn/2)
        pq = sq*rq
    else:
        sq = 0
        rq = 0
        pq = 0
    return pq, sq, rq

In [12]:
class DatasetForSegmentation(Dataset):
    
    def __init__(self, image_paths, target_paths, down_scale=8):
        self.image_paths = image_paths
        self.target_paths = target_paths
        self.down_scale = down_scale
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, i):
        image_np_array = nib.load(self.image_paths[i]).get_fdata()
        image_torch_tensor = torch.from_numpy(image_np_array).float()
        downsampled_image = uniform_downsample_image(image_torch_tensor, down_scale=self.down_scale)
        target_np_array = nib.load(self.target_paths[i]).get_fdata()
        target_torch_tensor = torch.from_numpy(target_np_array)
        target_torch_tensor[target_torch_tensor > 7] = 0
        downsampled_target = uniform_downsample_image(target_torch_tensor, down_scale=self.down_scale)
        target_dict = create_target_dict(downsampled_target)
        return downsampled_image, target_dict

In [13]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

In [14]:
path = '/home/francisco/workspace/ImageCHD_dataset'

In [15]:
image_paths = glob(f'{path}/*image.nii.gz',recursive=True)
target_paths = glob(f'{path}/*label.nii.gz',recursive=True)

In [16]:
image_paths.sort(key=sort_func)
target_paths.sort(key=sort_func)

In [17]:
detr_seg = detr_resnet3d_panoptic()
# detr_seg.eval();
# detr_seg.half();
detr_seg.to(device);

In [18]:
# outputs = detr_seg(image_reshaped)

In [19]:
matcher = HungarianMatcher()
loss_segmentation = LossSegmentation(matcher)

In [20]:
is_thing_map = {'0': False, '1': True, '2': True, '3': True, '4': True, '5': True, '6': True, '7': True}
post_process_panoptic = PostProcessPanoptic(is_thing_map)

In [21]:
lr = 0.0001
n_epoch = 1
weight_decay = 10e-4
ACC_Threshold = 0.02

In [22]:
optimizer = optim.AdamW(detr_seg.parameters(), lr=lr, weight_decay=weight_decay)

In [None]:
torch.autograd.set_detect_anomaly(True)
train_dset_size = 87
down_scale = 10

valid_dset = DatasetForSegmentation(image_paths[train_dset_size:],target_paths[train_dset_size:], down_scale=down_scale)
for epoch in range(n_epoch):
    kbar = pkbar.Kbar(target=(train_dset_size+len(valid_dset)-2), epoch=epoch, num_epochs=n_epoch, width=16, stateful_metrics=["Accuracy"])
    running_loss_t = 0.0
    running_loss_v = 0.0
    rdm = torch.randperm(train_dset_size)
    train_image_paths = [image_paths[i] for i in rdm]
    train_target_paths = [target_paths[i] for i in rdm]
    train_ds = DatasetForSegmentation(train_image_paths, train_target_paths, down_scale=down_scale)

    for idx, b in enumerate(train_ds):
        image = b[0]#.half()
        image = image.unsqueeze(0).unsqueeze(0).to(device)
        target = b[1]
        for i in target.items():
#             if i[0] in {'boxes'}:
#                 target[i[0]] = i[1].half()
            target[i[0]] = i[1].to(device)
            
        outputs = detr_seg(image)
        optimizer.zero_grad()
        
        loss = loss_segmentation.loss_masks(outputs, [target])
        loss['loss_mask'].backward()
        optimizer.step()
        running_loss_t = loss['loss_mask'].item()
        kbar.update(idx, values=[("Train Loss", running_loss_t)])

    kbar.add(1, values=[("Validation Loss", 0), ("Accuracy", 0)])        
    with torch.no_grad():
        acc = 0
        for ids,b in enumerate(valid_dset):
            image = b[0].unsqueeze(0).unsqueeze(0).to(device)
            target = b[1]
            for i in target.items():
                target[i[0]] = i[1].to(device)

            outputs = detr_seg(image)
            loss = loss_segmentation.loss_masks(outputs, [target])
            running_loss_v = loss['loss_mask'].item()
            pred_seg = post_process_panoptic(outputs,[tuple(torch.tensor(b[0].shape).tolist())]).to(device)
            pq,sq,rq = compute_panoptic_quality(pred_seg, target)
            kbar.update(idx+ids, values=[("Validation Loss", running_loss_v), ("PQ", 100*pq),("SQ", 100*sq),("RQ", 100*rq)])
print('Finished Training')

Epoch: 1/1


  dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)


 40/108 [====>...........] - ETA: 1:39 - Train Loss: 0.0536

***

# Visualization:

In [33]:
is_thing_map = {'0': False, '1': True, '2': True, '3': True, '4': True, '5': True, '6': True, '7': True}

In [34]:
post_process_panoptic = PostProcessPanoptic(is_thing_map)

In [36]:
result = post_process_panoptic(outputs,[tuple(torch.tensor(image.squeeze().shape).tolist())])

In [37]:
visualize_segmented_image(result)