In [36]:
import math
import scipy
import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import minimize
import os
import nibabel as nib

# from mpl_toolkits.mplot3d import Axes3D 
%run utils/gaussian_aux.ipynb

In [38]:
def pointset_register_main(gaussian_centers_3d, mrid_dict, bundle_start, weighted_loss_f, visualization=False):
    # gaussian_centers_3d = np.load(os.path.join(analysedpath, mrid_type, "3D-gaussian-centers-mrid.npy"))
    mrid_design_dist, mrid_design_points, pattern_lengths, ionp_amount = get_mrid_dimensions(mrid_dict, bundle_start)
    loss_f_weights = np.ones_like(pattern_lengths)

    if weighted_loss_f=="density":
        loss_f_weights = ionp_amount/pattern_lengths
        
    elif weighted_loss_f=="length":
        loss_f_weights = pattern_lengths

    elif weighted_loss_f=="iopn_amount":
        loss_f_weights = ionp_amount
        
    res=pointsetreg(gaussian_centers_3d, mrid_design_dist, loss_f_weights)

    reg_results = res.x
    print("Registration resulsts: ")
    print(reg_results)
    fitted_mrid_points = get_fitted_points(res, mrid_design_dist)

    if visualization:
        visualize_pointfit(gaussian_centers_3d, fitted_mrid_points)

    # filename = "fitted_mrid_points.npy"
    # np.save(os.path.join(analysedpath, mrid_type, filename), fitted_mrid_points)
    
    return fitted_mrid_points

In [6]:
def get_mrid_dimensions(mrid_dict, bundle_start):
    pattern_dimensions = mrid_dict["dimensions"][bundle_start:, :]
    pattern_intersegment = mrid_dict["intersegment_distances"][bundle_start:]
    ionp_amount = mrid_dict["ionp_amount"][bundle_start:]
    pattern_lengths = pattern_dimensions[:,-1]
    pattern_dist, pattern_points=get_centomass(pattern_dimensions, pattern_intersegment)
    
    return pattern_dist, pattern_points, pattern_lengths, ionp_amount

In [5]:
def bundle_fit3d(x, *args):
    
    xfit=x[0]
    yfit=x[1]
    zfit=x[2]
    
    mriPoints=args[0]
    pattern_dists=args[1]
    pattern_lengths = args[2]
    inverse_lengths = 1/pattern_lengths
    weights = inverse_lengths/np.sum(inverse_lengths)
    d=[]
    for i, mriPoint in enumerate(mriPoints):
        mriX, mriY, mriZ = mriPoint
        if i==0:
            err_d = np.sqrt((mriX-xfit)**2 + (mriY-yfit)**2 + (mriZ-zfit)**2)
            err_d_weighted = err_d * weights[i]
            d.append(err_d_weighted)
            # d.append(np.sqrt((mriX-xfit)**2 + (mriY-yfit)**2 + (mriZ-zfit)**2))
        else:
            r=pattern_dists[i-1]
            theta=x[2*i+1]
            gamma=x[2*i+2]
            
            xfit=xfit + r*math.sin(theta)*math.cos(gamma)
            yfit=yfit + r*math.sin(theta)*math.sin(gamma)
            zfit=zfit + r*math.cos(theta)

            err_d = np.sqrt((mriX-xfit)**2 + (mriY-yfit)**2 + (mriZ-zfit)**2)
            err_d_weighted = err_d * weights[i]
            d.append(err_d_weighted)
            # d.append(np.sqrt((mriX-xfit)**2 + (mriY-yfit)**2 + (mriZ-zfit)**2))

    d=np.array(d)
    return np.sum(d)

def get_spherical_coord(cart_coord):
    xyz_list = np.diff(cart_coord, axis=0)
    sph_coord = np.zeros_like(xyz_list)
    
    for i, xyz in enumerate(xyz_list):
        xy = xyz[0]**2 + xyz[1]**2
        sph_coord[i,0] = np.sqrt(xy + xyz[2]**2)
        sph_coord[i,1] = np.arctan2(np.sqrt(xy), xyz[2]) # for elevation angle defined from Z-axis down
        sph_coord[i,2] = np.arctan2(xyz[1], xyz[0])
    
    return sph_coord

In [2]:
def pointsetreg(gaussian_centers_3d, pattern_dist, pattern_lengths):
    """
    Registers bundle to the measured Gaussian centers in 3D.
    """
    px_size=25

    pInit=gaussian_centers_3d[0]
    sph_coord_gaussian_centers=get_spherical_coord(gaussian_centers_3d)
    print("Spherical coordinates for measured 3D gaussian centers: ")
    print(sph_coord_gaussian_centers)

    
    x_init = np.append(pInit,sph_coord_gaussian_centers[:,1:].flatten())
    print("Initialization x: ")
    print(x_init)

    res = minimize(bundle_fit3d, x_init, method='BFGS', args=(gaussian_centers_3d, pattern_dist/px_size, pattern_lengths) )
                   
    return res

SyntaxError: invalid syntax (3087551127.py, line 26)

In [16]:
def get_fitted_points(reg_result, pattern_dist):
    """
    Calculates the fitted point coordinates given the registration results and MRID dimensions.
    """
    x=reg_result.x

    p1=[x[0], x[1], x[2]]
    points=[p1]
    
    num_points=(len(x)-3)/2
    
    for i in range(int(num_points)):
        theta=x[2*i+3]
        gamma=x[2*i+4]
        # print(points[i])
        xprev, yprev, zprev = points[i]
        r=pattern_dist[i]/25
        
        newX=xprev + r*math.sin(theta)*math.cos(gamma)
        newY=yprev + r*math.sin(theta)*math.sin(gamma)
        newZ=zprev + r*math.cos(theta)
        
        newP=[newX, newY, newZ]
        points.append(newP)
    
    points=np.array(points)
    
    return points

In [20]:
def visualize_pointfit(gauss3d, fit3d):
    fig = go.Figure()

    # Scatter for gauss3d (markers only)
    fig.add_trace(go.Scatter3d(
        x=gauss3d[:, 0], 
        y=gauss3d[:, 1], 
        z=gauss3d[:, 2],
        mode='markers',
        marker=dict(size=5, color='blue'),
        name='gauss3d'
    ))
    
    # Scatter for fit3d (markers + connecting lines)
    fig.add_trace(go.Scatter3d(
        x=fit3d[:, 0], 
        y=fit3d[:, 1], 
        z=fit3d[:, 2],
        mode='markers+lines',
        marker=dict(size=5, color='red'),
        line=dict(width=2, color='red'),
        name='fit3d'
    ))
    
    fig.update_layout(
        scene=dict(
            # ax.set_xlabel("Medial --> Lateral")

            # ax.set_ylabel("Ventral --> Dorsal")
            # ax.set_zlabel("Posterior --> Anterior")
            xaxis_title="Medial --> Lateral",
            yaxis_title="Ventral --> Dorsal",
            zaxis_title="Posterior --> Anterior"
        ),
        legend=dict(x=0, y=1)
    )
    
    
    fig.show()
    