In [1]:
from optimizer import StimulusOptimizer
efield_model_path = "/home/mikael/Git/E-field_targeting/subjects/sub0/ID18TMSE-m4-sdROI_LMC_230414103614.mat"
stimulator_path = "/home/mikael/Git/E-field_targeting/subjects/sub0/5coil_example.json"

opt = StimulusOptimizer()
opt.load_stimulator(stimulator_path)
opt.load_efield_model(efield_model_path)

In [2]:
import numpy as np
target = (np.array([-0.030883094,-0.054985330,0.068835363]),np.array([1.0,0.0,0.0]))
pop = opt.run(target)

In [None]:
import plotly.graph_objects as go
import jax.numpy as jnp
jaxArray = jnp.ndarray
from typing import Callable, Dict, Tuple

def e_to_mag(E: jaxArray) -> Tuple[jaxArray, jaxArray, jaxArray]:
    """
    Calculates the magnitude of the E-field at each vertex.
    
    Args:
        E: Combined E-field, shape (n_vertices, 3).
        
    Returns:
        A tuple containing:
        - E_mag: The magnitude at each vertex, shape (n_vertices,).
        - E_max: The maximum magnitude value (scalar).
        - E_max_ind: The index of the vertex with the maximum magnitude.
    """
    # Placeholder: Replace with your implementation.
    E_mag = jnp.linalg.norm(E, axis=1)
    E_max_ind = jnp.argmax(E_mag)
    E_max = E_mag[E_max_ind]
    return E_mag, E_max, E_max_ind

def stimulated_target(E: jaxArray, vertices: jaxArray) -> Tuple[jaxArray, jaxArray]:
    """
    Determines the realized stimulation location and direction from the E-field.
    
    Args:
        E: Combined E-field, shape (n_vertices, 3).
        vertices: Mesh vertices, shape (n_vertices, 3).
        
    Returns:
        A tuple containing:
        - loc: The realized stimulation location, shape (3,).
        - dir: The realized stimulation direction (normalized), shape (3,).
    """
    # Placeholder: Replace with your implementation based on 'Max' or 'WCOG'.
    # This example uses the 'Max' metric for simplicity.
    _, _, loc_i = e_to_mag(E)
    loc = vertices[loc_i, :]
    direction = E[loc_i, :]
    direction_norm = jnp.linalg.norm(direction)
    # Avoid division by zero for null vectors
    dir_normalized = jnp.where(direction_norm > 1e-8, direction / direction_norm, jnp.zeros_like(direction))
    return loc, dir_normalized

def visualize_solution_plotly(optimizer: StimulusOptimizer, solution_weights: jnp.ndarray, optimization_target: tuple):
    """
    Creates a 3D interactive plot of the solution using Plotly, including target and result vectors.
    """
    # --- 1. Calculate E-field and get mesh data ---
    target_pos, target_dir = optimization_target
    normalized_weights = solution_weights / optimizer.stimulator['max_current_slope']
    E_combined = jnp.einsum('c,cvd->vd', normalized_weights, optimizer.efield_set)
    E_mag, E_max, _ = e_to_mag(E_combined)
    E_mag_normalized = np.array(E_mag / (E_max + 1e-8))
    vertices = np.array(optimizer.mesh['vertices'])
    faces = np.array(optimizer.mesh['faces'])

    # --- 2. Determine the realized position and direction ---
    realized_pos, realized_dir = stimulated_target(E_combined, optimizer.mesh['vertices'])
    realized_pos, realized_dir = np.array(realized_pos), np.array(realized_dir)

    # --- 3. Create the Plotly Figure ---
    fig = go.Figure()

    # Add the brain mesh trace
    fig.add_trace(
        go.Mesh3d(
            x=vertices[:, 0], y=vertices[:, 1], z=vertices[:, 2],
            i=faces[:, 0], j=faces[:, 1], k=faces[:, 2],
            intensity=E_mag_normalized,
            colorscale='Viridis',
            colorbar_title='Normalized E-Field Magnitude',
            name='Brain Mesh',
            showscale=True
        )
    )

    # Define arrow properties
    arrow_scale = np.mean(np.ptp(vertices, axis=0)) * 0.2 # Set arrow size relative to mesh size

    # Add the target direction arrow (cone)
    fig.add_trace(
        go.Cone(
            x=[target_pos[0]], y=[target_pos[1]], z=[target_pos[2]],
            u=[target_dir[0]], v=[target_dir[1]], w=[target_dir[2]],
            sizeref=arrow_scale,
            showscale=False,
            colorscale=[[0, 'red'], [1, 'red']], # Solid color
            anchor='tip', # Arrow tip starts at the target position
            name='Target Direction'
        )
    )

    # Add the realized E-field direction arrow (cone)
    fig.add_trace(
        go.Cone(
            x=[realized_pos[0]], y=[realized_pos[1]], z=[realized_pos[2]],
            u=[realized_dir[0]], v=[realized_dir[1]], w=[realized_dir[2]],
            sizeref=arrow_scale,
            showscale=False,
            colorscale=[[0, 'cyan'], [1, 'cyan']], # Solid color
            anchor='tail', # Arrow tail starts at the realized position
            name='Realized E-Field Direction'
        )
    )

    # --- 4. Update Layout and Formatting ---
    fig.update_layout(
        title='Optimization Sanity Check: E-Field Magnitude and Direction',
        legend=dict(x=0.7, y=0.9),
        scene=dict(
            xaxis_title='X (m)',
            yaxis_title='Y (m)',
            zaxis_title='Z (m)',
            aspectratio=dict(x=1, y=1, z=1), # Ensure correct aspect ratio
            camera_eye=dict(x=1.2, y=1.2, z=1.2) # Set initial camera angle
        ),
        margin=dict(l=0, r=0, b=0, t=40)
    )

    fig.show()

visualize_solution_plotly(opt,pop,target)