In [1]:
import numpy as np
import nibabel as nib
from glob import glob
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import datasets
from models.detr import DETR
from models.segmentation import DETRsegm
from models.matcher import HungarianMatcher
from hubconf import detr_resnet101_panoptic, detr_resnet3d_panoptic
from torchvision.transforms import Resize
import open3d as o3d

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
def sort_func(path):
    path_id = int(path.split('/')[-1].split('_')[1])
    return path_id

In [3]:
def show_image_and_label(image, label):
    fig, axs = plt.subplots(nrows=1,ncols=2, squeeze=False,figsize=(12, 12))
    axs[0, 0].imshow(image)
    axs[0, 0].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[], title='Image')
    axs[0, 1].imshow(label)
    axs[0, 1].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[], title='Label')

In [4]:
def uniform_downsample_image(image, down_scale = 8):
    image_shape = image.shape
    dim_0_indexes = torch.arange(0, image_shape[0], down_scale)
    dim_1_indexes = torch.arange(0, image_shape[1], down_scale)
    dim_2_indexes = torch.arange(0, image_shape[2], down_scale)
    downsampled_image = image[dim_0_indexes,:,:]
    downsampled_image = downsampled_image[:,dim_1_indexes,:]
    downsampled_image = downsampled_image[:,:,dim_2_indexes]
    return downsampled_image

In [5]:
class DatasetForSegmentation(Dataset):
    
    def __init__(self, image_paths, label_paths):
        self.image_paths = image_paths
        self.label_paths = label_paths
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, i):
        image_np_array = nib.load(image_paths[i]).get_fdata()
        image_torch_tensor = torch.from_numpy(image_np_array)
        label_np_array = nib.load(label_paths[i]).get_fdata()
        label_torch_tensor = torch.from_numpy(label_np_array)
        return image_torch_tensor, label_torch_tensor

In [6]:
def visualize_segmented_image(segmented_image):
#     segmented_image.squeeze()
    if segmented_image.get_device() > -1: segmented_image = segmented_image.to('cpu')
    classes = segmented_image.unique()
    colors = [[128,128,128],[255,0,0],[255,255,0],[0,255,0],[0,255,255],[0,0,255],[255,0,255]]
    vis = o3d.visualization.Visualizer()
    vis.create_window()
    for _class in classes[1:]:
        points_numpy = (segmented_image == _class).nonzero().numpy()
        o3d_point_cloud = o3d.geometry.PointCloud()
        o3d_point_cloud.points = o3d.utility.Vector3dVector(points_numpy)
        o3d_point_cloud.estimate_normals()
        o3d_point_cloud.paint_uniform_color(np.array(colors[int(_class)-1])/255)
        vis.add_geometry(o3d_point_cloud)
    vis.run()
    vis.destroy_window()

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [8]:
path = '/home/francisco/workspace/ImageCHD_dataset'

In [9]:
image_paths = glob(f'{path}/*image.nii.gz',recursive=True)
label_paths = glob(f'{path}/*label.nii.gz',recursive=True)

In [10]:
image_paths.sort(key=sort_func)
label_paths.sort(key=sort_func)

In [11]:
dset = DatasetForSegmentation(image_paths,label_paths)

In [12]:
image, label = dset[2]
image = image.to(device)
label = label.to(device)

pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1
INFO - 2022-10-17 18:35:54,597 - batteryrunners - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1
pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1
INFO - 2022-10-17 18:35:55,177 - batteryrunners - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


In [29]:
downsampled_image = uniform_downsample_image(image)
downsampled_label = uniform_downsample_image(label)
downsampled_image_reshaped = downsampled_image.unsqueeze(0).unsqueeze(0).float()
downsampled_label_reshaped = downsampled_label.unsqueeze(0).unsqueeze(0)
# im_right_shape = im.transpose(2,0).unsqueeze(0).unsqueeze(0).float()
# label_right_shape = label.transpose(2,0).unsqueeze(0).unsqueeze(0).float()
# new_size = [int(im_right_shape.shape[2]/8),int(im_right_shape.shape[3]/8),int(im_right_shape.shape[4]/8)]
# im_resized = im_right_shape.resize_((1,1,new_size[0],new_size[1],new_size[2]))
# label_resized = label_right_shape.resize_((1,1,new_size[0],new_size[1],new_size[2]))

In [31]:
segmented_image = downsampled_label_reshaped
points = (segmented_image.squeeze() == 1).nonzero()
x_min, x_max = points[:,0].min(), points[:,0].max()
y_min, y_max = points[:,1].min(), points[:,1].max()
z_min, z_max = points[:,2].min(), points[:,2].max()

In [32]:
x_min, x_max, y_min, y_max, z_min, z_max

(tensor(35, device='cuda:0'),
 tensor(46, device='cuda:0'),
 tensor(23, device='cuda:0'),
 tensor(36, device='cuda:0'),
 tensor(5, device='cuda:0'),
 tensor(13, device='cuda:0'))

In [44]:
bb_center_point = torch.tensor([x_max-x_min, y_max-y_min, z_max-z_min])/2

In [45]:
bb_center_point

tensor([5.5000, 6.5000, 4.0000])

In [46]:
visualize_segmented_image(downsampled_label_reshaped)

RuntimeError: Unable to cast Python instance to C++ type (compile in debug mode for details)

In [18]:
detr_seg = detr_resnet3d_panoptic()
# detr_seg = detr_resnet11_panoptic()
detr_seg.eval();
detr_seg.to(device);

