In [None]:
import raster_geometry as rg
import numpy as np
from scipy.ndimage import center_of_mass, shift, binary_fill_holes

def within_radius(A_shape, interpolated_com, interpolated_radius):
    # Generate coordinate matrices
    x = np.arange(A_shape[0])
    y = np.arange(A_shape[1])
    z = np.arange(A_shape[2])
    X, Y, Z = np.meshgrid(x, y, z)
    
    # Compute distance matrix
    distance_matrix = np.sqrt((X - interpolated_com[0])**2 + (Y - interpolated_com[1])**2 + (Z - interpolated_com[2])**2)
    
    # Return binary matrix indicating whether each location is within the interpolated radius
    return (distance_matrix <= interpolated_radius).astype(int)

def compute_radius(array, com):
    # Find all non-zero (i.e. value 1) coordinates
    x, y, z = np.where(array)
    max_distance = 0
    # Compute the distance of each point from the center of mass
    for i in range(len(y)):
        distance = np.sqrt((x[i] - com[0])**2 + (y[i] - com[1])**2 + (z[i] - com[2])**2)
        max_distance = max(max_distance, distance)
    return max_distance

def interpolate_sparse_shapes(shape1, shape2, num_frames):
    """
    Interpolate between two sparse, binary 2D numpy arrays over a number of frames.
    """
    
    assert shape1.shape == shape2.shape, "Shapes should have the same dimensions"
    
    # Find centroids
    centroid1 = np.array(center_of_mass(shape1))
    centroid2 = np.array(center_of_mass(shape2))
    
    radius1 = compute_radius(shape1, centroid1)
    radius2 = compute_radius(shape2, centroid2)
    # Vector connecting centroids from 1 to 2
    vector = centroid2 - centroid1

    print(centroid1)
    frames = [shape1]  # Start with the initial shape
    for i in range(1, num_frames+1):
        alpha = i / (num_frames+1)  # Adjusted alpha to make the step proportional
        
        # Shift shapes without wrapping using mode='nearest'
        shifted_shape1 = shift(shape1*(1-alpha), alpha * vector, mode='constant', order=1)
        shifted_shape2 = shift(shape2*alpha, -(1-alpha) * vector, mode='constant', order=1)

        interpolated_radius = alpha * radius2 + (1-alpha) * radius1

        # Combine (OR) the shifted shapes to generate the frame
        frame = shifted_shape1+shifted_shape2

        new_centroid = alpha * centroid2 + (1-alpha) * centroid1
        print(new_centroid)
        mask = within_radius(shape1.shape, new_centroid, interpolated_radius)

        frame = np.heaviside(frame*mask, 0.0)

        #frames.append()
        frames.append(binary_fill_holes(frame))
    print(centroid2)
    frames.append(shape2)  # End with the final shape

    return frames

A = rg.sphere(100, 6, position = (0.2, 0.2, 0.2))

B = rg.sphere(100, 25, position = (0.2, 0.2, 0.6))

frames = interpolate_sparse_shapes(A, B, 10)

import ipywidgets as widgets
from IPython.display import display
import matplotlib.pyplot as plt

# Assume you've already run the interpolate_sparse_shapes function and generated frames

def plot_frame(frame_idx, z):
    plt.imshow(frames[frame_idx][z], cmap='gray')
    plt.axis('off')
    plt.show()

# Create an interactive slider
widgets.interact(
    plot_frame, 
    frame_idx=widgets.IntSlider(min=0, max=len(frames)-1, step=1, value=0), 
    z=widgets.IntSlider(min=0, max=len(frames[0])-1, step=1, value=0))



In [None]:
import pandas as pd
import numpy as np

dataset_name = '/media/brainstimmaps/DATA/2009_DeepMaps01/04_Source/01_Development/deepmaps/data/processed/stn_space_3sigma/merged/flipped/VTAs/250um.npz'

