In [2]:
import skimage
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import time
import argparse
import cv2
from scipy import io
from tqdm.notebook import tqdm

import torch
from torch import nn
import torch.nn.functional as F
import torchvision.models as models
import torch.optim.lr_scheduler as lr_scheduler
from pytorch_msssim import ssim

from modules import utils, lin_inverse
from modules.models import INR

ModuleNotFoundError: No module named 'kornia'

In [None]:
parser = argparse.ArgumentParser(description='INCODE')

# Shared Parameters
parser.add_argument('--input',type=str, default='./incode_data/Image/img_377.png', help='Input image path')
parser.add_argument('--inr_model',type=str, default='incode', help='[gauss, mfn, relu, siren, wire, wire2d, ffn, incode]')
parser.add_argument('--lr',type=float, default=2e-4, help='Learning rate')
parser.add_argument('--using_schedular', type=bool, default=True, help='Whether to use schedular')
parser.add_argument('--scheduler_b', type=float, default=0.4, help='Learning rate scheduler')
parser.add_argument('--maxpoints', type=int, default=256*256, help='Batch size')
parser.add_argument('--niters', type=int, default=2001, help='Number if iterations')
parser.add_argument('--steps_til_summary', type=int, default=500, help='Number of steps till summary visualization')

# CT Parameters
parser.add_argument('--proj', type=int, default=150, help='Number of CT measurements')

# INCODE Parameters
parser.add_argument('--a_coef',type=float, default=0.1993, help='a coeficient')
parser.add_argument('--b_coef',type=float, default=0.0196, help='b coeficient')
parser.add_argument('--c_coef',type=float, default=0.0588, help='c coeficient')
parser.add_argument('--d_coef',type=float, default=0.0269, help='d coeficient')


args = parser.parse_args(args=[])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Loading Data

In [None]:
thetas = torch.tensor(np.linspace(0, 180, args.proj, dtype=np.float32)).to(device)
im = utils.normalize(plt.imread(args.input).astype(np.float32), True)[..., 0]
H, W = im.shape
gt = torch.tensor(im)[None, None, ...].to(device)

with torch.no_grad():
    sinogram = lin_inverse.radon(gt, thetas).detach().cpu().numpy()
    sinogram_gt = torch.tensor(sinogram).to(device)

## Defining Model

### Defining desired Positional Encoding

In [None]:
# Frequency Encoding
pos_encode_freq = {'type':'frequency', 'use_nyquist': True, 'mapping_input': int(args.proj)}

# Gaussian Encoding
pos_encode_gaus = {'type':'gaussian', 'scale_B': 10, 'mapping_input': 256}

# No Encoding
pos_encode_no = {'type': None}

### Model Configureations

In [None]:
### Harmonizer Configurations
MLP_configs={'task': 'image',
             'model': 'resnet34',
             'truncated_layer':5,
             'in_channels': 64,             
             'hidden_channels': [64, 32, 4],
             'mlp_bias':0.3120,
             'activation_layer': nn.SiLU,
             'GT': sinogram_gt[None, None, ...].expand(1, 3, sinogram_gt.shape[0], sinogram_gt.shape[1])
            }

### Model Configurations
model = INR(args.inr_model).run(in_features=2,
                                out_features=1, 
                                hidden_features=256,
                                hidden_layers=3,
                                first_omega_0=30.0,
                                hidden_omega_0=30.0,
                                pos_encode_configs=pos_encode_no, 
                                MLP_configs = MLP_configs
                               ).to(device)

## Training Code

In [None]:
# Optimizer setup
if args.inr_model == 'wire':
    args.lr = args.lr * min(1, args.maxpoints / (H * W))
optim = torch.optim.Adam(lr=args.lr, params=model.parameters())
scheduler = lr_scheduler.LambdaLR(optim, lambda x: args.scheduler_b ** min(x / args.niters, 1))

# Initialize lists for PSNR and MSE values
psnr_values = []
mse_array = torch.zeros(args.niters, device='cuda')

# Initialize best loss value as positive infinity
best_loss = torch.tensor(float('inf'))

# Generate coordinate grid
coords = utils.get_coords(H, W, dim=2)[None, ...].to(device)

In [None]:
for step in tqdm(range(args.niters)):

    # Calculate model output
    if args.inr_model == 'incode':
        model_output, coef = model(coords)  
    else:
        model_output = model(coords) 

    model_output = model_output.reshape(-1, H, W)[None, ...]

    # Compute the sinogram of output
    sinogram_output = lin_inverse.radon(model_output, thetas)

    # Calculate the output loss
    output_loss = ((sinogram_output - sinogram_gt)**2).mean()

    if args.inr_model == 'incode':
        # Calculate regularization loss for 'incode' model
        a_coef, b_coef, c_coef, d_coef = coef[0]  
        reg_loss = args.a_coef * torch.relu(-a_coef) + \
                   args.b_coef * torch.relu(-b_coef) + \
                   args.c_coef * torch.relu(-c_coef) + \
                   args.d_coef * torch.relu(-d_coef)

        # Total loss for 'incode' model
        loss = output_loss + reg_loss 
    else: 
        # Total loss for other models
        loss = output_loss

    # Perform backpropagation and update model parameters
    optim.zero_grad()
    loss.backward()
    optim.step()

    # Adjust learning rate using a scheduler if applicable
    if args.using_schedular:
        scheduler.step()


    # Calculate PSNR
    with torch.no_grad():
        mse_array[step] = ((gt - model_output)**2).mean().item()
        psnr = -10*torch.log10(mse_array[step])
        psnr_values.append(psnr.item())

        
    # Check if the current iteration's loss is the best so far
    if (mse_array[step] < best_loss) or (step == 0):
        best_loss = mse_array[step]
        model_output = (model_output - model_output.min()) / (model_output.max() - model_output.min())
        best_img = model_output

    # Display intermediate results at specified intervals
    if step % args.steps_til_summary == 0:
        print("Epoch: {} | Loss: {:.5f} | PSNR: {:.5f}".format(step, loss.item(), psnr.item())) 
        
        ### Plot                                                                   
        fig, axes = plt.subplots(1, 4, figsize=(10, 10))
        subplot_info = [
            {'title': 'Ground Truth', 'image': im, 'cmap': 'gray'},
            {'title': 'Reconstructed', 'image': best_img[0][0].cpu().detach().numpy(), 'cmap': 'gray'},
            {'title': 'Sinogram GT', 'image': sinogram_gt.cpu().detach().numpy(), 'cmap': 'viridis'},
            {'title': 'Sinogram', 'image': sinogram_output.cpu().detach().numpy(), 'cmap': 'viridis'}]

        for ax, info in zip(axes, subplot_info):
            ax.set_title(info['title'])
            ax.imshow(info['image'], cmap=info['cmap'])
            ax.axis('off')
        plt.show()



# Print maximum PSNR achieved during training
print('--------------------')
print('Max PSNR:', max(psnr_values))
print('--------------------')

# Convergance Rate

In [None]:
font = {'font': 'Times New Roman', 'size': 12}

plt.figure()
axfont = {'family' : 'Times New Roman', 'weight' : 'regular', 'size'   : 10}
plt.rc('font', **axfont)

plt.plot(np.arange(len(psnr_values[:-1])), psnr_values[:-1], label = f"{(args.inr_model).upper()}")
plt.xlabel('# Epochs', fontdict=font)
plt.ylabel('PSNR (dB)', fontdict=font)
plt.title('CT Reconstruction', fontdict={'family': 'Times New Roman', 'size': 12, 'weight': 'bold'})
plt.legend()
plt.grid(True, color='lightgray')

plt.show()