In [1]:
import jax
import jax.numpy as jnp
import io
import base64
import time
from functools import partial
from typing import NamedTuple
import subprocess
import optax
import PIL
import numpy as np
import matplotlib.pylab as pl
import einops 
import SimpleITK as sitk
import itertools

def norm(v, axis=-1, keepdims=False, eps=0.0):
  return jnp.sqrt((v*v).sum(axis, keepdims=keepdims).clip(eps))

def normalize(v, axis=-1, eps=1e-20):
  return v/norm(v, axis, keepdims=True, eps=eps)
def standardize_to_sum_1(v, axis=-1, eps=1e-20):
  a=v-jnp.min(v)
  # print(f"a {a}")
  return a/jnp.sum(jnp.ravel(a))


class Balls(NamedTuple):
  pos: jnp.ndarray
  color: jnp.ndarray




In [2]:
def show_slice(sdf, z=0.0, w=400, r=3.5):
  y, x = jnp.mgrid[-r:r:w*1j, -r:r:w*1j].reshape(2, -1)
  p = jnp.c_[x, y, x*0.0+z]
  d = jax.vmap(sdf)(p).reshape(w, w)
  pl.figure(figsize=(5, 5))
  kw = dict(extent=(-r, r, -r, r), vmin=-r, vmax=r)
  pl.contourf(d, 16, cmap='bwr', **kw );
  pl.contour(d, levels=[0.0], colors='black', **kw);
  pl.axis('equal')
  pl.xlabel('x')
  pl.ylabel('y')

In [3]:
# def create_balls(key, n=3, R=3.0):
#   pos, color = jax.random.uniform(key, [2, n, 3])
#   pos = (pos-0.5)*R
#   return Balls(pos, color)
# normalizedf, balls), z=0.0)

idea in general is to have a function that takes a point coordinates and output the value
so we have stored for each supervoxel a center, a voxel characteristic (for now just simple gaussian)
and set of vectors in diffrent directions basically in polar coordinates
next given a point coordinates in the image we can look at a location value for this point of each supervoxel
so we will take the voctor from the center of the analyzed supervoxel and look for the dot products of this vector with the stored vectors in the supervoxel publish_display_data
we than take either sum or max dot product as the score we than take into account wheather the query vecotr is shorter or longer than stored one
so if it is shorter we give high value if it is not small  
we will have a score for each supervoxel that will mark basically are the coordinates in the shape or not
next we will softmax or sth like that to exaggerate the influence of best fit 
lastly we will multiply by the output values of the haussians - hence all the supervoxels becouse of the low scores should not contribute significantly 
and the single one should

later when wavelet or some other we should also have the way to get the value in a particular spot without instantiating whole image 

note 
in order to avoid collapse of the shape to the point there should be additionall loss function that will maximize the distance between the points within supervoxel

In [36]:
def get_initial_points_around(r):
    """
    given center point will give set of points that will create a verticies of shape around this point
    the first try of this function is a cube
    center_point- jax array with 3 float entries
    r - a float that is controlling the initial shape size if it would be a sphere it would be a radius
    """
    res= list(set(itertools.combinations([-r,0.0,r,-r,0.0,r,-r,0.0,r],3)))
    return jnp.array( list(filter(lambda points: not (points[0]==0.0 and points[1]==0.0 and points[0]==0.0)  ,res)))
     


def get_corrected_dist(query_point_normalized,around,eps,pow=3):
    """
    our goal is to get the information how distant in a query point direction it can be to be still within shape
    so first we need to asses the angle in respect to all of its verticies to know in which direction to look
    next we asses how far in this direction the shape goes from the center (which in this case is the center of coordinate system)
    query_point_normalized: the location of the query points given the shape center is the center of coordianate system
    around - set of points that are the verticies of the shape in the same coordinate system as above
    eps - added for numerical stability
    pow - need to be odd number  the bigger the more exact the calculation the smaller the more stable numerically and simpler to differentiate
    """
    #vmap cosine similarity
    cosine_similarity_multi=jax.vmap(partial(optax.cosine_similarity, epsilon=eps),(0,None))

    angle_similarity= jax.nn.softmax(((cosine_similarity_multi(around,query_point_normalized))+1)**pow)
    dists = jnp.sqrt(jnp.sum(around**2,axis=1))
    corrected_dist= jnp.sum(dists*angle_similarity)#try also max
    return corrected_dist

# dist_query= jnp.sqrt(jnp.sum(query_point_normalized**2))
# corrected_dist=get_corrected_dist(query_point_normalized,around,eps)
# dist_query
# dist_query
# aa=sq_dists*angle_similarity
# # print(around)

# jnp.round(aa,decimals=1)
# # angle_similarity

In [5]:
def soft_less(max_cube,curr,to_pow=90):
    """
    soft version of less operator if max cube is bigger than curr we will get number close to 1
    if curr is bigger than max_cube we will gt close to 0
    to_pow - the bigger the more exact but less numerically stable solution we get
    """
    diff_a =  to_pow**(max_cube-curr)# if it is positive we are inside if negative outside
    diff_b =  to_pow**(curr-max_cube)# if it is positive we are outside if negative inside
    # print(f"max_cube {max_cube}  curr {curr} diff_a {diff_a},diff_b {diff_b} ")
    return (diff_a/(diff_a+diff_b))

