In [26]:
%reload_ext autoreload
%autoreload 2

In [29]:
import sys
sys.path.append('/media/brainstimmaps/DATA/2009_DeepMaps01/04_Source/01_Development/deepmaps/src/datasets')

from proj_dataset import ProjectionDataset

dataset = ProjectionDataset(
        input_folder_vtas=None, 
        input_folder_table=None,
        space='stn_space_3sigma',
        center='merged',
        hemisphere='flipped',
        resolution=500,
        augmentation_type='augmented_full',
        interpolation='step_interp',
        label_threshold=0.6589893416482424,
        normalize_projections=False,
        noise_factor=1, #20
        tuning=False,
        tweening=True,
        n_split=1)

df = dataset.get_df()

39672
Surface noise augmentation x 1.00 factor
center : Bern, len df_center : 20871, len VTAs : 24195


100%|██████████| 3324/3324 [00:05<00:00, 634.29it/s]


center : Cologne, len df_center : 3324, len VTAs : 24195


100%|██████████| 20871/20871 [00:31<00:00, 664.17it/s]


In [22]:
df['patient'].unique()

array([  1,   3,   6,   8,   9,  11,  12,  13,  14,  15,  16,  17,  18,
        19,  20,  21,  22,  23,  25,  26,  27,  28,  29,  31,  34,  35,
        36,  37,  38,  39,  40,  41,  42,  82, 119, 134, 138, 149, 175,
       176, 177, 178, 179, 180, 181, 183, 201, 202, 203, 204, 205, 206,
       207, 208, 211, 214, 215, 216, 217, 218, 219, 220, 222, 223, 224,
       225, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240])

In [30]:
df_f = df[(df['noisy'] == False) & (df['patient'] == 232)]
df_f[['contact', 'amplitude', 'total_voxels']].sort_values(by=['contact', 'amplitude']).head(30)

Unnamed: 0,contact,amplitude,total_voxels
23058,1,0.5,18.0
23059,1,0.6,18.0
23060,1,0.9,49.0
23061,1,1.0,49.0
23062,1,1.1,49.0
23063,1,1.5,101.0
23064,1,1.7,101.0
23065,1,2.0,173.0
23066,1,2.1,173.0
23067,1,2.2,175.0


In [4]:
import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display

#df_loco = df_loco[df_loco['noisy'] == False]

# Grouping the dataframe
grouped_df = df.groupby(['patientID', 'contactID'])

# Extracting the unique group indices
group_indices = grouped_df.groups.keys()
group_indices = sorted(group_indices)

# Extract unique patient IDs and contact IDs for each patient
unique_patient_ids = sorted(df['patientID'].unique())
patient_to_contacts = df.groupby('patientID')['contactID'].unique().to_dict()

# Create widgets
patient_id_widget = widgets.Dropdown(
    options=unique_patient_ids,
    description='Patient ID:'
)

contact_id_widget = widgets.Dropdown(
    options=patient_to_contacts.get(unique_patient_ids[0], []),
    description='Contact ID:'
)

score_type_widget = widgets.Dropdown(
    options=['lin_interp_score', 'step_interp_score'],
    value='lin_interp_score',
    description='Score Type:'
)

# Update function to update contacts based on selected patient
def update_contact_id_widget(*args):
    selected_patient = patient_id_widget.value
    contact_id_widget.options = patient_to_contacts.get(selected_patient, [])

# Attach update function to patient ID widget
patient_id_widget.observe(update_contact_id_widget, 'value')