In [21]:
# im = torch.ones((1, 1, 32, 64, 64),device=device)
# im = torch.ones((1, 3, 128, 128),device=device)

In [19]:
outputs = detr_seg(downsampled_image_reshaped)

  dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)


In [22]:
outputs.items()

dict_items([('pred_logits', tensor([[[-0.3566,  0.6328, -0.3596,  1.2090,  0.2509, -0.5311, -0.3903,
          -0.3225],
         [-0.3585,  0.6374, -0.3610,  1.2049,  0.2505, -0.5342, -0.3876,
          -0.3288],
         [-0.3609,  0.6311, -0.3608,  1.2040,  0.2550, -0.5272, -0.3937,
          -0.3231],
         [-0.3507,  0.6286, -0.3637,  1.2081,  0.2587, -0.5313, -0.3967,
          -0.3178],
         [-0.3582,  0.6307, -0.3656,  1.2052,  0.2529, -0.5335, -0.3885,
          -0.3222],
         [-0.3607,  0.6345, -0.3598,  1.2026,  0.2552, -0.5259, -0.3894,
          -0.3241],
         [-0.3558,  0.6281, -0.3557,  1.2108,  0.2522, -0.5297, -0.3985,
          -0.3267],
         [-0.3517,  0.6349, -0.3661,  1.2063,  0.2616, -0.5324, -0.3886,
          -0.3191],
         [-0.3597,  0.6361, -0.3631,  1.2068,  0.2585, -0.5327, -0.3899,
          -0.3274],
         [-0.3545,  0.6325, -0.3581,  1.2059,  0.2526, -0.5274, -0.3874,
          -0.3265],
         [-0.3538,  0.6425, -0.3654,  1.20

In [21]:
outputs['pred_masks'].shape

torch.Size([1, 100, 7, 16, 16])

3D: torch.Size([1, 100, 8, 16, 16])

2D: torch.Size([1, 100, 32, 32])

***

***

In [17]:
# label_cpu = label.to('cpu')

In [20]:
segmented_image = downsampled_label

In [21]:
segmented_image = segmented_image.squeeze()
if segmented_image.get_device() > -1: segmented_image = segmented_image.to('cpu')
classes = segmented_image.unique()
colors = [[128,128,128],[255,0,0],[255,255,0],[0,255,0],[0,255,255],[0,0,255],[255,0,255]]
vis = o3d.visualization.Visualizer()
vis.create_window()
for _class in classes[1:]:
    points_numpy = (segmented_image == _class).nonzero().numpy()
    o3d_point_cloud = o3d.geometry.PointCloud()
    o3d_point_cloud.points = o3d.utility.Vector3dVector(points_numpy)
    o3d_point_cloud.estimate_normals()
    o3d_point_cloud.paint_uniform_color(np.array(colors[int(_class)-1])/255)
    vis.add_geometry(o3d_point_cloud)
vis.run()
vis.destroy_window()

In [28]:
points_numpy

array([[ 8, 43, 19],
       [ 8, 47, 22],
       [ 9, 32, 20],
       ...,
       [19, 38, 24],
       [19, 39, 42],
       [19, 40, 25]])

In [60]:
segmented_image.unique()

tensor([0.])

In [None]:
""" The forward expects a NestedTensor, which consists of:
       - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
       - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels

torch.Size([1, 1, 32, 128, 128]) torch.Size([1, 32, 128, 128])

torch.Size([1, 2048, 8, 32, 32]) torch.Size([1, 32, 128, 128])

torch.Size([1, 2048, 4, 4]) torch.Size([1, 4, 4])


torch.Size([1, 256, 4, 4])



torch.Size([1, 2048, 8, 32, 32]) torch.Size([1, 32, 128, 128])


torch.Size([1, 384, 32, 128, 128])

torch.Size([1, 256, 8, 32, 32]) torch.Size([1, 32, 128, 128]) torch.Size([100, 256]) torch.Size([1, 256, 32, 128, 128])

2D:

torch.Size([1, 256, 4, 4]) torch.Size([1, 4, 4]) torch.Size([100, 256]) torch.Size([1, 256, 4, 4])

In [18]:
def loss_masks(outputs, targets, indices, num_boxes):
    """Compute the losses related to the masks: the focal loss and the dice loss.
       targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
    """
    assert "pred_masks" in outputs

    src_idx = self._get_src_permutation_idx(indices)
    tgt_idx = self._get_tgt_permutation_idx(indices)
    src_masks = outputs["pred_masks"]
    src_masks = src_masks[src_idx]
    masks = [t["masks"] for t in targets]
    # TODO use valid to mask invalid areas due to padding in loss
    target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
    target_masks = target_masks.to(src_masks)
    target_masks = target_masks[tgt_idx]

    # upsample predictions to the target size
    src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-3:],
                            mode="bilinear", align_corners=False)
    src_masks = src_masks[:, 0].flatten(1)

    target_masks = target_masks.flatten(1)
    target_masks = target_masks.view(src_masks.shape)
    losses = {
        "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
        "loss_dice": dice_loss(src_masks, target_masks, num_boxes),
    }
    return losses

In [19]:
matcher = HungarianMatcher()

In [22]:
outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}

# Retrieve the matching between the outputs of the last layer and the targets
indices = matcher(outputs_without_aux, label_resized)

IndexError: too many indices for tensor of dimension 4

In [None]:
num_boxes = sum(len(t["labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
if is_dist_avail_and_initialized():
    torch.distributed.all_reduce(num_boxes)
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()

In [None]:
loss_masks()