In [2]:
import numpy as np
import plotly.graph_objects as go

def create_sphere_plot(directions, threshold_angle=45):
    """
    Create an interactive 3D plot with:
    - Hemisphere surface
    - Arrows showing gradient directions
    - Edges between similar directions
    
    Args:
        directions: numpy array of shape (N, 3) with unit vectors
        threshold_angle: angle in degrees below which directions are connected
    """
    # Create hemisphere surface
    phi = np.linspace(0, np.pi/2, 20)
    theta = np.linspace(0, 2*np.pi, 40)
    phi, theta = np.meshgrid(phi, theta)
    
    x = np.sin(phi) * np.cos(theta)
    y = np.sin(phi) * np.sin(theta)
    z = np.cos(phi)
    
    # Initialize figure
    fig = go.Figure()
    
    # Add hemisphere surface
    fig.add_trace(go.Surface(
        x=x, y=y, z=z,
        opacity=0.3,
        showscale=False
    ))
    
    # Add directions as arrows
    for direction in directions:
        # Arrow start at origin
        x = [0, direction[0]]
        y = [0, direction[1]]
        z = [0, direction[2]]
        
        fig.add_trace(go.Scatter3d(
            x=x, y=y, z=z,
            mode='lines',
            line=dict(color='red', width=4),
            showlegend=False
        ))
        
    # Add connections between similar directions
    for i in range(len(directions)):
        for j in range(i+1, len(directions)):
            # Calculate angle between directions
            angle = np.arccos(np.clip(np.dot(directions[i], directions[j]), -1.0, 1.0))
            angle_deg = np.degrees(angle)
            
            # Connect if angle is below threshold
            if angle_deg < threshold_angle:
                fig.add_trace(go.Scatter3d(
                    x=[directions[i][0], directions[j][0]],
                    y=[directions[i][1], directions[j][1]],
                    z=[directions[i][2], directions[j][2]],
                    mode='lines',
                    line=dict(color='blue', width=1, dash='dash'),
                    opacity=0.5,
                    showlegend=False
                ))
    
    # Update layout for better visualization
    fig.update_layout(
        scene = dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z',
            aspectmode='data'
        ),
        showlegend=False,
        title='Gradient Directions and Connections'
    )
    
    return fig

# Example usage:
# Generate some example directions (replace with your actual directions)
np.random.seed(42)
n_directions = 10
theta = np.random.uniform(0, np.pi/2, n_directions)
phi = np.random.uniform(0, 2*np.pi, n_directions)

directions = np.array([
    [np.sin(t)*np.cos(p), np.sin(t)*np.sin(p), np.cos(t)]
    for t, p in zip(theta, phi)
])

# Create and show plot
fig = create_sphere_plot(directions, threshold_angle=45)
fig.show()

In [3]:
import numpy as np
import plotly.graph_objects as go

def create_sphere_plot(directions, threshold_angle=45):
    """
    Create an interactive 3D plot with:
    - Hemisphere surface
    - Arrows showing gradient directions
    - Edges between similar directions
    
    Args:
        directions: numpy array of shape (N, 3) with unit vectors
        threshold_angle: angle in degrees below which directions are connected
    """
    # Create hemisphere surface
    phi = np.linspace(0, np.pi/2, 20)
    theta = np.linspace(0, 2*np.pi, 40)
    phi, theta = np.meshgrid(phi, theta)
    
    x = np.sin(phi) * np.cos(theta)
    y = np.sin(phi) * np.sin(theta)
    z = np.cos(phi)
    
    # Initialize figure with larger size
    fig = go.Figure()
    
    # Add hemisphere surface
    fig.add_trace(go.Surface(
        x=x, y=y, z=z,
        opacity=0.3,
        showscale=False
    ))
    
    # Add directions as arrows
    for direction in directions:
        # Arrow start at origin
        x = [0, direction[0]]
        y = [0, direction[1]]
        z = [0, direction[2]]
        
        fig.add_trace(go.Scatter3d(
            x=x, y=y, z=z,
            mode='lines',
            line=dict(color='red', width=4),
            showlegend=False
        ))
        
    # Add connections between similar directions
    for i in range(len(directions)):
        for j in range(i+1, len(directions)):
            # Calculate angle between directions
            angle = np.arccos(np.clip(np.dot(directions[i], directions[j]), -1.0, 1.0))
            angle_deg = np.degrees(angle)
            
            # Connect if angle is below threshold
            if angle_deg < threshold_angle:
                fig.add_trace(go.Scatter3d(
                    x=[directions[i][0], directions[j][0]],
                    y=[directions[i][1], directions[j][1]],
                    z=[directions[i][2], directions[j][2]],
                    mode='lines',
                    line=dict(color='blue', width=1, dash='dash'),
                    opacity=0.5,
                    showlegend=False
                ))
    
    # Update layout for better visualization
    fig.update_layout(
        scene = dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z',
            aspectmode='data',
            camera=dict(
                eye=dict(x=1.5, y=1.5, z=1.5),  # Adjusted camera position
                up=dict(x=0, y=0, z=1)
            ),
            # Ensure axes ranges show full hemisphere
            xaxis=dict(range=[-1.2, 1.2]),
            yaxis=dict(range=[-1.2, 1.2]),
            zaxis=dict(range=[-0.2, 1.2])
        ),
        showlegend=False,
        title='Gradient Directions and Connections',
        width=800,  # Increased width
        height=800,  # Increased height
        margin=dict(l=0, r=0, b=0, t=30)  # Reduced margins
    )
    
    return fig

# Example usage:
# Generate some example directions (replace with your actual directions)
np.random.seed(42)
n_directions = 10
theta = np.random.uniform(0, np.pi/2, n_directions)
phi = np.random.uniform(0, 2*np.pi, n_directions)

directions = np.array([
    [np.sin(t)*np.cos(p), np.sin(t)*np.sin(p), np.cos(t)]
    for t, p in zip(theta, phi)
])

# Create and show plot
fig = create_sphere_plot(directions, threshold_angle=45)
fig.show()