In [7]:
# all from plenoxel.py
import jax
import jax.numpy as jnp
from jax import lax
from jax.ops import index, index_update
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
import os
import sh
from tqdm import tqdm
import imageio


In [2]:
def get_rays(H, W, focal, c2w):
    i, j = jnp.meshgrid(jnp.linspace(0, W-1, W) + 0.5, jnp.linspace(0, H-1, H) + 0.5) 
    dirs = jnp.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -jnp.ones_like(i)], -1)
    # Rotate ray directions from camera frame to the world frame
    rays_d = jnp.sum(dirs[..., jnp.newaxis, :] * c2w[:3,:3], -1)  # dot product, equals to: [c2w.dot(dir) for dir in dirs]
    # Translate camera frame's origin to the world frame. It is the origin of all rays.
    rays_o = jnp.broadcast_to(c2w[:3,-1], rays_d.shape)
    return rays_o, rays_d

In [3]:
@jax.partial(jax.jit, static_argnums=(2,4,5,6,7,8,9))
def render_rays(grid, rays, resolution, keys, radius=1.3, harmonic_degree=0, jitter=0, uniform=0, interpolation='trilinear', nv=False):
  sh_dim = (harmonic_degree + 1)**2
  voxel_len = radius * 2.0 / resolution
  assert (resolution // 2) * 2 == resolution # Renderer assumes resolution is a multiple of 2
  rays_o, rays_d = rays
  # Compute when the rays enter and leave the grid
  offsets_pos = jax.lax.stop_gradient((radius - rays_o) / rays_d)
  offsets_neg = jax.lax.stop_gradient((-radius - rays_o) / rays_d)
  offsets_in = jax.lax.stop_gradient(jnp.minimum(offsets_pos, offsets_neg))
  offsets_out = jax.lax.stop_gradient(jnp.maximum(offsets_pos, offsets_neg))
  start = jax.lax.stop_gradient(jnp.max(offsets_in, axis=-1, keepdims=True))
  stop = jax.lax.stop_gradient(jnp.min(offsets_out, axis=-1, keepdims=True))
  first_intersection = jax.lax.stop_gradient(rays_o + start * rays_d)
  # Compute locations of ray-voxel intersections along each dimension
  interval = jax.lax.stop_gradient(voxel_len / jnp.abs(rays_d)) # replace voxel_len with voxel_len/2
  offset_bigger = jax.lax.stop_gradient((safe_ceil(first_intersection / voxel_len) * voxel_len - first_intersection) / rays_d)
  offset_smaller = jax.lax.stop_gradient((safe_floor(first_intersection / voxel_len) * voxel_len - first_intersection) / rays_d)
  offset = jax.lax.stop_gradient(jnp.maximum(offset_bigger, offset_smaller))
  # print(first_intersection)

  # Compute the samples along each ray
  matrix = None
  powers = None
  if interpolation == 'tricubic':
    matrix, powers = tricubic_interpolation_matrix()
  if len(rays_o.shape) > 2:
    voxel_sh, voxel_sigma, intersections = get_intersections({"start": start, "stop": stop, "offset": offset, "interval": interval, "ray_o": rays_o, "ray_d": rays_d}, grid, resolution, radius, jitter, uniform, keys, sh_dim, interpolation, matrix, powers)
  else:
    voxel_sh, voxel_sigma, intersections = get_intersections_partial({"start": start, "stop": stop, "offset": offset, "interval": interval, "ray_o": rays_o, "ray_d": rays_d}, grid, resolution, radius, jitter, uniform, keys, sh_dim, interpolation, matrix, powers)
  # Apply spherical harmonics
  # voxel_rgb = sh.eval_sh(harmonic_degree, voxel_sh, rays_d)
  # Call volumetric_rendering
  if harmonic_degree >= 0:
    voxel_rgb = sh.eval_sh(harmonic_degree, voxel_sh, rays_d)
  else:
    voxel_rgb = []
  if nv:
    rgb, disp, acc, weights = nv_rendering(voxel_rgb, voxel_sigma, intersections, rays_d)
  else:
    rgb, disp, acc, weights = volumetric_rendering(voxel_rgb, voxel_sigma, intersections, rays_d)
  pts = rays_o[:, jnp.newaxis, :] + intersections[:, :, jnp.newaxis] * rays_d[:, jnp.newaxis, :]  # [n_rays, n_intersections, 3]
  ids = jnp.clip(jnp.array(jnp.floor(pts / voxel_len + eps) + resolution / 2, dtype=int), a_min=0, a_max=resolution-1)
  return rgb, disp, acc, weights, ids

In [8]:
def get_ct_jerry(root, stage):
    all_c2w = []
    all_gt = []

    print('LOAD DATA', root)
    
    projection_matrices = np.genfromtxt(os.path.join('/data/datasets/jerry-cbct/', 'proj_mat.csv'), delimiter=',')  # [719, 12]
    for i in range(len(projection_matrices)-1): 
        index = "{:04d}".format(i)
        im_gt = imageio.imread(os.path.join('/data/datasets/newJerryProj', f'NewJerryProj_{index}.png')).astype(np.float32) / 255.0
        im_gt = 1 - im_gt

        w2c = np.reshape(projection_matrices[i], (3,4))
        w2c[:,-1] = (w2c[:,-1] - [400, 220, 200])
        # invert world -> camera to get camera -> world
        c2w = np.linalg.inv(np.concatenate([w2c, [[0,0,0,1]]], axis=0))

        all_c2w.append(c2w)
        all_gt.append(im_gt) # This one is needed for source projections

    focal = 300

    all_gt = np.asarray(all_gt)
    all_c2w = np.asarray(all_c2w)

    mask = np.zeros(len(all_c2w))
    idx = np.random.choice(len(all_c2w), 100, replace = False) # was 500 idx
    mask[idx] = 1
    mask = mask.astype(bool)

    # train and test can be commented out ot get the full 360 ground truth projections
    if stage == 'train':
        all_gt = all_gt[mask]
        all_c2w = all_c2w[mask]
    elif stage == 'test':
        all_gt = all_gt[~mask]
        all_c2w = all_c2w[~mask]

    return focal, all_c2w, all_gt

In [9]:
root = '/data/datasets'

focal, train_c2w, train_gt = get_ct_jerry(root, "train")
test_focal, test_c2w, test_gt = get_ct_jerry(root, "test")

assert focal == test_focal
H, W = train_gt[0].shape[:2]
n_train_imgs = len(train_c2w)
n_test_imgs = len(test_c2w)

LOAD DATA /data/datasets


  # Remove the CWD from sys.path while we load stuff.


LOAD DATA /data/datasets


In [20]:
for j, (c2w, gt) in tqdm(enumerate(zip(train_c2w, train_gt)), total=len(train_c2w)):
    rays = get_rays(H, W, focal, c2w)
    gt = jnp.concatenate((gt[...,jnp.newaxis], gt[...,jnp.newaxis], gt[...,jnp.newaxis]), axis=-1)
    # rgb, disp, acc, weights, voxel_ids = render_rays(data_dict, rays, resolution, key, radius, harmonic_degree, jitter, uniform, interpolation, nv)
    # vis = jnp.concatenate((gt, rgb), axis = 1)
    imageio.imwrite(f"/data/fabrizio/multi-energy-ct/here/{j:04}_0001.png", (gt*255).astype(np.uint8))


  0%|          | 0/100 [00:00<?, ?it/s]


FileNotFoundError: The directory '/data/fabrizio/multi-energy-ct/here' does not exist