# [Generalizable Neural Fields as Partially Observed Neural Processes](https://arxiv.org/abs/2309.06660)

In this paper, we leverage neural processes (NPs) to both speed up neural field training and leverage our ability to train neural fields for related signals to condition the neural field for a signal of interest. Previous approaches used gradient-based meta-learning methods such as Reptile, but we show that NPs are superior to gradient-based meta-learning approaches. In this notebook, we replicate our experiment for 2D CT reconstruction. 

**Open notebook:** 
[![View on Github](https://img.shields.io/static/v1.svg?logo=github&label=Repo&message=View%20On%20Github&color=lightgrey)](https://github.com/its-gucci/partially-observed-neural-processes/blob/2D_CT_Recon.ipynb)
[![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/its-gucci/partially-observed-neural-processes/blob/2D_CT_Recon.ipynb)  


Let's start by importing the libraries we'll need. 

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

from map_coordinates import _map_coordinates

import os
os.chdir('neural_process_family')
import pickle

## Forward Map

For 2D CT reconstruction, we would like to reconstruct a 2D CT scan from 1D sensor observations called sinograms. The 1D sinograms are generated from the 2D CT scan through the following forward map (integral projection, adapted to PyTorch from [here](https://github.com/tancik/learnit/blob/main/Experiments/2d_ct.ipynb)). 

In [None]:
def ct_project(img, theta, device='cuda'):
    y, x = torch.meshgrid(torch.arange(img.shape[0])/img.shape[0] - 0.5,
                          torch.arange(img.shape[0])/img.shape[0] - 0.5,
                         )
    x = x.to(device)
    y = y.to(device)
    x_rot = x*torch.cos(theta) - y*torch.sin(theta)
    y_rot = x*torch.sin(theta) + y*torch.cos(theta)
    x_rot = (x_rot + 0.5)*img.shape[1]
    y_rot = (y_rot + 0.5)*img.shape[0]
    sample_coords = torch.stack([y_rot, x_rot], dim=0)
    resampled = _map_coordinates(img, sample_coords.to(device), 0, device=device).reshape(img.shape)
    return resampled.mean(axis=0)[:,None,...]

In [None]:
def ct_project_batch(img, thetas, device='cuda'):
    projs = []
    for theta in thetas:
        projs.append(ct_project(img.squeeze(), theta, device=device))
    return torch.stack(projs)

In [None]:
def ct_project_double_batch(imgs, thetas, device='cuda'):
    '''
    imgs: [batch, 256, 256]
    thetas: [batch, n_projs]
    '''
    b = imgs.shape[0]
    n_projs = thetas.shape[1]
    projs = []
    for i in range(b):
        partial = []
        for j in range(n_projs):
            partial.append(ct_project(imgs[i], thetas[i][j], device=device))
        projs.append(torch.stack(partial))
    return torch.stack(projs)

## Architectures

We will implement our NP-based Partially-Observed Neural Process (PONP) with help from the [Neural Process Family library](https://github.com/YannDubs/Neural-Process-Family), which implements many different neural process algorithms. They also have a nice introduction to neural processes [here](https://yanndubs.github.io/Neural-Process-Family/text/Intro.html). 

### MLP

In [None]:
from npf.architectures.mlp import MLP

### CT reconstruction model

For fair comparison, we use the same CT reconstruction neural field as was used in previous work (adapted to PyTorch from [here](https://github.com/tancik/learnit/blob/main/Experiments/2d_ct.ipynb)). 

In [None]:
class CTReconModel(nn.Module):
    def __init__(self, in_features=2, out_features=1, hidden_features=256, hidden_layers=5, device='cuda'):
        super().__init__()
        modules = []
        
        # first layer
        modules.append(nn.Linear(hidden_features, hidden_features))
        modules.append(nn.ReLU())
        
        # intermediate layers
        for i in range(1, hidden_layers - 1):
            modules.append(nn.Linear(hidden_features, hidden_features))
            modules.append(nn.ReLU())
            
        # last layer
        modules.append(nn.Linear(hidden_features, out_features))
        modules.append(nn.ReLU())
        
        self.net = nn.Sequential(*modules)
        
    def forward(self, x):
        return self.net(x)

### Transformers

We will use these networks to transform the output of our PONP decoder.

In [None]:
class loc_transformer(nn.Module):
    
    def __init__(self, dim=256):
        super().__init__()
        self.linear = nn.Linear(dim, 1)
        self.activation = nn.Sigmoid()
    
    def forward(self, x, thetas):
        x_shape = x.shape
        x = self.linear(x).view(-1, 256, 256)
        x = self.activation(x)
        x = ct_project_double_batch(x, thetas.repeat(x_shape[0], 1))
        return x.view(*x_shape[:-2], -1, 256)

In [None]:
class scale_transformer(nn.Module):
    
    def __init__(self, dim=256):
        super().__init__()
        self.linear = nn.Linear(dim, 1)
        self.activation = nn.Sigmoid()
        
    def forward(self, x, thetas):
        x_shape = x.shape
        x = self.linear(x).view(-1, 256, 256)
        x = self.activation(x)
        x = ct_project_double_batch(x, thetas.repeat(x_shape[0], 1))
        x = 0.01 + 0.99 * F.softplus(x)
        return x.view(*x_shape[:-2], -1, 256)

In [None]:
class id_transformer(nn.Module):
    
    def __init__(self, transformer, stat='loc'):
        super().__init__()
        self.linear = transformer.linear
        self.activation = transformer.activation
        self.stat = stat
    
    def forward(self, x, _):
        x_shape = x.shape
        x = self.linear(x).view(*x_shape[:-2], 256, 256)
        x = self.activation(x)
        if self.stat == 'scale':
            x = 0.01 + 0.99 * F.softplus(x)
        return x

## Neural Process

In [None]:
from functools import partial

from npf import AttnLNP
from npf.architectures import merge_flat_input
from utils.helpers import count_parameters

### Adapt LNP to CT recon

We change the usual neural process model to incorporate training with our forward map. 

In [None]:
class AttnLNPFM(AttnLNP):
    
    def __init__(self, x_dim, y_dim, **kwargs):
        super().__init__(x_dim, y_dim, **kwargs)
    
    def forward(self, X_cntxt, Y_cntxt, X_trgt, Y_trgt=None, thetas=None):
        try:
            self.n_z_samples = (
                self.n_z_samples_train.rvs()
                if self.training
                else self.n_z_samples_test.rvs()
            )
        except AttributeError:
            self.n_z_samples = (
                self.n_z_samples_train if self.training else self.n_z_samples_test
            )
        
        # NeuralProcessFamily forward 
        self._validate_inputs(X_cntxt, Y_cntxt, X_trgt, Y_trgt)

        # size = [batch_size, *n_cntxt, x_transf_dim]
        X_cntxt = self.x_encoder(X_cntxt)
        # size = [batch_size, *n_trgt, x_transf_dim]
        X_trgt = self.x_encoder(X_trgt)

        # {R^u}_u
        # size = [batch_size, *n_rep, r_dim]
        R = self.encode_globally(X_cntxt, Y_cntxt)

        if self.encoded_path in ["latent", "both"]:
            z_samples, q_zCc, q_zCct = self.latent_path(X_cntxt, R, X_cntxt, Y_trgt)
        else:
            z_samples, q_zCc, q_zCct = None, None, None

        if self.encoded_path == "latent":
            # if only latent path then cannot depend on deterministic representation
            R = None

        # size = [n_z_samples, batch_size, *n_trgt, r_dim]
        R_trgt = self.trgt_dependent_representation(X_cntxt, z_samples, R, X_trgt)

        # p(y|cntxt,trgt)
        # batch shape=[n_z_samples, batch_size, *n_trgt] ; event shape=[y_dim]
        p_yCc = self.decode(X_trgt, R_trgt, thetas)

        return p_yCc, z_samples, q_zCc, q_zCct
    
    def decode(self, X_trgt, R_trgt, thetas):
        """
        Compute predicted distribution conditioned on representation and
        target positions.
        Parameters
        ----------
        X_trgt: torch.Tensor, size=[batch_size, *n_trgt, x_transf_dim]
            Set of all target features {x^t}_t.
        R_trgt : torch.Tensor, size=[n_z_samples, batch_size, *n_trgt, r_dim]
            Set of all target representations {r^t}_t.
        Return
        ------
        p_y_trgt: torch.distributions.Distribution, batch shape=[n_z_samples, batch_size, *n_trgt] ; event shape=[y_dim]
            Posterior distribution for target values {p(Y^t|y_c; x_c, x_t)}_t
        """
        # size = [n_z_samples, batch_size, *n_trgt, y_dim*2]
        p_y_suffstat = self.decoder(X_trgt, R_trgt)

        # size = [n_z_samples, batch_size, *n_trgt, y_dim]
        p_y_loc, p_y_scale = p_y_suffstat.split(self.y_dim, dim=-1)

        p_y_loc = self.p_y_loc_transformer(p_y_loc, thetas)
        p_y_scale = self.p_y_scale_transformer(p_y_scale, thetas)

        if not self.is_heteroskedastic:
            n_z_samples, batch_size, *n_trgt, y_dim = p_y_scale.shape
            p_y_scale = p_y_scale.view(n_z_samples * batch_size, *n_trgt, y_dim)
            p_y_scale = pool_and_replicate_middle(p_y_scale)
            p_y_scale = p_y_scale.view(n_z_samples, batch_size, *n_trgt, y_dim)

        # batch shape=[n_z_samples, batch_size, *n_trgt] ; event shape=[y_dim]
        p_yCc = self.PredictiveDistribution(p_y_loc, p_y_scale)

        return p_yCc

### LNP settings

Here, we set the hyperparameters and architecture for our PONP model.

In [None]:
R_DIM = 256
KWARGS = dict(
    is_q_zCct=False,  
    n_z_samples_train=1,
    n_z_samples_test=1,  
    attention='transformer',
    XEncoder=partial(MLP, n_hidden_layers=1, hidden_size=R_DIM),
    Decoder=merge_flat_input(  # MLP takes single input but we give x and R so merge them
        partial(CTReconModel, hidden_layers=4, hidden_features=256), is_sum_merge=True,
    ),
    r_dim=R_DIM,
    p_y_loc_transformer=loc_transformer(),
    p_y_scale_transformer=scale_transformer(),
)

In [None]:
# 1D case
model_1d = partial(
    AttnLNPFM,
    x_dim=2,
    y_dim=256,
    XYEncoder=merge_flat_input(  # MLP takes single input but we give x and y so merge them
        partial(MLP, n_hidden_layers=2, hidden_size=R_DIM * 2), is_sum_merge=True,
    ),
    **KWARGS,
)

In [None]:
n_params_1d = count_parameters(model_1d())
print(f"Number Parameters (1D): {n_params_1d:,d}")

## Dataset + Context/Target Getters

Here, we prepared the data for use in Neural Process Family framework. 

In [None]:
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms

In [None]:
def get_mgrid(sidelen, dim=2):
    '''Generates a flattened grid of (x,y,...) coordinates in a range of -1 to 1.
    sidelen: int
    dim: int'''
    tensors = tuple(dim * [torch.linspace(-1, 1, steps=sidelen)])
    mgrid = torch.stack(torch.meshgrid(*tensors), dim=-1)
    mgrid = mgrid.reshape(-1, dim)
    return mgrid

In [None]:
pi = torch.acos(torch.zeros(1)).item() * 2

In [None]:
data_path = 'data/ct_256.pkl'

In [None]:
class CTDataset(Dataset):
    
    shape=(1, 256, 256)
    coords = get_mgrid(shape[1], dim=2)
    name='ctrecon'
    
    def __init__(self, data_path=data_path, 
                 split='train', n_projs=20, transform=None, device='cuda'):
        super(CTDataset, self).__init__()
        with open(data_path, 'rb') as file:
            dataset = pickle.load(file)
        
        if split == 'train':
            self.data = dataset['data_train']
        elif split == 'test':
            self.data = dataset['data_test']
            
        self.transform = transform
        self.n_projs = n_projs
        self.device = device
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        x = self.data[idx]
        if self.transform:
            x = self.transform(x)
        thetas = pi * torch.rand(self.n_projs)
        image_projs = ct_project_batch(x.squeeze(), thetas, device=self.device).squeeze()
        return self.coords, image_projs, thetas

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

In [None]:
train_dataset = CTDataset(data_path, split='train', transform=transform, device='cpu')
test_dataset = CTDataset(data_path, split='test', transform=transform, device='cpu')

In [None]:
ct_datasets = {'ctrecon': train_dataset}
ct_test_datasets = {'ctrecon': test_dataset}

### Collate functions

In [None]:
from neural_process_family.npf.utils.datasplit import CntxtTrgtGetter, GetRandomIndcs, get_all_indcs
from utils.data import cntxt_trgt_collate

In [None]:
class CntxtTrgtGetterCT(CntxtTrgtGetter):
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def __call__(
        self, X, thetas, y=None, context_indcs=None, target_indcs=None, is_return_indcs=False
    ):
        batch_size, num_points_x = self.getter_inputs(X)
        _, num_points_y = self.getter_inputs(y)

        if context_indcs is None:
            context_indcs = self.contexts_getter(batch_size, num_points_y)
        if target_indcs is None:
            target_indcs_x = self.targets_getter(batch_size, num_points_x)
            target_indcs_y = self.targets_getter(batch_size, num_points_y)

        if self.is_add_cntxts_to_trgts:
            target_indcs_x = self.add_cntxts_to_trgts(
                num_points_x, target_indcs_x, context_indcs_x
            )
            target_indcs_y = self.add_cntxts_to_trgts(
                num_points_y, target_indcs_y, context_indcs_y
            )

        # only used if X for context and target should be different (besides selecting indices!)
        X_pre_cntxt = self.preprocess_context(X)

        if is_return_indcs:
            # instead of features return indices / masks, and `Y_cntxt` is replaced
            # with all values Y
            return (
                context_indcs,
                X_pre_cntxt,
                target_indcs,
                X,
            )

        X_cntxt, Y_cntxt = self.select((thetas/pi).unsqueeze(-1).repeat(1, 1, 2), y, context_indcs)
        X_trgt, Y_trgt = self.select(X, y, target_indcs_x, target_indcs_y)
        return X_cntxt, Y_cntxt, X_trgt, Y_trgt, thetas
    
    def select(self, X, y, indcs, indcs_y=None):
        """Select the correct values from X."""
        batch_size, num_points, x_dim = X.shape
        y_dim = y.size(-1)
        indcs_x = indcs.to(X.device).unsqueeze(-1).expand(batch_size, -1, x_dim)
        if indcs_y is not None:
            indcs_y = indcs_y.to(X.device).unsqueeze(-1).expand(batch_size, -1, y_dim)
        else:
            indcs_y = indcs.to(X.device).unsqueeze(-1).expand(batch_size, -1, y_dim)
        return (
            torch.gather(X, 1, indcs_x).contiguous(),
            torch.gather(y, 1, indcs_y).contiguous(),
        )

In [None]:
def cntxt_trgt_collate_ct(get_cntxt_trgt, is_duplicate_batch=False, **kwargs):
    """Transformes and collates inputs to neural processes given the whole input.
    Parameters
    ----------
    get_cntxt_trgt : callable
        Function that takes as input the features and tagrets `X`, `y` and return
        the corresponding `X_cntxt, Y_cntxt, X_trgt, Y_trgt`.
    is_duplicate_batch : bool, optional
        Wether to repeat the batch to have 2 different context and target sets
        for every function. If so the batch will contain the concatenation of both.
    """

    def mycollate(batch):
        collated = torch.utils.data.dataloader.default_collate(batch)
        X = collated[0]
        y = collated[1]
        thetas = collated[2]

        if is_duplicate_batch:
            X = torch.cat([X, X], dim=0)
            if y is not None:
                y = torch.cat([y, y], dim=0)
            y = torch.cat([y, y], dim=0)

        X_cntxt, Y_cntxt, X_trgt, Y_trgt, thetas = get_cntxt_trgt(X, thetas, y, **kwargs)
        inputs = dict(X_cntxt=X_cntxt, Y_cntxt=Y_cntxt, X_trgt=X_trgt, Y_trgt=Y_trgt, thetas=thetas)
        targets = Y_trgt

        return inputs, targets

    return mycollate

In [None]:
mycollate = cntxt_trgt_collate_ct(
    CntxtTrgtGetterCT(
        get_all_indcs,
        get_all_indcs)
)

## Training 

In [None]:
import skorch
from npf import ELBOLossLNPF, NLLLossLNPF
from skorch.callbacks import GradientNormClipping, ProgressBar
from utils.ntbks_helpers import add_y_dim
from utils.train import train_models

run = 1

KWARGS = dict(
    is_retrain=False,  # whether to load precomputed model or retrain
    criterion=NLLLossLNPF,  # NPVI or NPML
    chckpnt_dirname="saved_models",
    device=None,  # use GPU if available
    batch_size=1,
    lr=1e-4,
    decay_lr=10,  # decrease learning rate by 10 during training
    seed=None,
    callbacks=[
        GradientNormClipping(gradient_clip_value=1)
    ],  # clipping gradients can stabilize training
    starting_run=run,
    runs=1,
)


# 1D
trainers_1d = train_models(
    ct_datasets,
    {"AttnLNP": model_1d},
    test_datasets=ct_test_datasets,
    iterator_train__collate_fn=mycollate,
    iterator_valid__collate_fn=mycollate,
    max_epochs=200,
    **KWARGS
)

## Test Time Optimization + Visualization

Here, we perform test-time optimization as described in the paper. We also visualize the results. 

In [None]:
import torch.optim as optim

import matplotlib.pyplot as plt

import copy

In [None]:
trainers_1d['ctrecon/AttnLNP/run_{}'.format(r)].module_.p_y_loc_transformer = id_transformer(
    trainers_1d['ctrecon/AttnLNP/run_{}'.format(r)].module_.p_y_loc_transformer, stat='loc'
)
trainers_1d['ctrecon/AttnLNP/run_{}'.format(r)].module_.p_y_scale_transformer = id_transformer(
    trainers_1d['ctrecon/AttnLNP/run_{}'.format(r)].module_.p_y_scale_transformer, stat='scale'
)

In [None]:
print(trainers_1d['ctrecon/AttnLNP/run_{}'.format(run)].module_)

In [None]:
# turn off gradients for all layers except the CT recon model
for name, param in trainers_1d['ctrecon/AttnLNP/run_{}'.format(run)].module_.named_parameters():
    if name.startswith('decoder.flat_module') or name.startswith('p_y_'):
        param.requires_grad = True
    else:
        param.requires_grad = False
        
    print(name, param.requires_grad)

In [None]:
coords = get_mgrid(256, dim=2)

In [None]:
# views to test steps
_views_to_test_steps = {
    1: 50,
    2: 100,
    4: 1000,
    8: 1000,
}

In [None]:
def find_psnr(model, views):
    avg_mse = 0.0
    avg_psnr = 0.0
    # calculate thetas
    thetas = torch.linspace(0, pi, views + 1)
    thetas = thetas[:-1]
    for i in range(len(test_dataset.data)):
        # get test image and test projs
        test_image = test_dataset.data[i]
        image_projs = ct_project_batch(torch.from_numpy(test_image), thetas, device='cpu')
        
        # test time optimization
        learner = copy.deepcopy(model).cuda()
        opt = optim.Adam(learner.parameters(), lr=1e-4)
        # for j in range(_views_to_test_steps[views]):
        for j in range(0):
            # pass through NP model to find pred
            dist = learner.forward(
                (thetas/pi).view(1, views, 1).repeat(1, 1, 2).cuda(),
                image_projs.view(1, views, 256).cuda(),
                coords.unsqueeze(0).cuda(),
                image_projs.view(1, views, 256).cuda(),
                thetas.cuda(),
            )
            preds = dist[0].base_dist.loc
            # calculate predicted projections
            pred_projs = ct_project_batch(preds.squeeze(), thetas)
            
            # calculate loss
            loss = ((torch.from_numpy(test_image).cuda() - preds.squeeze())**2).mean()
            
            # update parameters
            opt.zero_grad()
            loss.backward()
            opt.step() 
        
        # forward model to find NP pred
        with torch.no_grad():
            dist = learner.forward(
                (thetas/pi).view(1, views, 1).repeat(1, 1, 2).cuda(),
                image_projs.view(1, views, 256).cuda(),
                coords.unsqueeze(0).cuda(),
                image_projs.view(1, views, 256).cuda(),
                thetas.cuda(),
            )
        preds = dist[0].base_dist.loc
        print(preds.shape)
        # calculate predicted projections
        pred_projs = ct_project_batch(preds.squeeze(), thetas)
        
        # display predicted reconstruction and projection
        plt.figure(figsize=(15,4))        
        plt.subplot(1,2, 1)
        plt.imshow(preds.reshape(256, 256).detach().cpu().numpy())
        plt.title('Phantom')
        plt.subplot(1,2, 2)
        plt.imshow(pred_projs.detach().cpu().numpy()[:,:,0])
        plt.title('Sinogram')
        plt.show()
            
        # display groundtruth reconstruction and projection
        plt.figure(figsize=(15,4))
        plt.subplot(1,2, 1)
        plt.imshow(test_image)
        plt.title('Phantom')
        plt.subplot(1,2, 2)
        plt.imshow(image_projs.cpu().numpy()[:,:,0])
        plt.title('Sinogram')
        plt.show()
        
        loss = ((torch.from_numpy(test_image).cuda() - preds.squeeze())**2).mean()
        psnr = -10 * torch.log10(loss)
        print('MSE loss: {}'.format(loss))
        print('PSNR: {}'.format(psnr))
        
        avg_mse += loss
        avg_psnr += psnr
        
        print('Avg MSE loss: {}'.format(avg_mse/(i + 1)))
        print('Avg PSNR: {}'.format(avg_psnr/(i + 1)))
        
        torch.cuda.empty_cache()
    
    print('Avg MSE loss: {}'.format(avg_mse/len(test_dataset.data)))
    print('Avg PSNR: {}'.format(avg_psnr/len(test_dataset.data)))
    return avg_mse/len(test_dataset.data), avg_psnr/len(test_dataset.data)

The number of input views used can be changed via the views argument. In the paper, we use views=1,2,4,8. 

In [None]:
find_psnr(trainers_1d['ctrecon/AttnLNP/run_{}'.format(run)].module_, views=8)