In [1]:
import os
import gzip
import torch
import nibabel as nib

# Move working dir to access src
os.chdir("../")

def read_image(file_path):
    nii_img = nib.load(file_path)
    image_data = nii_img.get_fdata()
    header = nii_img.header
    return image_data, header

In [2]:
image_data, _ = read_image('data_dev/train/images/RibFrac422-image.nii.gz') 
label_data, header = read_image('data_dev/train/labels/RibFrac422-label.nii.gz')

In [3]:
# Add batch and channel dimensions
image_data = torch.tensor(image_data).unsqueeze(0).unsqueeze(0).float()

In [4]:
from src.model.models import UNet3D

unet3d = UNet3D(1, 3)

# Run 3D U-Net on [1, 1, 112, 112, 112] volume
data = image_data[:, :, :112, :112, :112]
unet3d_y = unet3d(data)
unet3d_y.shape

torch.Size([1, 3, 112, 112, 112])

In [5]:
from src.model.models import ResNet183D

resnet183d = ResNet183D(1)

# Run 3D ResNet18 on [1, 1, 64, 64, 64] volume
data = image_data[:, :, :64, :64, :64]
resnet183d_y = resnet183d(data)
[e.shape for e in resnet183d_y]

[torch.Size([1, 128, 16, 16, 16]),
 torch.Size([1, 256, 8, 8, 8]),
 torch.Size([1, 512, 4, 4, 4])]

In [6]:
from src.model.models import PyramidFeatures3D

fpn3d = PyramidFeatures3D(128, 256, 512)
fpn3d_y = fpn3d(*resnet183d_y)
[e.shape for e in fpn3d_y]

[torch.Size([1, 256, 16, 16, 16]),
 torch.Size([1, 256, 8, 8, 8]),
 torch.Size([1, 256, 4, 4, 4]),
 torch.Size([1, 256, 2, 2, 2]),
 torch.Size([1, 256, 1, 1, 1])]

In [7]:
from src.model.models import RegressionBlock3D

regressgion_block = RegressionBlock3D(256)
regressgion_block_y = torch.cat([regressgion_block(f) for f in fpn3d_y], dim=1)
regressgion_block_y.shape

torch.Size([1, 70215, 6])

In [8]:
from src.model.models import ClassificationBlock3D

classification_block = ClassificationBlock3D(256)
classification_block_y = torch.cat([classification_block(f) for f in fpn3d_y], dim=1)
classification_block_y.shape

torch.Size([1, 70215, 10])

In [9]:
from src.model.anchors import Anchors3D

anchors = Anchors3D()
data = image_data[:, :, :64, :64, :64]
anchors_y = anchors(data)
anchors_y.shape

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


torch.Size([70215, 6])

In [10]:
from src.model.models import RetinaNet3D

retinanet = RetinaNet3D(1, num_classes=1)
data = image_data[:, :, :64, :64, :64]
retinanet_y = retinanet(data)
[e.shape for e in retinanet_y]

[torch.Size([1, 70215, 6]), torch.Size([1, 70215, 3])]

In [11]:
from src.misc.utils import pairwise_iou

boxes1 = torch.tensor([[0, 0, 0, 1, 1, 1], [0, 0, 0, 0.5, 0.5, 0.5]])
boxes2 = torch.tensor([[0, 0, 0, 0.5, 0.5, 0.5], [0, 0, 0, 1, 1, 1]])

pairwise_iou(boxes1, boxes2)

tensor([[0.1250, 1.0000],
        [1.0000, 0.1250]])

In [12]:
from src.model.anchors import match_anchor_boxes

match_anchor_boxes(boxes1, boxes2)

(tensor([1, 0]), tensor([1., 1.]), tensor([0., 0.]))

In [44]:
from src.model.modules import RetinaNetLoss, BoxLabelEncoder


encoder = BoxLabelEncoder()
criterion = RetinaNetLoss()

pred_box = torch.rand((8, 70215, 6))
pred_cls = torch.rand((8, 70215, 3))

gt_box = torch.rand((8, 5, 6)) * 64
gt_cls = torch.rand((8, 5, 1)).long()

gt_box, gt_cls = encoder.encode(gt_box, gt_cls)

box_loss, cls_loss = criterion(pred_box, pred_cls, gt_box, gt_cls)
(box_loss + cls_loss).mean()

tensor(4.2535e+37)