with np.load(dataset_name) as f:
    VTAs = f['arr_0']

df = pd.read_csv('/media/brainstimmaps/DATA/2009_DeepMaps01/04_Source/01_Development/deepmaps/data/raw/tables/stn_space/merged/flipped/table.csv')

assert len(df) == len(VTAs), "VTAs and table not the same length"

df

In [None]:
import random

# Group by 'patientID' and 'contactID'
grouped = df.groupby(['patient', 'contact'])

# Getting the unique group names/labels from the groupby object
group_names = list(grouped.groups.keys())

# Randomly selecting a group name/label
random_group_name = random.choice(group_names)

# Getting the actual group
random_group = grouped.get_group(random_group_name)

random_row_index = random.randint(0, len(random_group) - 2)

# Getting the corresponding row from the group
random_row = random_group.iloc[random_row_index]
print(random_row)

VTA1 = VTAs[random_row_index]
VTA2 = VTAs[random_row_index+1]

In [None]:
import raster_geometry as rg
import numpy as np
from scipy.ndimage import center_of_mass, shift, binary_fill_holes

def gaussian_3d(x, y, z, amplitude, mu_x, mu_y, mu_z, sigma_x, sigma_y, sigma_z, offset):
    return amplitude * np.exp(
        - ((x - mu_x)**2 / (2 * sigma_x**2)
        + (y - mu_y)**2 / (2 * sigma_y**2)
        + (z - mu_z)**2 / (2 * sigma_z**2))
    ) + offset

def within_radius(A_shape, interpolated_com, interpolated_radius):
    # Generate coordinate matrices
    x = np.arange(A_shape[0])
    y = np.arange(A_shape[1])
    z = np.arange(A_shape[2])
    X, Y, Z = np.meshgrid(x, y, z)
    
    # Compute distance matrix
    distance_matrix = np.sqrt((X - interpolated_com[0])**2 + (Y - interpolated_com[1])**2 + (Z - interpolated_com[2])**2)
    
    # Return binary matrix indicating whether each location is within the interpolated radius
    return (distance_matrix <= interpolated_radius).astype(int)

def compute_radius(array, com):
    # Find all non-zero (i.e. value 1) coordinates
    x, y, z = np.where(array)
    max_distance = 0
    # Compute the distance of each point from the center of mass
    for i in range(len(y)):
        distance = np.sqrt((x[i] - com[0])**2 + (y[i] - com[1])**2 + (z[i] - com[2])**2)
        max_distance = max(max_distance, distance)
    return max_distance

def interpolate_sparse_shapes(shape1, shape2, num_frames):
    """
    Interpolate between two sparse, binary 2D numpy arrays over a number of frames.
    """
    
    assert shape1.shape == shape2.shape, "Shapes should have the same dimensions"
    #num_frames = num_frames+2
    #centroids = []

    # Find centroids
    centroid1 = np.array(center_of_mass(shape1))
    centroid2 = np.array(center_of_mass(shape2))
    
    radius1 = compute_radius(shape1, centroid1)
    radius2 = compute_radius(shape2, centroid2)
    #print(radius1, radius2)
    # Vector connecting centroids from 1 to 2
    vector = centroid2 - centroid1

    #centroids.append(centroid1)
    #print(centroid1)
    frames = [shape1]  # Start with the initial shape
    for i in range(1, num_frames+1):
        alpha = i / (num_frames+1)  # Adjusted alpha to make the step proportional
        
        # Shift shapes without wrapping using mode='nearest'
        shifted_shape1 = shift(shape1*(1-alpha), alpha * vector, mode='constant', order=1)
        shifted_shape2 = shift(shape2*alpha, -(1-alpha) * vector, mode='constant', order=1)

        interpolated_radius = np.ceil(alpha * radius2 + (1-alpha) * radius1)
        #print(interpolated_radius)
        # Combine (OR) the shifted shapes to generate the frame
        frame = shifted_shape1+shifted_shape2

        new_centroid = alpha * centroid2 + (1-alpha) * centroid1
        #centroids.append(new_centroid)
        #print(new_centroid)
        mask = within_radius(shape1.shape, new_centroid, interpolated_radius)

        frame = np.heaviside((frame-0.5), 0.0)

        frames.append(frame)
        #frames.append(binary_fill_holes(frame))
    #print(centroid2)
    frames.append(shape2)  # End with the final shape
    #centroids.append(centroid2)
    #frames.pop(1)
    #frames.pop(-2)
    return np.array(frames).astype(np.uint8)#, centroids

