In [None]:
import os
import gzip
import torch
import nibabel as nib

# Move working dir to access src
os.chdir("../")

def read_image(file_path):
    nii_img = nib.load(file_path)
    image_data = nii_img.get_fdata()
    header = nii_img.header
    return image_data, header

In [None]:
image_data, _ = read_image('data_dev/train/images/RibFrac422-image.nii.gz') 
label_data, header = read_image('data_dev/train/labels/RibFrac422-label.nii.gz')

In [None]:
# Add batch and channel dimensions
image_data = torch.tensor(image_data).unsqueeze(0).unsqueeze(0).float()

In [None]:
from src.model.models import UNet3D

unet3d = UNet3D(1, 3)

# Run 3D U-Net on [1, 1, 112, 112, 112] volume
data = image_data[:, :, :112, :112, :112]
unet3d_y = unet3d(data)
unet3d_y.shape

In [None]:
from src.model.models import ResNet3D

resnet3d = ResNet3D(1)

# Run 3D ResNet18 on [1, 1, 64, 64, 64] volume
data = image_data[:, :, :64, :64, :64]
resnet3d_y = resnet3d(data)
[e.shape for e in resnet3d_y]

In [None]:
from src.model.models import PyramidFeatures3D

fpn3d = PyramidFeatures3D(128, 256, 512)
fpn3d_y = fpn3d(*resnet183d_y)
[e.shape for e in fpn3d_y]

In [None]:
from src.model.models import RegressionBlock3D

regressgion_block = RegressionBlock3D(256)
regressgion_block_y = torch.cat([regressgion_block(f) for f in fpn3d_y], dim=1)
regressgion_block_y.shape

In [None]:
from src.model.models import ClassificationBlock3D

classification_block = ClassificationBlock3D(256)
classification_block_y = torch.cat([classification_block(f) for f in fpn3d_y], dim=1)
classification_block_y.shape

In [None]:
from src.model.anchors import Anchors3D

anchors = Anchors3D()
data = image_data[:, :, :64, :64, :64]
anchors_y = anchors(data)
anchors_y.shape

In [None]:
from src.model.models import RetinaNet3D

retinanet = RetinaNet3D(1, num_classes=1)
data = image_data[:, :, :64, :64, :64]
retinanet_y = retinanet(data)
[e.shape for e in retinanet_y]

In [None]:
from src.model.modules import RetinaNetLoss, BoxLabelEncoder


encoder = BoxLabelEncoder()
criterion = RetinaNetLoss()

pred_box = torch.rand((8, 70215, 6))
pred_cls = torch.rand((8, 70215, 3))

gt_box = torch.rand((8, 5, 6)) * 64
gt_cls = torch.rand((8, 5, 1)).long()

gt_box, gt_cls = encoder.encode(gt_box, gt_cls)

box_loss, cls_loss = criterion(pred_box, pred_cls, gt_box, gt_cls)
(box_loss + cls_loss).mean()

## Training a RetinaNet on two volumes

In [None]:
import json
from torch.nn.functional import interpolate

with open("data_dev/train/train.json", "r") as f:
    train_boxes = json.load(f)

# Load images and corresponding box labels
x0 = torch.tensor(read_image("data_dev/train/images/RibFrac421-image.nii.gz")[0])
x1 = torch.tensor(read_image("data_dev/train/images/RibFrac422-image.nii.gz")[0])
y0 = torch.tensor(train_boxes["RibFrac421"])
y1 = torch.tensor(train_boxes["RibFrac422"])

# Convert box labels to be relative
def convert_box_format(box, shape):
    box[:,0] /= shape[0]
    box[:,3] /= shape[0]

    box[:,1] /= shape[1]
    box[:,4] /= shape[1]

    box[:,2] /= shape[2]
    box[:,5] /= shape[2]

    return box

y0 = (convert_box_format(y0.float(), x0.shape) * 64).long().unsqueeze(0)
y1 = (convert_box_format(y1.float(), x1.shape) * 64).long().unsqueeze(0)

# Add padding to smaller labels
y0 = torch.cat([y0, torch.zeros(1, 2, 6)], dim=1)

