In [1]:
import os
import gzip
import torch
import nibabel as nib

# Move working dir to access src
os.chdir("../")

In [2]:
from src.model.models import ResNet3D

resnet3d = ResNet3D(1)

# Run 3D ResNet18 on [1, 1, 64, 64, 64] volume
data = torch.randn(1, 1, 64, 64, 64)
resnet3d_y = resnet3d(data)
[e.shape for e in resnet3d_y]

[torch.Size([1, 128, 16, 16, 16]),
 torch.Size([1, 256, 8, 8, 8]),
 torch.Size([1, 512, 4, 4, 4])]

In [3]:
from src.model.models import PyramidFeatures3D

fpn3d = PyramidFeatures3D(128, 256, 512)
fpn3d_y = fpn3d(*resnet3d_y)
[e.shape for e in fpn3d_y]

[torch.Size([1, 256, 16, 16, 16]),
 torch.Size([1, 256, 8, 8, 8]),
 torch.Size([1, 256, 4, 4, 4]),
 torch.Size([1, 256, 2, 2, 2]),
 torch.Size([1, 256, 1, 1, 1])]

In [4]:
from src.model.models import RegressionBlock3D

regressgion_block = RegressionBlock3D(256)
regressgion_block_y = torch.cat([regressgion_block(f) for f in fpn3d_y], dim=1)
regressgion_block_y.shape

torch.Size([1, 70215, 6])

In [5]:
from src.model.models import ClassificationBlock3D

classification_block = ClassificationBlock3D(256)
classification_block_y = torch.cat([classification_block(f) for f in fpn3d_y], dim=1)
classification_block_y.shape

torch.Size([1, 70215, 10])

In [6]:
from src.model.anchors import Anchors3D

anchors = Anchors3D()
data = torch.randn(1, 1, 64, 64, 64)
anchors_y = anchors(data)
anchors_y.shape

torch.Size([70080, 6])


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


torch.Size([70080, 6])

In [7]:
from src.model.models import RetinaNet3D

retinanet = RetinaNet3D(1, num_classes=1)
data = torch.randn(1, 1, 64, 64, 64)
retinanet_y = retinanet(data)
[e.shape for e in retinanet_y]

[torch.Size([1, 70080, 6]), torch.Size([1, 70080, 3])]

In [8]:
from src.model.modules import RetinaNetLoss, BoxLabelEncoder

encoder = BoxLabelEncoder(volume_width=64, volume_height=64, volume_depth=64)

criterion = RetinaNetLoss()

pred_box = torch.rand((8, 70080, 6))
pred_cls = torch.rand((8, 70080, 3))

gt_box = torch.rand((8, 5, 6)) * 64
gt_cls = torch.rand((8, 5, 1)).long()

gt_box, gt_cls = encoder.encode(gt_box, gt_cls)

box_loss, cls_loss = criterion(pred_box, pred_cls, gt_box, gt_cls)

torch.Size([70080, 6])


