In [28]:
import numpy as np
import nibabel as nib
from glob import glob
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import datasets
from models.detr import DETR, LossCustom
from models.segmentation import DETRsegm, PostProcessPanoptic
from models.matcher import HungarianMatcher
from hubconf import detr_resnet101_panoptic, detr_resnet3d_panoptic
from torchvision.transforms import Resize
from util.misc import get_world_size
import open3d as o3d
import torch.optim as optim
import pkbar

In [29]:
nib.imageglobals.logger.level = 51

In [30]:
def sort_func(path):
    path_id = int(path.split('/')[-1].split('_')[1])
    return path_id

In [31]:
def show_image_and_label(image, target):
    fig, axs = plt.subplots(nrows=1,ncols=2, squeeze=False,figsize=(12, 12))
    axs[0, 0].imshow(image)
    axs[0, 0].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[], title='Image')
    axs[0, 1].imshow(target)
    axs[0, 1].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[], title='Target')

In [32]:
def uniform_downsample_image(image, down_scale = 8):
    image_shape = image.shape
    dim_0_indexes = torch.arange(0, image_shape[0], down_scale)
    dim_1_indexes = torch.arange(0, image_shape[1], down_scale)
    dim_2_indexes = torch.arange(0, image_shape[2], down_scale)
    downsampled_image = image[dim_0_indexes,:,:]
    downsampled_image = downsampled_image[:,dim_1_indexes,:]
    downsampled_image = downsampled_image[:,:,dim_2_indexes]
    return downsampled_image

In [33]:
def downsample_image_to_given_size(image, size = 128):
    image_shape = image.shape
    dim_0_indexes = torch.arange(0, image_shape[0], int(image_shape[0]/size))
    dim_1_indexes = torch.arange(0, image_shape[1], int(image_shape[1]/size))
    dim_2_indexes = torch.arange(0, image_shape[2], int(image_shape[2]/size))
    downsampled_image = image[dim_0_indexes,:,:]
    downsampled_image = downsampled_image[:,dim_1_indexes,:]
    downsampled_image = downsampled_image[:,:,dim_2_indexes]
    return downsampled_image

In [34]:
def visualize_segmented_image(segmented_image):
#     segmented_image.squeeze()
    if segmented_image.get_device() > -1: segmented_image = segmented_image.to('cpu')
    classes = segmented_image.unique()
    colors = [[128,128,128],[255,0,0],[255,255,0],[0,255,0],[0,255,255],[0,0,255],[255,0,255]]
    vis = o3d.visualization.Visualizer()
    vis.create_window()
    for _class in classes[1:]:
        points_numpy = (segmented_image == _class).nonzero().numpy()
        o3d_point_cloud = o3d.geometry.PointCloud()
        o3d_point_cloud.points = o3d.utility.Vector3dVector(points_numpy)
        o3d_point_cloud.estimate_normals()
        o3d_point_cloud.paint_uniform_color(np.array(colors[int(_class)-1])/255)
        vis.add_geometry(o3d_point_cloud)
    vis.run()
    vis.destroy_window()

In [35]:
def get_bounding_boxes(segmented_image):
    if segmented_image.get_device() > -1: segmented_image = segmented_image.to('cpu')
    classes = segmented_image.unique()
    bb_list = []
    for class_ in classes[1:]: 
        points = (segmented_image == class_).nonzero()
        x_min, x_max = points[:,0].min(), points[:,0].max()
        y_min, y_max = points[:,1].min(), points[:,1].max()
        z_min, z_max = points[:,2].min(), points[:,2].max()
        bb = torch.tensor([(x_max-x_min)/2, (y_max-y_min)/2, (z_max-z_min)/2, x_max-x_min, y_max-y_min, z_max-z_min])
        bb[0::2] = bb[0::2]/torch.tensor(segmented_image.shape)
        bb[1::2] = bb[1::2]/torch.tensor(segmented_image.shape)
        bb_list.append(bb)
    bbs = torch.stack(bb_list)
    return bbs

In [36]:
def get_masks(segmented_image):
    if segmented_image.get_device() > -1: segmented_image = segmented_image.to('cpu')
    classes = segmented_image.unique()
    masks_list = []
    for class_ in classes[1:]:
        mask = segmented_image.clone()
        mask[segmented_image == class_] = 1
        mask[segmented_image != class_] = 0
        masks_list.append(mask.short())
    masks = torch.stack(masks_list)
    return masks

