In [1]:
import numpy as np
import nibabel as nib
from glob import glob
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import datasets
from models.detr import DETR
from models.segmentation import DETRsegm
from models.matcher import HungarianMatcher
from hubconf import detr_resnet101_panoptic, detr_resnet3d_panoptic
from torchvision.transforms import Resize
import open3d as o3d

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
def sort_func(path):
    path_id = int(path.split('/')[-1].split('_')[1])
    return path_id

In [3]:
def show_image_and_label(image, target):
    fig, axs = plt.subplots(nrows=1,ncols=2, squeeze=False,figsize=(12, 12))
    axs[0, 0].imshow(image)
    axs[0, 0].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[], title='Image')
    axs[0, 1].imshow(target)
    axs[0, 1].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[], title='Target')

In [4]:
def uniform_downsample_image(image, down_scale = 8):
    image_shape = image.shape
    dim_0_indexes = torch.arange(0, image_shape[0], down_scale)
    dim_1_indexes = torch.arange(0, image_shape[1], down_scale)
    dim_2_indexes = torch.arange(0, image_shape[2], down_scale)
    downsampled_image = image[dim_0_indexes,:,:]
    downsampled_image = downsampled_image[:,dim_1_indexes,:]
    downsampled_image = downsampled_image[:,:,dim_2_indexes]
    return downsampled_image

In [5]:
def visualize_segmented_image(segmented_image):
#     segmented_image.squeeze()
    if segmented_image.get_device() > -1: segmented_image = segmented_image.to('cpu')
    classes = segmented_image.unique()
    colors = [[128,128,128],[255,0,0],[255,255,0],[0,255,0],[0,255,255],[0,0,255],[255,0,255]]
    vis = o3d.visualization.Visualizer()
    vis.create_window()
    for _class in classes[1:]:
        points_numpy = (segmented_image == _class).nonzero().numpy()
        o3d_point_cloud = o3d.geometry.PointCloud()
        o3d_point_cloud.points = o3d.utility.Vector3dVector(points_numpy)
        o3d_point_cloud.estimate_normals()
        o3d_point_cloud.paint_uniform_color(np.array(colors[int(_class)-1])/255)
        vis.add_geometry(o3d_point_cloud)
    vis.run()
    vis.destroy_window()

In [6]:
def get_bounding_boxes(segmented_image):
    if segmented_image.get_device() > -1: segmented_image = segmented_image.to('cpu')
    classes = segmented_image.unique()
    bb_list = []
    for class_ in classes: 
        points = (segmented_image == 1).nonzero()
        x_min, x_max = points[:,0].min(), points[:,0].max()
        y_min, y_max = points[:,1].min(), points[:,1].max()
        z_min, z_max = points[:,2].min(), points[:,2].max()
        bb = torch.tensor([(x_max-x_min)/2, (y_max-y_min)/2, (z_max-z_min)/2, x_max-x_min, y_max-y_min, z_max-z_min])
        bb[0::2] = bb[0::2]/torch.tensor(segmented_image.shape)
        bb[1::2] = bb[1::2]/torch.tensor(segmented_image.shape)
        bb_list.append(bb)
    bbs = torch.stack(bb_list)
    return bbs

In [7]:
def get_masks(segmented_image):
    if segmented_image.get_device() > -1: segmented_image = segmented_image.to('cpu')
    classes = segmented_image.unique()
    masks_list = []
    for class_ in classes[1:]:
        mask = segmented_image.clone()
        mask[segmented_image == class_] = 1
        mask[segmented_image != class_] = 0
        masks_list.append(mask)
    masks = torch.stack(masks_list)
    return masks

In [8]:
def get_labels(segmented_image):
    return segmented_image.unique().int()

In [9]:
def create_target_dict(segmented_image):
    labels = get_labels(segmented_image)
    boxes = get_bounding_boxes(segmented_image)
    masks = get_masks(segmented_image)
    target_dict = {'labels': labels, 'boxes': boxes, 'masks': masks}
    return target_dict

In [18]:
class DatasetForSegmentation(Dataset):
    
    def __init__(self, image_paths, target_paths, down_scale=8):
        self.image_paths = image_paths
        self.target_paths = target_paths
        self.down_scale = down_scale
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, i):
        image_np_array = nib.load(image_paths[i]).get_fdata()
        image_torch_tensor = torch.from_numpy(image_np_array)
        downsampled_image = uniform_downsample_image(image_torch_tensor, down_scale=self.down_scale)
        target_np_array = nib.load(target_paths[i]).get_fdata()
        target_torch_tensor = torch.from_numpy(target_np_array)
        downsampled_target = uniform_downsample_image(target_torch_tensor, down_scale=self.down_scale)
        target_dict = create_target_dict(target_torch_tensor)
        return downsampled_image, target_dict

In [19]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [20]:
path = '/home/francisco/workspace/ImageCHD_dataset'