# Downsample input volumes
x0 = interpolate(x0.unsqueeze(0).unsqueeze(0), size=(64, 64, 64))
x1 = interpolate(x1.unsqueeze(0).unsqueeze(0), size=(64, 64, 64))

# Create training batches
x_batch = torch.cat([x0, x1], dim=0)
y_batch_box = torch.cat([y0, y1], dim=0)
y_batch_cls = torch.ones(2, 6, 1)

print(x_batch.shape, y_batch_box.shape, y_batch_cls.shape)


In [None]:
from src.model.models import RetinaNet3D
from src.model.modules import RetinaNetLoss, BoxLabelEncoder

# Encode labels
encoder = BoxLabelEncoder()
y_batch_box, y_batch_cls = encoder.encode(y_batch_box, y_batch_cls)

# Initialize model
model = RetinaNet3D(1, num_classes=1)

# Initialize optimizer and criterion
criterion = RetinaNetLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0002)

In [None]:
for i in range(10000):
    optimizer.zero_grad()
    y_hat_box, y_hat_cls = model(x_batch.float())

    box_loss, cls_loss = criterion(y_hat_box, y_hat_cls, y_batch_box, y_batch_cls)

    loss = (box_loss + cls_loss).mean()
    loss.backward()
    optimizer.step()

    print(box_loss.mean().item(), cls_loss.mean().item())

In [None]:
torch.save(model.state_dict(), "dev_model.pt")
# model.load_state_dict(torch.load("dev_model.pt"))

In [None]:
from src.model.modules import BoxLabelDecoder

decoder = BoxLabelDecoder()

In [None]:
with torch.no_grad():
    y_hat_box, y_hat_cls = model(x_batch.float())
    y_hat_box, y_hat_cls = decoder.decode(y_hat_box, y_hat_cls)

In [None]:
y_hat_cls_idx = y_hat_cls.argmax(dim=-1)
boxes0 = y_hat_box[0, y_hat_cls_idx[0] == 2]
boxes1 = y_hat_box[1, y_hat_cls_idx[1] == 2]

In [None]:
boxes_tar, _ = decoder.decode(y_batch_box.clone(), y_hat_cls)
boxes_tar0 = boxes_tar[0, y_batch_cls[0] == 2]
boxes_tar1 = boxes_tar[1, y_batch_cls[0] == 2]

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

def plot_3d_boxes(box, color):
    top_left_x, top_left_y, top_left_z, width, height, depth = box

    vertices = [
        [top_left_x, top_left_y, top_left_z],
        [top_left_x + width, top_left_y, top_left_z],
        [top_left_x + width, top_left_y + height, top_left_z],
        [top_left_x, top_left_y + height, top_left_z],
        [top_left_x, top_left_y, top_left_z + depth],
        [top_left_x + width, top_left_y, top_left_z + depth],
        [top_left_x + width, top_left_y + height, top_left_z + depth],
        [top_left_x, top_left_y + height, top_left_z + depth],
    ]

    # Define the vertices for the six faces of the bounding box
    faces = [
        [vertices[0], vertices[1], vertices[2], vertices[3]],
        [vertices[4], vertices[5], vertices[6], vertices[7]],
        [vertices[0], vertices[1], vertices[5], vertices[4]],
        [vertices[2], vertices[3], vertices[7], vertices[6]],
        [vertices[1], vertices[2], vertices[6], vertices[5]],
        [vertices[4], vertices[7], vertices[3], vertices[0]],
    ]

    # Create a Poly3DCollection for the bounding box
    bbox = Poly3DCollection(faces, alpha=0.25, linewidths=0)
    bbox.set_facecolor(color)

    # Add the bounding box to the 3D plot
    ax.add_collection3d(bbox)

    # Set axis labels
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_zlabel("Z")

In [None]:
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection="3d")

for box in boxes_tar1:
    plot_3d_boxes(box, "b")

for box in boxes1:
    plot_3d_boxes(box, "r")


# Set plot limits (adjust as needed)
ax.set_xlim(0, 64)
ax.set_ylim(0, 64)
ax.set_zlim(0, 64)

plt.show()