def plot_group(patient_id, contact_id, score_type):
    # Getting the corresponding group
    group = (patient_id, contact_id)
    data = grouped_df.get_group(group)

    # Extracting the selected score column
    score_column = data[score_type]

    # Setting colors for mapping_score and the selected score type
    mapping_color = 'red'
    score_color = 'blue' if score_type == 'lin_interp_score' else 'green'

    # Sorting the dataframe by amplitude
    data = data.sort_values(by='amplitude')

    a = data['mapping'] != 1
    b = data['noisy'] == True
    # Plotting the data
    plt.figure(figsize=(10, 6))

    # Plotting all points with a semi-transparent line connecting them
    #plt.plot(data['amplitude'], score_column, color='black', linestyle='-', alpha=0.3, zorder=1)

    # Plotting mapping_score points on top in red
    plt.scatter(data[~a]['amplitude'], score_column[~a],
                color=mapping_color, marker='o', label='Mapping', zorder=20)
    # Plotting selected score type points in blue or green
    plt.scatter(data[a & ~b]['amplitude'], score_column[a & ~b],color=score_color, marker='o', label=f'{score_type.capitalize()} Interpolation', zorder=2)
    
    plt.scatter(data[a & b]['amplitude'], score_column[a & b],color='darkgreen', marker='o', label=f'{score_type.capitalize()} Interpolation noisy', zorder=2)

    # Plotting a line from the origin to the first point
    first_point_x = data['amplitude'].iloc[0]
    first_point_y = score_column.iloc[0]
    plt.plot([0, first_point_x], [0, first_point_y], color='black', linestyle='--', alpha=0.5, zorder=1)

    plt.xlabel('Amplitude')
    plt.ylabel(score_type.capitalize())
    plt.title(f'Amplitude vs {score_type.capitalize()} (Group: {group})')
    plt.ylim(-0.1, 1.1)  # Set the y-axis limits between -0.1 and 1.1

    # Adjusting the x-axis range
    x_min = max(0, data['amplitude'].min() - 1)
    x_max = data['amplitude'].max() + 1
    plt.xlim(0, 8.5)

    plt.grid(True, zorder=0)
    plt.legend()
    plt.show()

# Create interactive widget
interactive_plot = widgets.interactive(plot_group, patient_id=patient_id_widget, contact_id=contact_id_widget, score_type=score_type_widget)

# Display widgets
display(patient_id_widget, contact_id_widget, score_type_widget, interactive_plot)


