# Spherical Bessel harmonic field slices for trabecular bone projections

This notebook loads the trabecular bone dataset used in `sici/bone-forward.ipynb`, converts each projection's angular description into a direction vector, and evaluates planar slices of a randomly initialised `SphericalBesselHarmonicField` along those directions.

In [1]:
import sys
from pathlib import Path

project_root = Path.cwd().parent
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))
print(f"Added {project_root} to Python path")

Added /Users/lfbarba/GitHub/smartTT to Python path


In [17]:
import math
import numpy as np
import torch
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, Markdown
import lovely_tensors as lt
lt.monkey_patch()

from mumott.data_handling import DataContainer
from pyTT import SphericalBesselHarmonicField

## Load trabecular bone dataset

In [3]:
data_path = project_root / "sici" / "trabecular_bone_9.h5"
dc = DataContainer(str(data_path))
projections = dc.projections
geometry = dc.geometry

num_projections = len(projections._projections)
sample_projection = projections._projections[0]

inner_angle_sample = getattr(sample_projection, "inner_angle", None)
outer_angle_sample = getattr(sample_projection, "outer_angle", None)
data_sample = None
for attr_name in ("data", "measured_data", "_data"):
    if hasattr(sample_projection, attr_name):
        data_sample = getattr(sample_projection, attr_name)
        break

print(f"Loaded data container from {data_path}")
print(f"Number of projections: {num_projections}")
if inner_angle_sample is not None and outer_angle_sample is not None:
    print(f"Example projection angles (inner, outer): ({float(inner_angle_sample):.3f}, {float(outer_angle_sample):.3f}) rad")
if data_sample is not None:
    try:
        print(f"Example projection data shape: {np.asarray(data_sample).shape}")
    except Exception:
        print("Could not infer projection data shape automatically.")
else:
    print("Projection data array attribute not detected on the sample object.")

INFO:Inner axis found in dataset base directory. This will override the default.
INFO:Outer axis found in dataset base directory. This will override the default.
INFO:Outer axis found in dataset base directory. This will override the default.
INFO:Rotation matrices were loaded from the input file.
INFO:Rotation matrices were loaded from the input file.
INFO:Sample geometry loaded from file.
INFO:Detector geometry loaded from file.
Loaded data container from /Users/lfbarba/GitHub/smartTT/sici/trabecular_bone_9.h5
Number of projections: 247
Example projection angles (inner, outer): (0.000, 0.000) rad
Example projection data shape: (65, 55, 8)
INFO:Sample geometry loaded from file.
INFO:Detector geometry loaded from file.
Loaded data container from /Users/lfbarba/GitHub/smartTT/sici/trabecular_bone_9.h5
Number of projections: 247
Example projection angles (inner, outer): (0.000, 0.000) rad
Example projection data shape: (65, 55, 8)


## Convert projection angles to direction vectors

In [4]:
def angles_to_unit_vector(inner_angle: float, outer_angle: float) -> np.ndarray:
    tilt = np.pi / 2.0 - outer_angle
    x = np.sin(tilt) * np.cos(inner_angle)
    y = np.sin(tilt) * np.sin(inner_angle)
    z = np.cos(tilt)
    vec = np.array([x, y, z], dtype=np.float64)
    norm = np.linalg.norm(vec)
    if norm == 0.0:
        raise ValueError("Encountered zero-length direction vector.")
    return vec / norm

direction_vectors = []
for idx, proj in enumerate(projections._projections):
    inner = float(getattr(proj, "inner_angle"))
    outer = float(getattr(proj, "outer_angle"))
    vec = angles_to_unit_vector(inner, outer)
    direction_vectors.append(vec)
direction_vectors = np.stack(direction_vectors)
direction_tensor = torch.from_numpy(direction_vectors).to(torch.float32)

print(f"Computed {direction_tensor.shape[0]} projection direction vectors.")
preview_vectors = [str(direction_vectors[i]) for i in range(min(3, len(direction_vectors)))]
display(Markdown("First three unit vectors:<br>" + "<br>".join(preview_vectors)))

Computed 247 projection direction vectors.


First three unit vectors:<br>[1.000000e+00 0.000000e+00 6.123234e-17]<br>[9.96917334e-01 7.84590957e-02 6.12323400e-17]<br>[9.87688341e-01 1.56434465e-01 6.12323400e-17]

## Instantiate spherical Bessel harmonic field

In [13]:
torch.manual_seed(0)
np.random.seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

volume_shape = getattr(geometry, "volume_shape", None)
if volume_shape is not None:
    volume_shape = np.asarray(volume_shape).astype(int)
    grid_dims = volume_shape
else:
    grid_dims = (12, 12, 12)

spacing_attr = None
for candidate in ("pixel_size", "voxel_size", "voxel_spacing"):
    if hasattr(geometry, candidate):
        spacing_attr = getattr(geometry, candidate)
        break