In [21]:
image_paths = glob(f'{path}/*image.nii.gz',recursive=True)
target_paths = glob(f'{path}/*label.nii.gz',recursive=True)

In [22]:
image_paths.sort(key=sort_func)
target_paths.sort(key=sort_func)

In [23]:
dset = DatasetForSegmentation(image_paths,target_paths)

In [24]:
image, target = dset[2]
image = image.to(device)

pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1
INFO - 2022-10-18 09:30:13,371 - batteryrunners - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1
pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1
INFO - 2022-10-18 09:30:13,966 - batteryrunners - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1


In [25]:
# downsampled_image = uniform_downsample_image(image)
# downsampled_target = uniform_downsample_image()
image_reshaped = image.unsqueeze(0).unsqueeze(0).float()
# downsampled_label_reshaped = downsampled_label.unsqueeze(0).unsqueeze(0)
# im_right_shape = im.transpose(2,0).unsqueeze(0).unsqueeze(0).float()
# label_right_shape = label.transpose(2,0).unsqueeze(0).unsqueeze(0).float()
# new_size = [int(im_right_shape.shape[2]/8),int(im_right_shape.shape[3]/8),int(im_right_shape.shape[4]/8)]
# im_resized = im_right_shape.resize_((1,1,new_size[0],new_size[1],new_size[2]))
# label_resized = label_right_shape.resize_((1,1,new_size[0],new_size[1],new_size[2]))

In [26]:
detr_seg = detr_resnet3d_panoptic()
# detr_seg = detr_resnet11_panoptic()
detr_seg.eval();
detr_seg.to(device);

In [27]:
# im = torch.ones((1, 1, 32, 64, 64),device=device)
# im = torch.ones((1, 3, 128, 128),device=device)

In [28]:
outputs = detr_seg(image_reshaped)

  dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)


In [29]:
outputs['pred_boxes'].shape

torch.Size([1, 100, 6])

In [30]:
outputs['pred_logits'].shape

torch.Size([1, 100, 8])

3D: torch.Size([1, 100, 8, 16, 16])

2D: torch.Size([1, 100, 32, 32])

***

***

In [17]:
# label_cpu = label.to('cpu')

In [20]:
segmented_image = downsampled_label

In [28]:
points_numpy

array([[ 8, 43, 19],
       [ 8, 47, 22],
       [ 9, 32, 20],
       ...,
       [19, 38, 24],
       [19, 39, 42],
       [19, 40, 25]])

In [60]:
segmented_image.unique()

tensor([0.])

In [None]:
""" The forward expects a NestedTensor, which consists of:
       - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
       - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels

torch.Size([1, 1, 32, 128, 128]) torch.Size([1, 32, 128, 128])

torch.Size([1, 2048, 8, 32, 32]) torch.Size([1, 32, 128, 128])

torch.Size([1, 2048, 4, 4]) torch.Size([1, 4, 4])


torch.Size([1, 256, 4, 4])



torch.Size([1, 2048, 8, 32, 32]) torch.Size([1, 32, 128, 128])


torch.Size([1, 384, 32, 128, 128])

torch.Size([1, 256, 8, 32, 32]) torch.Size([1, 32, 128, 128]) torch.Size([100, 256]) torch.Size([1, 256, 32, 128, 128])

2D:

torch.Size([1, 256, 4, 4]) torch.Size([1, 4, 4]) torch.Size([100, 256]) torch.Size([1, 256, 4, 4])

In [31]:
def loss_masks(outputs, targets, indices, num_boxes):
    """Compute the losses related to the masks: the focal loss and the dice loss.
       targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
    """
    assert "pred_masks" in outputs

    src_idx = self._get_src_permutation_idx(indices)
    tgt_idx = self._get_tgt_permutation_idx(indices)
    src_masks = outputs["pred_masks"]
    src_masks = src_masks[src_idx]
    masks = [t["masks"] for t in targets]
    # TODO use valid to mask invalid areas due to padding in loss
    target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
    target_masks = target_masks.to(src_masks)
    target_masks = target_masks[tgt_idx]

    # upsample predictions to the target size
    src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-3:],
                            mode="bilinear", align_corners=False)
    src_masks = src_masks[:, 0].flatten(1)

    target_masks = target_masks.flatten(1)
    target_masks = target_masks.view(src_masks.shape)
    losses = {
        "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
        "loss_dice": dice_loss(src_masks, target_masks, num_boxes),
    }
    return losses

In [32]:
matcher = HungarianMatcher()

In [36]:
outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}

# Retrieve the matching between the outputs of the last layer and the targets
indices = matcher(outputs_without_aux, [target])

IndexError: tensors used as indices must be long, byte or bool tensors

In [None]:
num_boxes = sum(len(t["labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
if is_dist_avail_and_initialized():
    torch.distributed.all_reduce(num_boxes)
num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()

In [None]:
loss_masks()