In [None]:
%load_ext autoreload
%autoreload 2



In [None]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '6'

In [None]:
import numpy as np
from matplotlib import pyplot as plt
import torch, torch.nn as nn
import os, os.path as osp
import sys
import time
from torchmetrics.image import PeakSignalNoiseRatio
from torchmetrics import MetricTracker
import skimage, skimage.io, skimage.transform, skimage.filters

import alpine
from functools import partial


In [None]:
NUM_ITERATIONS = 20000


In [None]:
class MSE_TV_Loss(nn.Module):
    def __init__(self, weight=1.0):
        super(MSE_TV_Loss, self).__init__()
        self.weight = weight
        self.mse = nn.MSELoss()
        
    def forward(self, x, y):
        mse = self.mse(x['output'], y['signal'])
        tv_img = x['output_img'].permute(0, 3,1,2)
        tv = torch.mean(torch.abs(tv_img[:, :, 1:, :] - tv_img[:, :, :-1, :])) + \
            torch.mean(torch.abs(tv_img[:, :, :, 1:] - tv_img[:, :, :, :-1]))
        if torch.isnan(tv):
            tv = 0
        return mse + self.weight * tv

In [None]:
wire_model = alpine.models.Wire( in_features = 2, hidden_features = 300, hidden_layers=4, out_features=1, omegas=[10.0], sigmas=[10.0,]).float().cuda()

scheduler = partial(torch.optim.lr_scheduler.LambdaLR, lr_lambda=lambda x: 0.1**(min(x/NUM_ITERATIONS, 1.0)) )
wire_model.register_loss_function(MSE_TV_Loss(weight=0.1).float().cuda())
wire_model.compile(learning_rate=5e-3, scheduler=scheduler)
print(wire_model)


In [None]:
image = skimage.io.imread("./data/chest.png").astype(np.float32)
image = (image - image.min())/(image.max() - image.min())
plt.figure()
plt.imshow(image)
plt.axis('off')
plt.colorbar()
plt.show()

H, W = image.shape
print(H, W)

# gt_signal = torch.from_numpy(image).float().cuda()[None,...,None]
# print(gt_signal.shape)
# print(gt_signal.min(), gt_signal.max())

In [None]:
import kornia.geometry
def radon(imten, angles, is_3d=False):
    '''
        Compute forward radon operation
        
        Inputs:
            imten: (1, nimg, H, W) image tensor
            angles: (nangles) angles tensor -- should be on same device as 
                imten
        Outputs:
            sinogram: (nimg, nangles, W) sinogram
    '''
    nangles = len(angles)
    imten_rep = torch.repeat_interleave(imten, nangles, 0)
    
    imten_rot = kornia.geometry.rotate(imten_rep, angles)
    
    if is_3d:
        sinogram = imten_rot.sum(2).squeeze().permute(1, 0, 2)
    else:
        sinogram = imten_rot.sum(2).squeeze()
        
    return sinogram

In [None]:
image_tensor = torch.tensor(image).float().cuda()[None,None,...]
print(image_tensor.shape)

In [None]:
with torch.no_grad():
    thetas = torch.tensor(np.linspace(0, 180, 100, dtype=np.float32)).cuda()
    sinogram = radon(image_tensor, thetas)[None,...]

In [None]:
coords = alpine.utils.get_coords_spatial(H, W).float().cuda()[None,...]
print(coords.shape)

In [None]:
def inverse_ct_closure(model_ctx, input, signal, iteration, return_features=False):
    output_packet = model_ctx(input)
    output_img = output_packet['output']
    output_sinogram = radon(output_img.permute(0, 3, 1, 2), thetas)[None,...]
    return {'output' : output_sinogram, 'output_img':output_img}
    


In [None]:
fit_output = wire_model.fit_signal(input = coords, 
                                   signal = sinogram, 
                                   closure=inverse_ct_closure, 
                                   n_iters=NUM_ITERATIONS,
                                   enable_tqdm = True, 
                                   save_best_weights = True,
                                   metric_trackers={'psnr':MetricTracker(PeakSignalNoiseRatio().cuda())})

In [None]:
output = wire_model.render(coords, use_best_weights=True)

In [None]:
outimg = np.clip(output['output'].detach().cpu().numpy()[0,...],0,1)
print(outimg.shape, outimg.min(), outimg.max())

In [None]:
plt.figure()
plt.imshow(output['output'].detach().cpu().numpy()[0,...], cmap='gray')
plt.axis('off')
plt.savefig('./output/ct_recon3.pdf', bbox_inches='tight', pad_inches=0, dpi=300)
plt.show()


In [None]:
import skimage.metrics

print(skimage.metrics.peak_signal_noise_ratio(image.flatten(), output['output'].detach().cpu().numpy()[0,...].flatten(), data_range=1.0)) 