if spacing_attr is None:
    spacing = 1.0
else:
    spacing = tuple(np.asarray(spacing_attr).flatten()) if np.ndim(spacing_attr) else float(spacing_attr)

max_l = 8
num_radial = 11
grid_resolution = 96

field = SphericalBesselHarmonicField(
    dims=grid_dims,
    max_l=max_l,
    num_radial=num_radial,
    radius=1.0,
    coeff_init_scale=0.05,
    spacing=spacing,
).to(device)

with torch.no_grad():
    for l_idx, coeff in enumerate(field.coeffs):
        coeff.mul_(1.0 + 0.2 * (l_idx + 1))

print(f"Field dims: {field.dims}")
print(f"Number of spheres: {field.num_spheres}")
print(f"Device: {field.device}, dtype: {field.dtype}")
print(f"Grid resolution per slice: {grid_resolution}")

Field dims: (55, 65, 65)
Number of spheres: 232375
Device: cpu, dtype: torch.float32
Grid resolution per slice: 96


In [21]:
field.coeffs[0].reshape(*grid_dims, -1)

tensor[55, 65, 65, 11] n=2556125 (9.8Mb) x∈[-0.305, 0.290] μ=-4.170e-05 σ=0.060 grad ViewBackward0

In [23]:
import sys
sys.path.append('~/GitHub/astra-torch')
from astra_torch.lamino import fbp_reconstruction_masked

ModuleNotFoundError: No module named 'astra_torch'

In [14]:
total_params = 0
for x in field.parameters():
    total_params += x.numel()
print(f"Total number of learnable parameters in the field: {total_params}")
field

Total number of learnable parameters in the field: 207046224


SphericalBesselHarmonicField(
  (log_k_params): ParameterList(
      (0): Parameter containing: [torch.float32 of size 11]
      (1): Parameter containing: [torch.float32 of size 11]
      (2): Parameter containing: [torch.float32 of size 11]
      (3): Parameter containing: [torch.float32 of size 11]
      (4): Parameter containing: [torch.float32 of size 11]
      (5): Parameter containing: [torch.float32 of size 11]
      (6): Parameter containing: [torch.float32 of size 11]
      (7): Parameter containing: [torch.float32 of size 11]
      (8): Parameter containing: [torch.float32 of size 11]
  )
  (coeffs): ParameterList(
      (0): Parameter containing: [torch.float32 of size 232375x11x1]
      (1): Parameter containing: [torch.float32 of size 232375x11x3]
      (2): Parameter containing: [torch.float32 of size 232375x11x5]
      (3): Parameter containing: [torch.float32 of size 232375x11x7]
      (4): Parameter containing: [torch.float32 of size 232375x11x9]
      (5): Parameter 

## Evaluate slices along projection directions

In [5]:
direction_tensor = direction_tensor.to(device=device, dtype=field.dtype)

slice_images = []
with torch.no_grad():
    for idx, normal in enumerate(direction_tensor):
        slice_volume = field.slice(normal, grid_resolution=grid_resolution, flatten=True)
        slice_mean = slice_volume.mean(dim=0)
        slice_images.append(slice_mean)
        break
slice_stack = torch.stack(slice_images)

print(f"Slice stack shape: {tuple(slice_stack.shape)} (num_projections, grid, grid)")

NameError: name 'direction_tensor' is not defined

## Visualise sample slices

In [None]:
def plot_slice_grid(slice_tensor: torch.Tensor, num_examples: int = 9) -> None:
    slice_tensor = slice_tensor.cpu()
    num_examples = min(num_examples, slice_tensor.shape[0])
    cols = 3
    rows = math.ceil(num_examples / cols)
    fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 4 * rows))
    axes = np.atleast_2d(axes)
    for idx in range(rows * cols):
        ax = axes[idx // cols, idx % cols]
        if idx < num_examples:
            im = ax.imshow(slice_tensor[idx].numpy(), cmap="magma", origin="lower")
            ax.set_title(f"Slice {idx}")
            ax.axis("off")
            fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        else:
            ax.axis("off")
    plt.tight_layout()
    plt.show()
plot_slice_grid(slice_stack, num_examples=9)

## Interactive explorer

In [None]:
slice_stack_cpu = slice_stack.cpu().numpy()
slider = widgets.IntSlider(min=0, max=slice_stack_cpu.shape[0] - 1, step=1, value=0, description="Index")

def update_slice(idx):
    plt.figure(figsize=(5, 5))
    plt.imshow(slice_stack_cpu[idx], cmap="magma", origin="lower")
    plt.title(f"Slice {idx}")
    plt.axis("off")
    plt.colorbar(fraction=0.046, pad=0.04)
    plt.show()

interactive = widgets.interactive_output(lambda idx: update_slice(idx), {"idx": slider})
display(widgets.VBox([slider, interactive]))