In [37]:
def get_labels(segmented_image):
    return segmented_image.unique().long()[1:]

In [38]:
def create_target_dict_panoptic(segmented_image):
    labels = get_labels(segmented_image)
    boxes = get_bounding_boxes(segmented_image)
    masks = get_masks(segmented_image)
    target_dict = {'labels': labels, 'boxes': boxes, 'masks': masks, 'seg_im': segmented_image}
    return target_dict

In [39]:
def create_target_dict_detection(segmented_image):
    labels = get_labels(segmented_image)
    boxes = get_bounding_boxes(segmented_image)
    target_dict = {'labels': labels, 'boxes': boxes}
    return target_dict

In [40]:
def get_box_corners(box):
    x_min = box[0] - box[3]/2
    x_max = box[0] + box[3]/2
    y_min = box[1] - box[4]/2
    y_max = box[1] + box[4]/2
    z_min = box[2] - box[5]/2
    z_max = box[2] + box[5]/2
    return torch.tensor([x_min,y_min,z_min,x_max,y_max,z_max])

In [41]:
def get_intersetion_box_corners(box_1,box_2):
    intersection_box_corners = torch.zeros(6)
    intersection_box_corners[0] = max(box_1[0],box_2[0])
    intersection_box_corners[1] = min(box_1[3],box_2[3])
    intersection_box_corners[2] = max(box_1[1],box_2[1])
    intersection_box_corners[3] = min(box_1[4],box_2[4])
    intersection_box_corners[4] = max(box_1[2],box_2[2])
    intersection_box_corners[5] = min(box_1[5],box_2[5])
    return intersection_box_corners

In [42]:
def compute_box_volume(box):
    box_volume = (box[3]-box[0])*(box[4]-box[1])*(box[5]-box[2])
    if box_volume < 0: box_volume = 0
    return box_volume

In [43]:
def compute_box_iou(box_1,box_2):
    box_1_corners = get_box_corners(box_1)
    box_2_corners = get_box_corners(box_2)
    intersection_box = get_intersetion_box_corners(box_1_corners,box_2_corners)
    intersection_box_volume = compute_box_volume(intersection_box)
    if intersection_box_volume == 0: return 0
    box_1_volume = compute_box_volume(box_1_corners) 
    box_2_volume = compute_box_volume(box_2_corners)
    iou = intersection_box_volume/(box_1_volume+box_2_volume-intersection_box_volume)
    return iou

In [44]:
def compute_average_precision(outputs, target, labels):
    ious = []
    ap = []
    for i, label in enumerate(labels):
        pred_box = outputs['pred_boxes'].squeeze()[i]
        target_box = target['boxes'][label] 
        iou = compute_box_iou(pred_box,target_box)
        ious.append(iou)
    ious_tensor = torch.tensor(ious)
    for t in range(50,95,5):
        tp = (ious_tensor >= t/100).count_nonzero()
        fn = abs(len(target['labels']) - 1 - tp)
        fp = abs((outputs['pred_logits'].squeeze().argmax(1) != 0).count_nonzero() - 1 - tp)
        precision = tp/(tp+fp)
        recall = tp/(tp+fn)
        ap.append(precision*recall)
    ap_tensor = torch.tensor(ap)
    return ap_tensor.mean()

In [45]:
def compute_panoptic_quality(pred_seg, target):
    ious = []
    for i, label in enumerate(target['labels'][1:]):
        target_mask = target['masks'][i] 
        pred_mask = (pred_seg == label).int()
        pred_mask[pred_mask == 0] = -1
        intersection = (pred_mask == target_mask).count_nonzero()
        union = pred_mask[pred_mask == 1].count_nonzero() + target_mask[target_mask == 1].count_nonzero() - intersection
        iou = intersection/union
        ious.append(iou)
    ious_tensor = torch.tensor(ious)
    tp = (ious_tensor > 0.5).count_nonzero()
    fn = abs(len(target['labels']) - 1 - tp)
    fp = abs(len(pred_seg.unique()) - 1 - tp)
    if tp > 0:
        sq = ious_tensor[ious_tensor > 0.5].sum() / tp
        rq = rq = tp / (tp + fp/2 + fn/2)
        pq = sq*rq
    else:
        sq = 0
        rq = 0
        pq = 0
    return pq, sq, rq