frames = interpolate_sparse_shapes(VTA1, VTA2, 4)

for frame in frames:
    print(np.sum(frame))

import ipywidgets as widgets
from IPython.display import display
import matplotlib.pyplot as plt

# Assume you've already run the interpolate_sparse_shapes function and generated frames

def plot_frame(frame_idx, x):
    plt.imshow(frames[frame_idx][x], cmap='gray')
    #plt.scatter(centroids[frame_idx][2], centroids[frame_idx][1], color='red')  # Note the reversed order for scatter
    plt.axis('off')
    plt.show()

# Create an interactive slider
widgets.interact(
    plot_frame, 
    frame_idx=widgets.IntSlider(min=0, max=len(frames)-1, step=1, value=0), 
    x=widgets.IntSlider(min=0, max=len(frames[0])-1, step=1, value=0))


In [None]:
from scipy.optimize import curve_fit
from scipy.ndimage import center_of_mass


def gaussian_3d(x, y, z, amplitude, mu_x, mu_y, mu_z, sigma_x, sigma_y, sigma_z, offset):
    return amplitude * np.exp(
        - ((x - mu_x)**2 / (2 * sigma_x**2)
        + (y - mu_y)**2 / (2 * sigma_y**2)
        + (z - mu_z)**2 / (2 * sigma_z**2))
    ) + offset

def compute_radius(array, com):
    # Find all non-zero (i.e. value 1) coordinates
    x, y, z = np.where(array)
    max_distance = 0
    # Compute the distance of each point from the center of mass
    for i in range(len(y)):
        distance = np.sqrt((x[i] - com[0])**2 + (y[i] - com[1])**2 + (z[i] - com[2])**2)
        max_distance = max(max_distance, distance)
    return max_distance

binary_volume = VTA1

# Extract the coordinates of the white voxels (1's in the binary volume)
x, y, z= np.where(binary_volume == 1)

# Initial guess for the parameters (amplitude, mu_x, mu_y, mu_z, sigma_x, sigma_y, sigma_z, offset)

centroid = np.array(center_of_mass(VTA1))

print(centroid)

radius = compute_radius(VTA1, centroid)

initial_guess = (1, *centroid, radius, radius, radius, 0)

# Flatten the 3D arrays to 1D and fit
xdata = np.vstack((x, y, z))
popt, _ = curve_fit(lambda xdata, amplitude, mu_x, mu_y, mu_z, sigma_x, sigma_y, sigma_z, offset: 
                    gaussian_3d(xdata[0], xdata[1], xdata[2], amplitude, mu_x, mu_y, mu_z, sigma_x, sigma_y, sigma_z, offset), 
                    xdata, np.ones_like(x), p0=initial_guess, method='lm', maxfev=100000)

amplitude, mu_x, mu_y, mu_z, sigma_x, sigma_y, sigma_z, offset = popt
print("amplitude =", amplitude)
print("mu_x =", mu_x, "mu_y =", mu_y, "mu_z =", mu_z)
print("sigma_x =", sigma_x, "sigma_y =", sigma_y, "sigma_z =", sigma_z)
print("offset =", offset)

import matplotlib.pyplot as plt