In [6]:
#should get big
max_cube=1.2
curr=0.1
bigg=soft_less(max_cube,curr)
#should get small
max_cube=1.2
curr=100
small=soft_less(max_cube,curr,to_pow=90)

print(f"bigg {bigg} small {small}")


bigg 0.9999498068236184 small 0.0


In [7]:

@jax.jit
def get_value_for_point(param_s_vox, value_param,query_point,to_pow,eps ):
    """
    given a point coordinates will return the value on the basis of current supervoxel data that is stored in param_s_vox
    param_s_vox - contain supervoxel parameters
        0-3) supervoxel center
        4-end verticies points (3 float each)
    value_param - parameters required to compute value for points inside the shape
    query_point - point that we are analizing wheather is in the given supervoxel or not
    """
    super_voxel_center=param_s_vox[0,:]
    #coordinates of verticies
    around= param_s_vox[1:,:]
    # around= einops.rearrange(around,'(v c)->v c',c=3)
    #normalizing so we are acting as if the center of supervoxel is the center of coordinate system
    query_point_normalized=query_point-super_voxel_center
    #getting the distance from the center of the shape to ith border in the direction of query point
    corrected_dist=get_corrected_dist(query_point_normalized,around,eps,pow=pow)
    # distance from center of the shepe to query point
    dist_query= jnp.sqrt(jnp.sum(query_point_normalized**2))
    #closer to 1 if query point is in shape closer to 0 if outside
    soft_is_in_shape=soft_less(corrected_dist,dist_query,to_pow=to_pow)
    return value_param*soft_is_in_shape

def get_spaced_super_voxels(r,array_shape):
    
    super_voxel_center= jnp.array([10.0,10.0,10.0 ])
    


In [41]:
shapee= (20,20,20)
r=4.5
xx,yy,zz=list(map(lambda index :jnp.arange(r,shapee[index]-r,r),[0,1,2]))
verticies=get_initial_points_around(r)
centers=list(map(lambda x:list(map( lambda y: list(map(lambda z: jnp.array([x,y,z])    
                    ,zz))
                ,yy))            
            ,xx ))
centers= list(itertools.chain(*centers))
centers= list(itertools.chain(*centers))
centers=jnp.array(centers)
centers.shape #27,3
centers=einops.rearrange(centers,'v c-> v 1 c')
verticies_multi=einops.repeat(verticies,'p c->v p c',v=len(centers) )
shape_param=jnp.concatenate([centers,verticies_multi], axis=1)
shape_param[26,:,:]

Array([[13.5, 13.5, 13.5],
       [ 0. ,  4.5,  4.5],
       [-4.5,  0. ,  4.5],
       [-4.5,  4.5,  0. ],
       [ 4.5,  4.5,  4.5],
       [ 4.5,  0. ,  4.5],
       [ 0. ,  4.5,  0. ],
       [-4.5, -4.5,  4.5],
       [ 0. , -4.5,  4.5],
       [ 4.5, -4.5,  4.5],
       [-4.5,  0. ,  0. ],
       [-4.5,  4.5, -4.5],
       [ 4.5,  4.5,  0. ],
       [ 4.5,  0. ,  0. ],
       [-4.5, -4.5,  0. ],
       [ 0. ,  4.5, -4.5],
       [ 0. , -4.5,  0. ],
       [ 4.5, -4.5,  0. ],
       [-4.5,  0. , -4.5],
       [ 4.5,  4.5, -4.5],
       [-4.5,  4.5,  4.5],
       [ 4.5,  0. , -4.5],
       [-4.5, -4.5, -4.5],
       [ 0. , -4.5, -4.5],
       [ 4.5, -4.5, -4.5]], dtype=float32)

In [9]:
eps=1e-20
pow=3
to_pow=90

super_voxel_center= jnp.array([10.0,10.0,10.0 ])
r=5.0
around=get_initial_points_around(r)

param_s_vox=jnp.concatenate([jnp.expand_dims(super_voxel_center, axis=0),around])



# query_point=jnp.array([6.5,6.5,6.5])
# query_point=jnp.array([180.0,180.0,180.0])
value_param=jnp.array([7.0])

image = jnp.zeros((100,100,100))
indicies= einops.rearrange(jnp.indices(image.shape),'c x y z->(x y z) c')

gvfp_points= jax.vmap(get_value_for_point,(None,None,0,None,None))
# get_value_for_point(param_s_vox, value_param,query_point,to_pow,eps )




# super_voxel_center

In [10]:
image_res=gvfp_points(param_s_vox, value_param,indicies,to_pow,eps )
image_res= jnp.reshape(image_res,image.shape )
image_res = sitk.GetImageFromArray(image_res)
sitk.WriteImage(image_res,"/workspaces/Jax_cuda_med/data/explore/cube.nii.gz")

In [11]:
aa=jnp.indices((100,100,100))
aa= einops.rearrange(aa,'c x y z->(x y z) c')
aa[110]

Array([ 0,  1, 10], dtype=int32)

In [12]:
# time to save it in a form readable by sitk