In [46]:
class DatasetForDetection(Dataset):
    
    def __init__(self, image_paths, target_paths, down_scale=8):
        self.image_paths = image_paths
        self.target_paths = target_paths
        self.down_scale = down_scale
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, i):
        image_np_array = nib.load(self.image_paths[i]).get_fdata()
        image_torch_tensor = torch.from_numpy(image_np_array).float()
        downsampled_image = uniform_downsample_image(image_torch_tensor, down_scale=self.down_scale)
        downsampled_image = downsampled_image/downsampled_image.max()
#         downsampled_image = downsample_image_to_given_size(image_torch_tensor, size=64)
        target_np_array = nib.load(self.target_paths[i]).get_fdata()
        target_torch_tensor = torch.from_numpy(target_np_array)
        target_torch_tensor[target_torch_tensor > 7] = 0
        downsampled_target = uniform_downsample_image(target_torch_tensor, down_scale=self.down_scale)
#         downsampled_target = downsample_image_to_given_size(target_torch_tensor, size=64)
        target_dict = create_target_dict_detection(downsampled_target)
        return downsampled_image, target_dict

In [47]:
class DatasetForSegmentation(Dataset):
    
    def __init__(self, image_paths, target_paths, down_scale=8):
        self.image_paths = image_paths
        self.target_paths = target_paths
        self.down_scale = down_scale
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, i):
        image_np_array = nib.load(self.image_paths[i]).get_fdata()
        image_torch_tensor = torch.from_numpy(image_np_array).float()
        downsampled_image = uniform_downsample_image(image_torch_tensor, down_scale=self.down_scale)
        downsampled_image = downsampled_image/downsampled_image.max()
#         downsampled_image = downsample_image_to_given_size(image_torch_tensor, size=64)
        target_np_array = nib.load(self.target_paths[i]).get_fdata()
        target_torch_tensor = torch.from_numpy(target_np_array)
        target_torch_tensor[target_torch_tensor > 7] = 0
        downsampled_target = uniform_downsample_image(target_torch_tensor, down_scale=self.down_scale)
#         downsampled_target = downsample_image_to_given_size(target_torch_tensor, size=64)
        target_dict = create_target_dict_panoptic(downsampled_target)
        return downsampled_image, target_dict

In [48]:
def create_dataset(dset_path, dset_type, split_index=87, down_scale=10):
    image_paths = glob(f'{dset_path}/*image.nii.gz',recursive=True)
    target_paths = glob(f'{dset_path}/*label.nii.gz',recursive=True)
    image_paths.sort(key=sort_func)
    target_paths.sort(key=sort_func)
    if dset_type == 'detection':
        train_dset = DatasetForDetection(image_paths[:split_index], target_paths[:split_index], down_scale=down_scale)
        valid_dset = DatasetForDetection(image_paths[split_index:],target_paths[split_index:], down_scale=down_scale)
    if dset_type == 'panoptic':
        train_dset = DatasetForSegmentation(image_paths[:split_index], target_paths[:split_index], down_scale=down_scale)
        valid_dset = DatasetForSegmentation(image_paths[split_index:], target_paths[split_index:], down_scale=down_scale)
    return train_dset, valid_dset

