In [1]:
import numpy as np
import nibabel as nib
from glob import glob
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import datasets
from models.detr import DETR
from models.segmentation import DETRsegm
from hubconf import detr_resnet101_panoptic
from torchvision.transforms import Resize

In [2]:
from models.hub.resnet import _resnet

In [3]:
# class DETRsegm_3D(nn.Module):
#     def __init__(self):
#         detr = detr_resnet101_panoptic()
#         conv

In [4]:
def sort_func(path):
    path_id = int(path.split('/')[-1].split('_')[1])
    return path_id

In [5]:
def show_image_and_label(image, label):
    fig, axs = plt.subplots(nrows=1,ncols=2, squeeze=False,figsize=(12, 12))
    axs[0, 0].imshow(image)
    axs[0, 0].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[], title='Image')
    axs[0, 1].imshow(label)
    axs[0, 1].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[], title='Label')

In [6]:
class DatasetForSegmentation(Dataset):
    
    def __init__(self, image_paths, label_paths):
        self.image_paths = image_paths
        self.label_paths = label_paths
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, i):
        image_np_array = nib.load(image_paths[i]).get_fdata()
        image_torch_tensor = torch.from_numpy(image_np_array)
        label_np_array = nib.load(image_paths[i]).get_fdata()
        label_torch_tensor = torch.from_numpy(label_np_array)
        return image_torch_tensor, label_np_array

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [8]:
path = '/home/francisco/workspace/ImageCHD_dataset'

In [9]:
image_paths = glob(f'{path}/*image.nii.gz',recursive=True)
label_paths = glob(f'{path}/*label.nii.gz',recursive=True)

In [10]:
image_paths.sort(key=sort_func)
label_paths.sort(key=sort_func)

In [11]:
dset = DatasetForSegmentation(image_paths,label_paths)

In [12]:
inpt, outp = dset[2]

In [13]:
model = _resnet();
model = model.half()
model.to(device);

In [14]:
inpt_right_shape = inpt.transpose(2,0).unsqueeze(0).unsqueeze(0).float()

In [15]:
new_size = [int(inpt_right_shape.shape[2]/3),int(inpt_right_shape.shape[3]/3),int(inpt_right_shape.shape[4]/3)]

In [16]:
inpt_resized = inpt_right_shape.resize_((1,1,new_size[0],new_size[1],new_size[2]))

In [17]:
cv_3d = torch.nn.Conv3d(1,3,(1,3,3),padding = 1)

In [18]:
result = cv_3d(inpt_resized)

In [19]:
result_cuda = result.to(device)

In [20]:
result_2 = model(result_cuda.half())

In [23]:
detr = detr_resnet101_panoptic()
detr.eval()



DETRsegm(
  (detr): DETR(
    (transformer): Transformer(
      (encoder): TransformerEncoder(
        (layers): ModuleList(
          (0): TransformerEncoderLayer(
            (self_attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
            )
            (linear1): Linear(in_features=256, out_features=2048, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
            (linear2): Linear(in_features=2048, out_features=256, bias=True)
            (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
            (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
            (dropout1): Dropout(p=0.1, inplace=False)
            (dropout2): Dropout(p=0.1, inplace=False)
          )
          (1): TransformerEncoderLayer(
            (self_attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
      

In [51]:
im = torch.ones(1,3,128,128)

In [53]:
preds = detr(im)

In [58]:
preds['pred_masks'].shape

torch.Size([1, 100, 32, 32])