size_x = VTA1.shape[0]
size_y = VTA1.shape[1]
size_z = VTA1.shape[2]
X, Y, Z = np.meshgrid(np.linspace(0, size_x, size_x), np.linspace(0, size_y, size_y), np.linspace(0, size_z, size_z))

# Create a Gaussian function using the parameters we found
fitted_gaussian = gaussian_3d(X, Y, Z, amplitude, mu_x, mu_y, mu_z, sigma_x, sigma_y, sigma_z, offset)

# Plot a slice of the original binary volume
def plotting_gaussian(z):
    plt.subplot(1, 2, 1)
    plt.title("Original Binary Slice")
    plt.imshow(binary_volume[:, :, z], cmap='gray')
    plt.scatter(centroid[1], centroid[0], color='red')  # Note the reversed order for scatter

    # Plot a slice of the fitted Gaussian
    plt.subplot(1, 2, 2)
    plt.title("Fitted Gaussian Slice")
    plt.imshow(fitted_gaussian[:, :, z], cmap='hot')
    plt.scatter(centroid[1], centroid[0], color='red')  # Note the reversed order for scatter

    plt.show()

# Create an interactive slider
widgets.interact(
    plotting_gaussian, 
    z=widgets.IntSlider(min=0, max=size_z-1, step=1, value=0))


In [None]:
#iterate through all the pairs of linked VTAs ampA->ampA+0.5
#tween 4 frames
#open corresponding table 
#lin interpolate amplitude and scores

In [6]:
import numpy as np
from scipy.ndimage import center_of_mass, shift
import pandas as pd
import numpy as np
from tqdm import tqdm

dataset_name = '/media/brainstimmaps/DATA/2009_DeepMaps01/04_Source/01_Development/deepmaps/data/processed/stn_space_3sigma/merged/flipped/VTAs/1000um.npz'

with np.load(dataset_name) as f:
    VTAs = f['arr_0']

df = pd.read_csv('/media/brainstimmaps/DATA/2009_DeepMaps01/04_Source/01_Development/deepmaps/data/raw/tables/stn_space/merged/flipped/table.csv')

assert len(df) == len(VTAs), "VTAs and table not the same length"

def check_voxel_sum_increase(images):
    """
    Check if the sum of voxel values in each 3D volume in the 
    given array of images is greater than the sum in the previous volume. 
    The array is expected to have shape (n_samples, x_dim, y_dim, z_dim).

    Parameters:
    images (np.ndarray): 4D numpy array containing the 3D image data.

    Returns:
    bool: True if it is NOT the case that the sum of voxel values in 
            each 3D volume is greater than the sum in the previous volume, 
            False otherwise.
    """
    
    # Calculate the sum of voxel values for each 3D volume
    sums = np.sum(images, axis=(1, 2, 3))
    diffs = np.array([sums[i]-sums[i-1] for i in range(1, len(sums))])
    # Compare each sum with the previous one
    for i in range(1, len(sums)):
        if sums[i] < sums[i - 1]:
            print(diffs)
            return True
    return False

