In [18]:
import numpy as onp
import jax.numpy as np
from jax import random, vmap, lax
from jax.config import config
from jax.scipy.special import expit as sigmoid
config.update("jax_enable_x64", True)

import matplotlib.pyplot as plt

from scipy.linalg import eigh


from jaxbo.mcmc_models import ReimannianMFGPclassifierFourier, ReimannianGPclassifierFourier
from jaxbo.input_priors import uniform_prior

from sklearn.metrics import balanced_accuracy_score

from utils.Mesh import Mesh
import meshio
onp.random.seed(1234)

### load the left atrium geometry and normalize

In [2]:
m = Mesh('data/LA_geometry.obj')

# mesh to voxels factor = 5 (mesh size = 0.2 mm)
verts = m.verts*5
centroid = verts.mean(0)
std_max = verts.std(0).max()
verts_new = (verts - centroid)/std_max

m = Mesh(verts = verts_new, connectivity = m.connectivity)



### compute Laplacian and its eigenvalues

In [3]:
print('Computing Laplacian')
K, M = m.computeLaplacian()
print('Computing eigen values')
eigvals, eigvecs = eigh(K,M)

Computing Laplacian
Computing eigen values


### generate all available cases

In [6]:
ablations = ['', '-PVI', '-BOX']
fibrosis = ['50PF', '70PF', '71PF']

fibnames = ['moderate fibrosis', 'severe fibrosis - case 1', 'severe fibrosis - case 2']
abnames = ['', ' - PVI', ' - PVI + BOX']
cases = []
names = []

for f in fibrosis:
    for a in ablations:
        cases.append(f + a)

for f in fibnames:
    for a in abnames:
        names.append(f + a)

print('cases')
print(cases)
print('names')
print(names)

id_case = 3
case = cases[id_case]

print('picked case:', names[id_case])

cases
['50PF', '50PF-PVI', '50PF-BOX', '70PF', '70PF-PVI', '70PF-BOX', '71PF', '71PF-PVI', '71PF-BOX']
names
['moderate fibrosis', 'moderate fibrosis - PVI', 'moderate fibrosis - PVI + BOX', 'severe fibrosis - case 1', 'severe fibrosis - case 1 - PVI', 'severe fibrosis - case 1 - PVI + BOX', 'severe fibrosis - case 2', 'severe fibrosis - case 2 - PVI', 'severe fibrosis - case 2 - PVI + BOX']
picked case: severe fibrosis - case 1


### load data -inducible points- for the selected case

In [12]:

gt = onp.genfromtxt('data/ground_truth_points.csv')[:,3].astype(int)
train_points = onp.genfromtxt('data/train_points.csv')[:,3].astype(int)
X_L = train_points

X_all = X_L

N_H = 40 # number of high fidelity points for training

X_H = train_points[:N_H]

Y_L_all = np.load('data/LF_train-%s.npz' % case)['output']
Y_H_all =  np.load('data/HF_train-%s.npz' % case)['output']

    
Y_L = Y_L_all[:100]
Y_H = Y_H_all[:N_H]

y_true = np.load('data/ground_truth-%s.npz' % case)['output']

X_true = gt

Y = np.concatenate([Y_L, Y_H])

### select the number of eigenfunctions to be used

In [14]:
rng_key = random.PRNGKey(123)

n_eigs = 1000

eigpairs = (np.array(eigvals[:n_eigs]), np.array(eigvecs[:,:n_eigs]).T)

D = 1
lb = 0.0*np.ones(D)
ub = 1.0*np.ones(D)
bounds = {'lb': lb, 'ub': ub}
p_x = uniform_prior(lb, np.ones(D)*m.verts.shape[0])



## Multi-fidelity classifier

In [23]:
options = {'kernel': 'RBF',
           'criterion': 'LW_CLSF', 
           'input_prior': p_x,
           'kappa': 1.0,
           'nIter': 0}
mcmc_settings = {'num_warmup': 500,
                 'num_samples': 500,
                 'num_chains': 1,
                 'target_accept_prob': 0.9}
gp_model = ReimannianMFGPclassifierFourier(options, eigpairs)

### training and testing

In [24]:

