In [1]:
import csv
import time
from pathlib import Path

import click
import numpy as np
import torch
from tqdm import tqdm

from diffdrr import DRR, load_example_ct
from diffdrr.metrics import XCorr2

In [2]:
def get_true_drr():
    volume, spacing = load_example_ct()
    bx, by, bz = np.array(volume.shape) * np.array(spacing) / 2
    true_params = {
        "sdr": 200.0,
        "theta": torch.pi,
        "phi": 0,
        "gamma": torch.pi / 2,
        "bx": bx,
        "by": by,
        "bz": bz,
    }
    return volume, spacing, true_params


def get_initial_parameters(true_params):
    sdr = true_params["sdr"]
    theta = true_params["theta"] + np.random.uniform(-np.pi / 4, np.pi / 4)
    phi = true_params["phi"] + np.random.uniform(-np.pi / 3, np.pi / 3)
    gamma = true_params["gamma"] + np.random.uniform(-np.pi / 3, np.pi / 3)
    bx = true_params["bx"] + np.random.uniform(-30.0, 31.0)
    by = true_params["by"] + np.random.uniform(-30.0, 31.0)
    bz = true_params["bz"] + np.random.uniform(-30.0, 31.0)
    return sdr, theta, phi, gamma, bx, by, bz

In [3]:
# Get the ground truth DRR
volume, spacing, true_params = get_true_drr()
drr = DRR(volume, spacing, height=100, delx=5e-2, device="cuda")
ground_truth = drr(**true_params)

In [4]:
drr

Parameter containing:
tensor([200.], device='cuda:0')
Parameter containing:
tensor([[3.1416, 0.0000, 1.5708]], device='cuda:0', requires_grad=True)
Parameter containing:
tensor([[180.0000, 180.0000, 166.2500]], device='cuda:0', requires_grad=True)

In [11]:
# Initialize the DRR and optimization parameters
sdr, theta, phi, gamma, bx, by, bz = get_initial_parameters(true_params)
_ = drr(sdr, theta, phi, gamma, bx, by, bz)  # Initialize the DRR generator
criterion = XCorr2(zero_mean_normalized=True)
optimizer = torch.optim.LBFGS([drr.rotations, drr.translations], line_search_fn="strong_wolfe")

In [12]:
drr

Parameter containing:
tensor([200.], device='cuda:0')
Parameter containing:
tensor([[ 3.2534, -0.0300,  1.3338]], device='cuda:0', requires_grad=True)
Parameter containing:
tensor([[150.3246, 151.7261, 189.0789]], device='cuda:0', requires_grad=True)

In [13]:
-criterion(ground_truth, drr())

tensor([-0.7001], device='cuda:0', grad_fn=<NegBackward0>)

In [14]:
for _ in tqdm(range(100)):
    def closure(drr=drr, ground_truth=ground_truth):
        optimizer.zero_grad()
        estimate = drr()
        loss = -criterion(ground_truth, estimate)
        loss.backward(retain_graph=True)
        if loss < -0.999:
            raise StopIteration
        return loss
    
    try:
        optimizer.step(closure)
    except StopIteration:
        break

  4%|███▎                                                                                | 4/100 [00:04<01:51,  1.16s/it]


In [15]:
drr

Parameter containing:
tensor([200.], device='cuda:0')
Parameter containing:
tensor([[3.1327, 0.0045, 1.5744]], device='cuda:0', requires_grad=True)
Parameter containing:
tensor([[180.0062, 180.1529, 165.1486]], device='cuda:0', requires_grad=True)

In [16]:
-criterion(ground_truth, drr())

tensor([-0.9991], device='cuda:0', grad_fn=<NegBackward0>)