In [59]:
class TrainAndEvaluate():
    def __init__(self, model, optimizer, loss, train_dset, valid_dset):
        self.model = model
        self.optimizer = optimizer
        self.loss = loss
        self.train_dset = train_dset
        self.valid_dset = valid_dset
        self.matcher = loss.matcher 
        aux_target = train_dset[0][1]
        if 'masks' in aux_target.keys():
            self.masks = True
            self.key = 'loss_mask'
            is_thing_map = {'0': False, '1': True, '2': True, '3': True, '4': True, '5': True, '6': True, '7': True}
            self.post_process_panoptic = PostProcessPanoptic(is_thing_map)
        else: 
            self.masks = False
            self.key = 'loss_bbox'        
    
    def train_and_evaluate(self, n_epoch, save=False):
        torch.autograd.set_detect_anomaly(True)

        for epoch in range(n_epoch):
            kbar = pkbar.Kbar(target=(len(self.train_dset)+len(self.valid_dset)-2), epoch=epoch, num_epochs=n_epoch, width=16)
            running_loss_t = 0.0
            running_loss_v = 0.0
            rdm = torch.randperm(len(self.train_dset))
            
            for i, j in enumerate(rdm):
                b = self.train_dset[j]
                image = b[0]
                image = image.unsqueeze(0).unsqueeze(0).to(device)
                target = b[1]
                for t in target.items():
                    target[t[0]] = t[1].to(device)
                outputs = self.model(image)
                self.optimizer.zero_grad()
                if self.masks:
                    l = self.loss.loss_masks(outputs, [target]) #To train masks
                else:
                    l = self.loss.loss_boxes(outputs, [target]) #To train boxes
                l[self.key].backward() 
                self.optimizer.step()
                running_loss_t = l[self.key].item()
                kbar.update(i, values=[("Train Loss", running_loss_t)])
            
            if save:
                torch.save(self.model.state_dict(), '/home/francisco/workspace/CHD_Classifier_by_Francisco_Lourenço/checkpoint_detr_hd.pth')
            
            kbar.add(1, values=[("Validation Loss", 0), ("Accuracy", 0)])        
            with torch.no_grad():
                for ii, b in enumerate(self.valid_dset):
                    image = b[0]
                    image = image.unsqueeze(0).unsqueeze(0).to(device)
                    target = b[1]
                    for t in target.items():
                        target[t[0]] = t[1].to(device)

                    outputs = self.model(image)
                    if self.masks:
                        l = self.loss.loss_masks(outputs, [target]) #To train masks
                    else:
                        l = self.loss.loss_boxes(outputs, [target]) #To train boxes
                    running_loss_v = l[self.key].item()
                    src_idx, trgt_idx = self.matcher(outputs,[target])[0]
                    if self.masks:
                        pred_seg = post_process_panoptic(outputs[trgt_idx],[tuple(torch.tensor(b[0].shape).tolist())]).to(device)
                        pq,sq,rq = compute_panoptic_quality(pred_seg, target)
                        kbar.update(i+ii, values=[("Validation Loss", running_loss_v), ("PQ", 100*pq),("SQ", 100*sq),("RQ", 100*rq)])
                    else:
                        ap = compute_average_precision(outputs, target, trgt_idx) #For boxes
                        kbar.update(i+ii, values=[("Validation Loss", running_loss_v), ("Accuracy", 100*ap)]) #For boxes
        print('Finished Training')
        return self.model

In [50]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
path = '/home/francisco/workspace/ImageCHD_dataset'
# device = torch.device("cpu")

In [51]:
model = detr_resnet3d_panoptic()
model = model.detr #To train boxes
# model.load_state_dict(torch.load('/home/francisco/workspace/CHD_Classifier_by_Francisco_Lourenço/checkpoint_detr.pth'))
model.to(device);

In [52]:
matcher = HungarianMatcher()
loss = LossCustom(matcher)

In [66]:
train_dset, valid_dset = create_dataset(path, 'detection', 87, 4)
optimizer = optim.AdamW(model.parameters(), lr=0.00001, weight_decay=10e-4)
trainer = TrainAndEvaluate(model, optimizer, loss, train_dset, valid_dset)

In [67]:
model = trainer.train_and_evaluate(100, True)

Epoch: 1/100
Epoch: 2/100
Epoch: 3/100
Epoch: 4/100
Epoch: 5/100
Epoch: 6/100
Epoch: 7/100
Epoch: 8/100
Epoch: 9/100
Epoch: 10/100
Epoch: 11/100
Epoch: 12/100
Epoch: 13/100
Epoch: 14/100
Epoch: 15/100
Epoch: 16/100
Epoch: 17/100
Epoch: 18/100
Epoch: 19/100
Epoch: 20/100
Epoch: 21/100
Epoch: 22/100
Epoch: 23/100
Epoch: 24/100
Epoch: 25/100
Epoch: 26/100
Epoch: 27/100
Epoch: 28/100
Epoch: 29/100
Epoch: 30/100
Epoch: 31/100
Epoch: 32/100
Epoch: 33/100
Epoch: 34/100
Epoch: 35/100
Epoch: 36/100
Epoch: 37/100
Epoch: 38/100
Epoch: 39/100
Epoch: 40/100
Epoch: 41/100
Epoch: 42/100
Epoch: 43/100
Epoch: 44/100
Epoch: 45/100
Epoch: 46/100
Epoch: 47/100
Epoch: 48/100
Epoch: 49/100
Epoch: 50/100
Epoch: 51/100
Epoch: 52/100
Epoch: 53/100
Epoch: 54/100
Epoch: 55/100
Epoch: 56/100
Epoch: 57/100
Epoch: 58/100
Epoch: 59/100
Epoch: 60/100
Epoch: 61/100
Epoch: 62/100
Epoch: 63/100
Epoch: 64/100
Epoch: 65/100
Epoch: 66/100