def interpolate_sparse_shapes(shape1, shape2, num_frames):
    """
    Interpolate between two sparse, binary 2D numpy arrays over a number of frames.
    """
    
    assert shape1.shape == shape2.shape, "Shapes should have the same dimensions"
    #num_frames = num_frames+2
    #centroids = []

    # Find centroids
    centroid1 = np.array(center_of_mass(shape1))
    centroid2 = np.array(center_of_mass(shape2))
    
    #radius1 = compute_radius(shape1, centroid1)
    #radius2 = compute_radius(shape2, centroid2)
    #print(radius1, radius2)
    # Vector connecting centroids from 1 to 2
    vector = centroid2 - centroid1

    #centroids.append(centroid1)
    #print(centroid1)
    frames = [shape1]  # Start with the initial shape
    for i in range(1, num_frames+1):
        alpha = i / (num_frames+1)  # Adjusted alpha to make the step proportional
        
        # Shift shapes without wrapping using mode='nearest'
        shifted_shape1 = shift(shape1*(1-alpha), alpha * vector, mode='constant', order=1)
        shifted_shape2 = shift(shape2*alpha, -(1-alpha) * vector, mode='constant', order=1)

        #interpolated_radius = np.round(alpha * radius2 + (1-alpha) * radius1)
        #print(interpolated_radius)
        # Combine (OR) the shifted shapes to generate the frame
        frame = shifted_shape1+shifted_shape2

        #new_centroid = alpha * centroid2 + (1-alpha) * centroid1
        #centroids.append(new_centroid)
        #print(new_centroid)
        #mask = within_radius(shape1.shape, new_centroid, interpolated_radius)

        frame = np.heaviside((frame-0.5), 0.0)

        frames.append(frame)
        #frames.append(binary_fill_holes(frame))
    #print(centroid2)
    frames.append(shape2)  # End with the final shape
    #centroids.append(centroid2)
    #frames.pop(1)
    #frames.pop(-2)
    return np.array(frames).astype(np.uint8)#, centroids

newVTAs = []

# Group by 'patientID' and 'contactID'
grouped = df.groupby(['patient', 'contact'])

new_rows = []

# Iterating through the groups
for name, group in tqdm(grouped):
    # Iterating through the rows of the group pairwise
    for i in range(len(group) - 1):
        index1 = group.index[i]
        index2 = group.index[i + 1]

        # Copy the metadata from the original df
        base_row = df.loc[index1].copy()
        sup_row = df.loc[index2].copy()

        #print(f'BASE : {base_row["lin_interp_score"]}  ', end='')

        #print(f"Group {name}: Indices -> {index1}, {index2}")
        VTA1 = VTAs[index1]
        VTA2 = VTAs[index2]

        frames = interpolate_sparse_shapes(VTA1, VTA2, 4)

        #if check_voxel_sum_increase(frames):
        #    print(f'{name} {i*0.5}')
        if i == 0:
            base_row['tweening'] = False
            new_rows.append(base_row)
            
        for j in range(1, 5):  # Assuming 4 interpolated volumes
            new_row = base_row.copy()

            new_amplitude = base_row['amplitude'] + 0.1 * j
            new_row['amplitude'] = new_amplitude

            lin_interp_score = (j * (sup_row['lin_interp_score'] - base_row['lin_interp_score'])/5.0) + base_row['lin_interp_score']
            #print(f'{lin_interp_score} ', end='')
            new_row['lin_interp_score'] = lin_interp_score

            new_row['tweening'] = True

            new_row['mapping'] = 0

            new_row['mapping_score'] = np.nan

            if base_row['part'] == 1 and sup_row['part'] == 0:
                new_row['part'] = 0

            new_rows.append(new_row)

        #print(f'SUP : {sup_row["lin_interp_score"]}')   
        # Append the row corresponding to index2 (the "next" volume in the sequence)
        last_row = df.loc[index2].copy()
        last_row['tweening'] = False
        new_rows.append(last_row)

        newVTAs.extend(frames[:-1])
    newVTAs.append(VTA2)


# verified : all voxels are [0, 1] and integers (binary image)
# so conversion to unsigned int8 is legit
newVTAs = np.array(newVTAs).astype(np.uint8)

# Create a new DataFrame from the list of new rows
new_df = pd.DataFrame(new_rows)

# Reset the index of the new DataFrame
new_df.reset_index(drop=True, inplace=True)

# Validate the length
assert len(new_df) == len(newVTAs), "newVTAs and new table not the same length"