batch = {'XL': X_L, 'XH': X_H, 'y': Y}
key_train, key_test = random.split(rng_key)
samples = gp_model.train(batch,
                            key_train,
                            mcmc_settings,
                            verbose = False)

rng_keys = random.split(key_test, 
                    mcmc_settings['num_samples'] * mcmc_settings['num_chains'])
kwargs = {'samples': samples,
            'batch': batch,
            'bounds': bounds,
            'rng_key': key_test,
            'rng_keys': rng_keys}
n_nodes = eigpairs[1].shape[1]
X_star = np.arange(n_nodes)


# this is a way to speed up the prediction which does not scale well with the amount requested points due to the big size of the covariance matrices
Mean, Std = lax.map(lambda x: gp_model.predict_conditional(x, **kwargs),X_star[:3200].reshape(8,-1))
mean, std = gp_model.predict_conditional(X_star[3200:], **kwargs)
Mean = np.concatenate((np.transpose(Mean, (1,0,2)).reshape((-1,3200)), mean), axis = 1)
Std = np.concatenate((np.transpose(Std, (1,0,2)).reshape((-1,3200)), std), axis = 1)

Mean_all = Mean.mean(0)
Std_all = np.sqrt(np.mean(Std**2, axis = 0))


fmesh = meshio.Mesh(points = m.verts*std_max + centroid, cells = {'triangle':m.connectivity}, point_data = {'probs': onp.array(sigmoid(Mean_all)), 'std': onp.array(Std_all)})
fmesh.write('output/LA_MF_%s_NH_%i.vtu' % (case, N_H))
accuracy = balanced_accuracy_score(y_true, np.rint(sigmoid(Mean_all[X_true])))

print('balanced accuracy:', accuracy)
    

sample: 100%|██████████| 1000/1000 [00:53<00:00, 18.59it/s, 255 steps of size 1.27e-02. acc. prob=0.98]


balanced accuracy: 0.8404074702886248


## Single fidelity classifier

In [19]:
options = {'kernel': 'RBF',
           'criterion': 'LW_CLSF', 
           'input_prior': p_x,
           'kappa': 1.0,
           'nIter': 0}
mcmc_settings = {'num_warmup': 500,
                 'num_samples': 500,
                 'num_chains': 1,
                 'target_accept_prob': 0.9}
gp_model_SF = ReimannianGPclassifierFourier(options, eigpairs)

### training and testing

In [22]:

batch = { 'X': X_H, 'y': Y_H}
key_train, key_test = random.split(rng_key)
samples = gp_model_SF.train(batch,
                            key_train,
                            mcmc_settings,
                            verbose = False)

rng_keys = random.split(key_test, 
                    mcmc_settings['num_samples'] * mcmc_settings['num_chains'])
kwargs = {'samples': samples,
            'batch': batch,
            'bounds': bounds,
            'rng_key': key_test,
            'rng_keys': rng_keys}
n_nodes = eigpairs[1].shape[1]
X_star = np.arange(n_nodes)


# this is a way to speed up the prediction which does not scale well with the amount requested points due to the big size of the covariance matrices
Mean, Std = lax.map(lambda x: gp_model_SF.predict_conditional(x, **kwargs),X_star[:3200].reshape(8,-1))
mean, std = gp_model_SF.predict_conditional(X_star[3200:], **kwargs)
Mean = np.concatenate((np.transpose(Mean, (1,0,2)).reshape((-1,3200)), mean), axis = 1)
Std = np.concatenate((np.transpose(Std, (1,0,2)).reshape((-1,3200)), std), axis = 1)

Mean_all = Mean.mean(0)
Std_all = np.sqrt(np.mean(Std**2, axis = 0))


fmesh = meshio.Mesh(points = m.verts*std_max + centroid, cells = {'triangle':m.connectivity}, point_data = {'probs': onp.array(sigmoid(Mean_all)), 'std': onp.array(Std_all)})
fmesh.write('output/LA_SF_%s_NH_%i.vtu' % (case, N_H))
accuracy = balanced_accuracy_score(y_true, np.rint(sigmoid(Mean_all[X_true])))

print('balanced accuracy:', accuracy)
    

sample: 100%|██████████| 1000/1000 [00:11<00:00, 87.73it/s, 63 steps of size 5.23e-02. acc. prob=0.95] 


balanced accuracy: 0.7504244482173175