Epoch: 67/100
Epoch: 68/100
Epoch: 69/100
Epoch: 70/100
Epoch: 71/100
Epoch: 72/100
Epoch: 73/100
Epoch: 74/100
Epoch: 75/100
Epoch: 76/100
Epoch: 77/100
Epoch: 78/100
Epoch: 79/100
Epoch: 80/100
Epoch: 81/100
Epoch: 82/100
Epoch: 83/100
Epoch: 84/100
Epoch: 85/100
Epoch: 86/100
Epoch: 87/100
Epoch: 88/100
Epoch: 89/100
Epoch: 90/100
Epoch: 91/100
Epoch: 92/100
Epoch: 93/100
Epoch: 94/100
Epoch: 95/100
Epoch: 96/100
Epoch: 97/100
Epoch: 98/100
Epoch: 99/100
Epoch: 100/100
Finished Training


In [63]:
# torch.save(model.state_dict(), '/home/francisco/workspace/CHD_Classifier_by_Francisco_Lourenço/checkpoint_detr_hd.pth')

In [None]:
hd 1/4:
    Train Loss: 0.6584 - Validation Loss: 0.6669 - Accuracy: 0.0000e+00
    best: Train Loss: 0.6584 - Validation Loss: 0.6498 - Accuracy: 0.0000e+00

Train Loss: 0.6197 - Validation Loss: 0.6037 - Accuracy: 0.0000e+00
best: Train Loss: 0.6129 - Validation Loss: 0.5986 - Accuracy: 0.0000e+00

***

In [49]:
a = valid_dset[0]
image = a[0].unsqueeze(0).unsqueeze(0).to(device)

In [50]:
b = model(image)

  dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)


In [51]:
a[1]

{'labels': tensor([1, 2, 3, 4, 5, 6, 7]),
 'boxes': tensor([[0.1058, 0.0865, 0.0288, 0.2115, 0.6429, 0.2143],
         [0.0769, 0.0962, 0.0385, 0.1538, 0.7143, 0.2857],
         [0.2212, 0.1154, 0.0673, 0.4423, 0.8571, 0.5000],
         [0.0962, 0.1442, 0.0865, 0.1923, 1.0714, 0.6429],
         [0.1731, 0.1731, 0.0673, 0.3462, 1.2857, 0.5000],
         [0.0769, 0.1442, 0.0962, 0.1538, 1.0714, 0.7143],
         [0.1731, 0.0769, 0.0769, 0.3462, 0.5714, 0.5714]])}

In [52]:
b['pred_logits'].argmax(2)

tensor([[3, 4, 2, 2, 3, 2, 3, 4]], device='cuda:0')

In [53]:
b['pred_boxes']

tensor([[[0.2162, 0.1905, 0.1379, 0.4125, 0.7468, 0.5607],
         [0.1852, 0.1596, 0.1057, 0.3731, 0.6936, 0.5030],
         [0.2007, 0.1832, 0.1221, 0.3869, 0.7135, 0.5342],
         [0.1799, 0.1373, 0.0945, 0.3724, 0.6251, 0.4239],
         [0.2030, 0.1737, 0.1308, 0.4168, 0.7406, 0.5684],
         [0.1633, 0.1215, 0.0784, 0.3160, 0.4827, 0.3158],
         [0.2196, 0.1901, 0.1439, 0.4387, 0.7995, 0.6313],
         [0.2001, 0.1834, 0.1260, 0.4072, 0.7581, 0.5545]]], device='cuda:0',
       grad_fn=<SelectBackward0>)

# Visualization:

In [53]:
is_thing_map = {'0': False, '1': True, '2': True, '3': True, '4': True, '5': True, '6': True, '7': True}

In [54]:
post_process_panoptic = PostProcessPanoptic(is_thing_map)

In [35]:
pred_masks = outputs['pred_masks']

In [42]:
result = post_process_panoptic(outputs,[tuple(torch.tensor(image.squeeze().shape).tolist())])

In [61]:
visualize_segmented_image(aux_target['seg_im'])

In [60]:
aux_inpt[0,0]

torch.Size([64, 64, 25])