#10m10

  results = [sum(input * grids[dir].astype(float), labels, index) / normalizer
  results = [sum(input * grids[dir].astype(float), labels, index) / normalizer
  results = [sum(input * grids[dir].astype(float), labels, index) / normalizer
  results = [sum(input * grids[dir].astype(float), labels, index) / normalizer
  results = [sum(input * grids[dir].astype(float), labels, index) / normalizer
  results = [sum(input * grids[dir].astype(float), labels, index) / normalizer
  results = [sum(input * grids[dir].astype(float), labels, index) / normalizer
  results = [sum(input * grids[dir].astype(float), labels, index) / normalizer
  results = [sum(input * grids[dir].astype(float), labels, index) / normalizer
  results = [sum(input * grids[dir].astype(float), labels, index) / normalizer
  results = [sum(input * grids[dir].astype(float), labels, index) / normalizer
  results = [sum(input * grids[dir].astype(float), labels, index) / normalizer
  results = [sum(input * grids[dir].astype(float), l

In [7]:
print(len(newVTAs), len(new_df))

39687 39687


In [8]:
np.savez_compressed(
    '/media/brainstimmaps/DATA/2009_DeepMaps01/04_Source/01_Development/deepmaps/data/processed/stn_space_3sigma_subtracted_tweened/merged/flipped/VTAs/1000um.npz', 
    newVTAs)


In [9]:
new_df.to_csv('/media/brainstimmaps/DATA/2009_DeepMaps01/04_Source/01_Development/deepmaps/data/processed/stn_space_3sigma_subtracted_tweened/merged/flipped/VTAs/1000um_table.csv', index=False)

In [23]:
new_df.tail(50)

Unnamed: 0,centerID,leadModel,patientID,contactID,verciseID,amplitude,massive_filename,mapping,mapping_score,part,lin_interp_score,step_interp_score,zeroed,tweening
39637,Cologne,Boston Scientific Vercise,240.0,2.0,2,4.7,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,0.0,,0.0,0.9,0.666667,0.0,True
39638,Cologne,Boston Scientific Vercise,240.0,2.0,2,4.8,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,0.0,,0.0,0.933333,0.666667,0.0,True
39639,Cologne,Boston Scientific Vercise,240.0,2.0,2,4.9,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,0.0,,0.0,0.966667,0.666667,0.0,True
39640,Cologne,Boston Scientific Vercise,240.0,2.0,2,5.0,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,0.0,,0.0,1.0,0.666667,0.0,False
39641,Cologne,Boston Scientific Vercise,240.0,3.0,3,0.5,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,0.0,,1.0,0.083333,0.0,0.0,False
39642,Cologne,Boston Scientific Vercise,240.0,3.0,3,0.6,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,0.0,,1.0,0.1,0.0,0.0,True
39643,Cologne,Boston Scientific Vercise,240.0,3.0,3,0.7,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,0.0,,1.0,0.116667,0.0,0.0,True
39644,Cologne,Boston Scientific Vercise,240.0,3.0,3,0.8,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,0.0,,1.0,0.133333,0.0,0.0,True
39645,Cologne,Boston Scientific Vercise,240.0,3.0,3,0.9,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,0.0,,1.0,0.15,0.0,0.0,True
39646,Cologne,Boston Scientific Vercise,240.0,3.0,3,1.0,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,0.0,,1.0,0.166667,0.0,0.0,False


In [None]:
df.head()

In [None]:
newVTAs = np.array(newVTAs)

len(newVTAs)

In [None]:



import ipywidgets as widgets
from IPython.display import display
import matplotlib.pyplot as plt

# Assume you've already run the interpolate_sparse_shapes function and generated frames

def plot_frame(frame_idx, x):
    plt.imshow(frames[frame_idx][x], cmap='gray')
    #plt.scatter(centroids[frame_idx][2], centroids[frame_idx][1], color='red')  # Note the reversed order for scatter
    plt.axis('off')
    plt.show()

# Create an interactive slider
widgets.interact(
    plot_frame, 
    frame_idx=widgets.IntSlider(min=0, max=len(frames)-1, step=1, value=0), 
    x=widgets.IntSlider(min=0, max=len(frames[0])-1, step=1, value=0))
