In [None]:
# Imports
import torch
from lighter_zoo import SegResEncoder
from monai.transforms import (
    Compose, LoadImage, EnsureType, Orientation,
    ScaleIntensityRange, CropForeground
)
from monai.inferers import SlidingWindowInferer
import nibabel as nib
import numpy as np

In [None]:
# # Load pre-trained model
# model = SegResEncoder.from_pretrained(
#     "project-lighter/ct_fm_feature_extractor"
# )

model = torch.load('checkpoints/ct_fm_feature_extractor.pth')
model.eval()

In [None]:
# Preprocessing pipeline
preprocess = Compose([
    LoadImage(ensure_channel_first=True),  # Load image and ensure channel dimension
    EnsureType(),                         # Ensure correct data type
    Orientation(axcodes="SPL"),           # Standardize orientation
    # Scale intensity to [0,1] range, clipping outliers
    ScaleIntensityRange(
        a_min=-1024,    # Min HU value
        a_max=2048,     # Max HU value
        b_min=0,        # Target min
        b_max=1,        # Target max
        clip=True       # Clip values outside range
    ),
    CropForeground()    # Remove background to reduce computation
])

In [None]:
# Input path
input_path = "/home/vlab/Documents/Collections/Dental/STS-Tooth/SD-Tooth/STS-3D-Tooth/Integrity/Labeled/Image/Integrity_L_004.nii.gz"

# Preprocess input
input_tensor = preprocess(input_path)

# Run inference
with torch.no_grad():
    output = model(input_tensor.unsqueeze(0))[-1]

    # Average pooling compressed the feature vector across all patches. If this is not desired, remove this line and 
    # use the output tensor directly which will give you the feature maps in a low-dimensional space.
    avg_output = torch.nn.functional.adaptive_avg_pool3d(output, 1).squeeze()

print("✅ Feature extraction completed")
print(f"Output shape: {avg_output.shape}")


In [1]:
nii = nib.load(input_path)
data = nii.get_fdata()  # raw voxel values

NameError: name 'nib' is not defined

In [None]:
# Plot distribution of features
import matplotlib.pyplot as plt
_ = plt.hist(avg_output.cpu().numpy(), bins=100)