In [1]:
import torch
import torch.nn.functional as F
from PIL import Image
import numpy as np
import math
import os

In [2]:
# --- 1. Configuration -----------------------------
DEVICE = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
DATA_DIR = './images_4k'
NUM_PROJS = 50
PROJ_SIZE = 4096    # original projection resolution
REC_SIZE = 1024      # downsampled for reconstruction
OUT_SIZE = 1024     # output video frame resolution
ANGLES = torch.linspace(-72, 75, steps=NUM_PROJS).mul(math.pi/180).to(DEVICE)

# Define rotation axis (perpendicular to view). Example: camera looks down -Z, so axis=X or Y works.
# Specify as a 3-element tuple, e.g., axis = (1,0,0) for X, (0,1,0) for Y, (0,0,1) for Z,
# or any arbitrary axis.
ROT_AXIS = torch.tensor((0,1,0), dtype=torch.float32, device=DEVICE)  # change as needed

In [3]:
# --- 2. Load & Downsample Projections -------------
files = [os.path.join(DATA_DIR, f"{i}.png") for i in range(1, NUM_PROJS+1)]
projs = []
for fp in files:
    img = Image.open(fp).convert('L').resize((REC_SIZE, REC_SIZE), Image.BILINEAR)
    arr = np.array(img, dtype=np.float32)
    projs.append(torch.from_numpy(arr))
projs = torch.stack(projs, dim=0).to(DEVICE)  # (N, REC_SIZE, REC_SIZE)

In [4]:
# --- 3. Ramp Filter --------------------------------
def ramp_filter(proj):
    N = proj.shape[-1]
    freqs = torch.fft.rfftfreq(N, d=1.0).to(DEVICE)
    filt = freqs.abs()
    P = torch.fft.rfft(proj, dim=-1)
    P = P * filt.unsqueeze(0)
    return torch.fft.irfft(P, n=N, dim=-1)

In [5]:
# --- 4. Backprojection -----------------------------
def backproject_slice(z_idx, projs, angles):
    sinogram = projs[:, z_idx, :]
    filt = ramp_filter(sinogram)
    W = filt.shape[-1]
    coords = torch.linspace(-(W-1)/2, (W-1)/2, W, device=DEVICE)
    X, Y = torch.meshgrid(coords, coords, indexing='xy')
    recon = torch.zeros((W, W), device=DEVICE)
    for theta, proj_f in zip(angles, filt):
        t = X * torch.cos(theta) + Y * torch.sin(theta)
        idx_norm = (t + (W-1)/2) * 2/(W-1) - 1
        grid = torch.stack([idx_norm, torch.zeros_like(idx_norm)], dim=-1).unsqueeze(0)
        sampled = F.grid_sample(proj_f.view(1,1,1,W), grid, align_corners=False)
        recon += sampled.view(W, W)
    return recon * (math.pi / (2 * angles.numel()))

In [6]:
# --- 5. Reconstruct Volume ------------------------
volume = torch.stack([backproject_slice(z, projs, ANGLES) for z in range(REC_SIZE)], dim=0)  # (Z, Y, X)

In [7]:
# --- 6. Generic MIP Rendering ----------------------
def render_mip_axis(vol, yaw, axis_vec):
    """
    Rotate volume around arbitrary axis_vec by yaw (radians), then max-intensity-project along camera direction (-Z).
    Assumes camera looks along -Z, volume coords in (Z,Y,X).
    """
    # Normalize axis
    a = axis_vec / axis_vec.norm()
    ux, uy, uz = a
    # Rodrigues rotation matrix components
    c, s = math.cos(yaw), math.sin(yaw)
    R = torch.tensor([
        [c + ux*ux*(1-c),    ux*uy*(1-c) - uz*s, ux*uz*(1-c) + uy*s],
        [uy*ux*(1-c) + uz*s, c + uy*uy*(1-c),    uy*uz*(1-c) - ux*s],
        [uz*ux*(1-c) - uy*s, uz*uy*(1-c) + ux*s, c + uz*uz*(1-c)]
    ], device=DEVICE)

    Z, Y, X = vol.shape
    # Create meshgrid coords in normalized [-1,1]
    zs = torch.linspace(-1,1,Z,device=DEVICE)
    ys = torch.linspace(-1,1,Y,device=DEVICE)
    xs = torch.linspace(-1,1,X,device=DEVICE)
    zz, yy, xx = torch.meshgrid(zs, ys, xs, indexing='ij')  # (Z,Y,X)
    coords = torch.stack([zz, yy, xx], dim=-1)  # (Z,Y,X,3)

    # Apply rotation
    coords_flat = coords.view(-1, 3)  # (N,3)
    rot_flat = coords_flat @ R.T
    rot = rot_flat.view(Z, Y, X, 3)

    # Map back to sampling grid for camera projecting along -Z: sample rotated vol
    # We need XY plane sampling: for each (y,x), sample along Z dimension max
    # Build grid for grid_sample: samples each Z slice into image
    # Permute vol to (1,1,Z,Y,X) for 3D sampling is complex—approximate by sampling 2D slices
    # We'll sample each Z slice: for each z index, project rotated volume's corresponding intensity
    # Instead use nearest neighbor: find for each ray along Z the max intensity
    # Simplify: rotate points, then MIP: gather vol at rotated coords using trilinear
    # Use grid_sample 3D
    vol_5d = vol.unsqueeze(0).unsqueeze(0)  # (1,1,Z,Y,X)
    grid_3d = rot.view(1, Z, Y, X, 3)
    # grid_sample expects coords in [-1,1] as (x,y,z)
    # swap to (N,C,D,H,W) and grid (N,D,H,W,3)
    sampled = F.grid_sample(vol_5d, grid_3d, mode='bilinear', align_corners=True)
    # sampled shape: (1,1,Z,Y,X)
    sampled = sampled[0,0]  # (Z,Y,X)
    # MIP along Z axis
    mip = sampled.amax(dim=0)  # (Y,X)
    return mip

In [8]:
# --- 7. Save Frames -------------------------------
n_frames = 180
os.makedirs('frames', exist_ok=True)
for i in range(n_frames):
    angle = 2*math.pi * i / n_frames
    mip = render_mip_axis(volume, angle, ROT_AXIS).cpu().numpy()
    # Contrast
    p1, p99 = np.percentile(mip, (1,99))
    mip_norm = np.clip((mip-p1)/(p99-p1),0,1) if p99>p1 else (mip-mip.min())/(mip.max()-mip.min()+1e-6)
    img = (mip_norm*255).astype('uint8')
    Image.fromarray(img).resize((REC_SIZE,REC_SIZE)).save(f'frames/frame_{i:03d}.png')

print("Done.")


Done.