Dropdown(description='Patient ID:', options=(1.0, 3.0, 6.0, 8.0, 9.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0…

Dropdown(description='Contact ID:', options=(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.…

Dropdown(description='Score Type:', options=('lin_interp_score', 'step_interp_score'), value='lin_interp_score…

interactive(children=(Dropdown(description='Patient ID:', options=(1.0, 3.0, 6.0, 8.0, 9.0, 11.0, 12.0, 13.0, …

In [5]:
a = df['patientID'] == 237
b = df['contactID'] == 4
c = df['amplitude'] > 4.5
df[a & b & c].sort_values(by=['amplitude'])

Unnamed: 0,centerID,leadModel,patientID,contactID,verciseID,amplitude,massive_filename,mapping,mapping_score,part,lin_interp_score,step_interp_score,zeroed,tweening,tuning,noisy,original_vta,added_voxels,total_voxels
23655,Cologne,Boston Scientific Vercise,237.0,4.0,4,4.7,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,0.0,,0.0,1.0,1.0,0.0,True,False,False,,0,4320.0
66764,Cologne,Boston Scientific Vercise,237.0,4.0,4,4.7,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,0.0,,0.0,1.0,1.0,0.0,True,False,True,23655.0,1,4321.0
66592,Cologne,Boston Scientific Vercise,237.0,4.0,4,4.7,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,0.0,,0.0,1.0,1.0,0.0,True,False,True,23655.0,1,4321.0
61890,Cologne,Boston Scientific Vercise,237.0,4.0,4,4.7,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,0.0,,0.0,1.0,1.0,0.0,True,False,True,23655.0,0,4320.0
61290,Cologne,Boston Scientific Vercise,237.0,4.0,4,4.7,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,0.0,,0.0,1.0,1.0,0.0,True,False,True,23655.0,1,4321.0
61119,Cologne,Boston Scientific Vercise,237.0,4.0,4,4.7,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,0.0,,0.0,1.0,1.0,0.0,True,False,True,23655.0,3,4323.0
60653,Cologne,Boston Scientific Vercise,237.0,4.0,4,4.7,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,0.0,,0.0,1.0,1.0,0.0,True,False,True,23655.0,0,4320.0
54102,Cologne,Boston Scientific Vercise,237.0,4.0,4,4.7,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,0.0,,0.0,1.0,1.0,0.0,True,False,True,23655.0,0,4320.0
51910,Cologne,Boston Scientific Vercise,237.0,4.0,4,4.7,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,0.0,,0.0,1.0,1.0,0.0,True,False,True,23655.0,0,4320.0
47752,Cologne,Boston Scientific Vercise,237.0,4.0,4,4.7,/media/brainstimmaps/DATA/2009_DeepMaps01/03_D...,0.0,,0.0,1.0,1.0,0.0,True,False,True,23655.0,1,4321.0


In [8]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import ipywidgets as widgets
from IPython.display import display

# Grouping the dataframe
grouped_df = df.groupby(['patientID', 'contactID'])

# Extracting the unique group indices
group_indices = grouped_df.groups.keys()
group_indices = sorted(group_indices)

# Extract unique patient IDs and contact IDs for each patient
unique_patient_ids = sorted(df['patientID'].unique())
patient_to_contacts = df.groupby('patientID')['contactID'].unique().to_dict()

# Create widgets
patient_id_widget = widgets.Dropdown(
    options=unique_patient_ids,
    description='Patient ID:'
)

contact_id_widget = widgets.Dropdown(
    options=patient_to_contacts.get(unique_patient_ids[0], []),
    description='Contact ID:'
)

score_type_widget = widgets.Dropdown(
    options=['lin_interp_score', 'step_interp_score'],
    value='lin_interp_score',
    description='Score Type:'
)

# Update function to update contacts based on selected patient
def update_contact_id_widget(*args):
    selected_patient = patient_id_widget.value
    contact_id_widget.options = patient_to_contacts.get(selected_patient, [])

# Attach update function to patient ID widget
patient_id_widget.observe(update_contact_id_widget, 'value')


def plot_group(patient_id, contact_id, score_type):
    # Getting the corresponding group
    group = (patient_id, contact_id)
    data = grouped_df.get_group(group).sort_values(by='amplitude')

    # Create a color palette dictionary
    palette = {
        True: 'red',  # mapping
        False: 'blue' if score_type == 'lin_interp_score' else 'green'  # score_type
    }

    plt.figure(figsize=(10, 6))

    sns.scatterplot(data=data, x='amplitude', y=score_type, hue='mapping', palette=palette, s=100, alpha=.3)

    # Handle noisy points specifically
    noisy_data = data[data['noisy'] == True]
    plt.scatter(noisy_data['amplitude'], noisy_data[score_type], color='darkgreen', s=100, alpha=.3)

    plt.xlabel('Amplitude')
    plt.ylabel(score_type.capitalize())
    plt.title(f'Amplitude vs {score_type.capitalize()} (Group: {group})')

    plt.ylim(-0.1, 1.1)
    plt.xlim(0, 8.5)
    
    plt.grid(True)
    plt.legend(title='Mapping')
    plt.show()

# Create interactive widget
interactive_plot = widgets.interactive(plot_group, patient_id=patient_id_widget, contact_id=contact_id_widget, score_type=score_type_widget)

# Display widgets
display(patient_id_widget, contact_id_widget, score_type_widget, interactive_plot)


Dropdown(description='Patient ID:', options=(1.0, 3.0, 6.0, 8.0, 9.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0…

Dropdown(description='Contact ID:', options=(0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.…

Dropdown(description='Score Type:', options=('lin_interp_score', 'step_interp_score'), value='lin_interp_score…

interactive(children=(Dropdown(description='Patient ID:', options=(1.0, 3.0, 6.0, 8.0, 9.0, 11.0, 12.0, 13.0, …