In [111]:
import torch
import torch.nn.functional as F
from PIL import Image
import numpy as np
import math
import os

In [113]:
# --- 1. Configuration -----------------------------
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
NUM_PROJS = 50
DATA_DIR = 'directory_to_images'
PROJ_SIZE = 512    # original projection resolution
REC_SIZE = 512      # downsampled for reconstruction
OUT_SIZE = 512    # output video frame resolution
ANGLES = torch.linspace(-72, 75, steps=NUM_PROJS).mul(math.pi/180).to(DEVICE)

# Define rotation axis (perpendicular to view). Example: camera looks down -Z, so axis=X or Y works.
# Specify as a 3-element tuple, e.g., axis = (1,0,0) for X, (0,1,0) for Y, (0,0,1) for Z,
# or any arbitrary axis.
ROT_AXIS = torch.tensor((0,1,0), dtype=torch.float32, device=DEVICE)  # change as needed

In [114]:
# --- 2. Load & Downsample Projections (RGB) -------
files = [os.path.join(DATA_DIR, f"overlay_{i}.png") for i in range(1, NUM_PROJS+1)]
projs = []
for fp in files:
    img = Image.open(fp).convert('RGB').resize((REC_SIZE, REC_SIZE), Image.BILINEAR)
    arr = np.array(img, dtype=np.float32) / 255.0  # normalize
    tensor = torch.from_numpy(arr).permute(2,0,1)
    projs.append(tensor)
# projs shape: (NUM_PROJS, 3, REC_SIZE, REC_SIZE)
projs = torch.stack(projs, dim=0).to(DEVICE)

In [115]:
# --- 3. Ramp Filter --------------------------------
def ramp_filter_1d(signal):
    """Apply ramp filter to 1D signal of shape (W,)"""
    N = signal.shape[-1]
    freqs = torch.fft.rfftfreq(N, d=1.0).to(DEVICE)
    filt = freqs.abs()
    P = torch.fft.rfft(signal, dim=-1)
    return torch.fft.irfft(P * filt, n=N, dim=-1)

In [116]:
# --- 4. Backprojection -----------------------------
def backproject_slice(z_idx):
    # sinogram: (NUM_PROJS, 3, REC_SIZE)
    sinogram = projs[:, :, z_idx, :]
    W = REC_SIZE
    # Create backprojection grid coordinates
    coords = torch.linspace(-(W-1)/2, (W-1)/2, W, device=DEVICE)
    X, Y = torch.meshgrid(coords, coords, indexing='xy')
    recon = torch.zeros((3, W, W), device=DEVICE)
    for i, theta in enumerate(ANGLES):
        cos_t, sin_t = math.cos(theta), math.sin(theta)
        for c in range(3):
            proj_line = sinogram[i, c, :]          # (W,)
            filtered = ramp_filter_1d(proj_line)   # (W,)
            # compute sampling grid of shape (1, W, W, 2)
            t = X * cos_t + Y * sin_t
            idx_norm = (t + (W-1)/2) * 2/(W-1) - 1  # in [-1,1]
            # first coord (y) is zeros since proj_line is 1D along width
            zeros = torch.zeros_like(idx_norm)
            grid = torch.stack([zeros, idx_norm], dim=-1).unsqueeze(0)  # (1, W, W, 2)
            # sample filtered projection into 2D slice
            sampled = F.grid_sample(filtered.view(1,1,1,W), grid, align_corners=False)
            # sampled: (1,1,W,W?) -> view as (W, W)
            recon[c] += sampled.view(W, W)
    recon *= (math.pi / (2 * ANGLES.numel()))
    return recon

In [117]:
# --- 5. Reconstruct Volume ------------------------. Reconstruct Volume ------------------------
vol_slices = [backproject_slice(z) for z in range(REC_SIZE)]
vol_stack = torch.stack(vol_slices, dim=0)  # (Z,3,Y,X)
volume = vol_stack.permute(1,0,2,3).unsqueeze(0)  # (1,3,Z,Y,X)

In [118]:
# --- 6. 3D MIP Rendering on Arbitrary Axis --------
def render_mip_axis(vol5d, yaw, axis_vec):
    a = axis_vec / axis_vec.norm()
    ux, uy, uz = a
    c, s = math.cos(yaw), math.sin(yaw)
    R = torch.tensor([
        [c+ux*ux*(1-c), ux*uy*(1-c)-uz*s, ux*uz*(1-c)+uy*s],
        [uy*ux*(1-c)+uz*s, c+uy*uy*(1-c), uy*uz*(1-c)-ux*s],
        [uz*ux*(1-c)-uy*s, uz*uy*(1-c)+ux*s, c+uz*uz*(1-c)]
    ], device=DEVICE)
    _, C, Z, Y, X = vol5d.shape
    zs = torch.linspace(-1,1,Z,device=DEVICE)
    ys = torch.linspace(-1,1,Y,device=DEVICE)
    xs = torch.linspace(-1,1,X,device=DEVICE)
    zz, yy, xx = torch.meshgrid(zs, ys, xs, indexing='ij')
    coords = torch.stack([zz, yy, xx], dim=-1)
    rot = (coords.view(-1,3) @ R.T).view(Z,Y,X,3)
    grid = rot.view(1,Z,Y,X,3)
    sampled = F.grid_sample(vol5d, grid, mode='bilinear', align_corners=True)
    mip = sampled.amax(dim=2)  # (1,C,Y,X)
    return mip[0]

In [119]:
# --- 7. Save Frames & Upsample Output ------------
n_frames = 180
os.makedirs('frames', exist_ok=True)
for i in range(n_frames):
    angle = 2*math.pi * i / n_frames
    mip = render_mip_axis(volume, angle, ROT_AXIS).cpu().numpy()  # (3,Y,X)
    img_accum = np.zeros_like(mip)
    for c in range(3):
        chan = mip[c]
        p1, p99 = np.percentile(chan, (1,99))
        img_accum[c] = np.clip((chan-p1)/(p99-p1),0,1) if p99>p1 else (chan-chan.min())/(chan.max()-chan.min()+1e-6)
    img = (np.clip(img_accum,0,1).transpose(1,2,0) * 255).astype('uint8')
    Image.fromarray(img).resize((OUT_SIZE,OUT_SIZE), Image.BILINEAR).save(f'frames/frame_{i:03d}.png')
print("Done.")

Done.
