# Atomic Orbital Visualization

This interactive script visualizes atomic orbitals by:

- Computing and displaying isosurfaces of the wavefunction using the marching cubes algorithm.
- Generating contour plots to explore cross-sections of the orbital at different heights.
- Allowing users to dynamically adjust the quantum numbers (n, l, m) and height to explore different molecular orbitals.

**No need to review all the code!** You can scroll down and start playing with the interactive widgets to visualize different orbitals immediately.

In [1]:
from pytessel import PyTessel
import numpy as np
import trimesh
import pythreejs as p3
import ipywidgets as widgets
from ipywidgets import HBox, Output
from IPython.display import display
import matplotlib.pyplot as plt
import re
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

import numpy as np
from math import factorial
from scipy.special import assoc_laguerre
from scipy.special import lpmv
from scipy.special import sph_harm
from scipy.special import sph_harm_y

def wfcart(n,l,m,x,y,z):
    """
    Construct the wave function using Cartesian coordinates
    
    n : pritimive quantum number
    l : azimuthal quantum number
    m : magnetic quantum number
    r : radius
    theta : azimuthal angle
    phi : polar angle
    """
    r = np.linalg.norm([x,y,z], axis=0)
    theta = np.arctan2(y,x)
    phi = np.arccos(z/r)

    return wf(n,l,m,r,theta,phi)

def wf(n,l,m,r,theta,phi):
    """
    Construct the wave function using spherical coordinates
    
    n : pritimive quantum number
    l : azimuthal quantum number
    m : magnetic quantum number
    r : radius
    theta : azimuthal angle
    phi : polar angle
    """
    return radial(n,l,r) * angular(l,m,theta,phi)
    
def angular(l,m,theta,phi):
    """
    Construct the angular part of the wave function
    
    l : azimuthal quantum number
    m : magnetic quantum number
    theta : azimuthal angle
    phi : polar angle
    """
    # see: https://en.wikipedia.org/wiki/Spherical_harmonics#Real_form
    #
    # this create so-called Tesseral spherical harmonics
    #
    if m == 0:
        return np.real(sph_harm_y(l,m,phi,theta))
    elif m < 0:
        return np.real(1j / np.sqrt(2) * (sph_harm_y(l,m,phi,theta) - (-1)**m * sph_harm_y(l,-m,phi,theta)))
    elif m > 0:
        return np.real(1 / np.sqrt(2) * (sph_harm_y(l,-m,phi,theta) + (-1)**m * sph_harm_y(l,m,phi,theta)))

def radial(n,l,r):
    """
    This is the formulation for the radial wave function as encountered in
    Griffiths "Introduction to Quantum Mechanics 3rd edition"
    
    n : pritimive quantum number
    l : azimuthal quantum number
    r : radius
    """
    n = int(n)
    l = int(l)
    a = 1.0
    rho = 2.0 * r / (n * a)
    val =  np.sqrt((2.0 / (n * a))**3) * \
           np.sqrt(factorial(n - l - 1) / (2 * n * factorial(n + l))) * \
           np.exp(-0.5 * rho) * \
           rho**l * \
           assoc_laguerre(rho, n-l-1, 2*l+1)

    return val

def generate_isosurfaces(n,l,m,sz,res,name):
    pytessel = PyTessel()

    # generate scalar field and tesselate it
    x = np.linspace(-sz, sz, res)
    xx, yy, zz = np.meshgrid(x,x,x)

    field = wfcart(n,l,m,xx,yy,zz)
    unitcell = np.diag(np.ones(3) * 3 * 2)
    pytessel = PyTessel()
    isovalue = 0.01
    vertices, normals, indices = pytessel.marching_cubes(field.flatten(), field.shape, unitcell.flatten(), isovalue)
    pytessel.write_ply('%s_pos.ply' % name, vertices, normals, indices)
    vertices, normals, indices = pytessel.marching_cubes(field.flatten(), field.shape, unitcell.flatten(), -isovalue)
    pytessel.write_ply('%s_neg.ply' % name, vertices, normals, indices)

def visualize_mesh(name, yheight=0):
    mesh_pos = trimesh.load_mesh("%s_pos.ply" % name)
    mesh_neg = trimesh.load_mesh("%s_neg.ply" % name)

    geometry_pos = p3.BufferGeometry(
        attributes={
            "position": p3.BufferAttribute(mesh_pos.vertices.astype(np.float32), normalized=False),
            "index": p3.BufferAttribute(mesh_pos.faces.astype(np.uint32).ravel(), normalized=False)
        }
    )
    
    geometry_neg = p3.BufferGeometry(
        attributes={
            "position": p3.BufferAttribute(mesh_neg.vertices.astype(np.float32), normalized=False),
            "index": p3.BufferAttribute(mesh_neg.faces.astype(np.uint32).ravel(), normalized=False),
        }
    )

    solid_material_pos = p3.MeshStandardMaterial(
        color="#276419",
        side="DoubleSide"
    )
    solid_material_neg = p3.MeshStandardMaterial(
        color="#8e0152",
        side="DoubleSide"
    )

    mesh_object_pos = p3.Mesh(geometry=geometry_pos, material=solid_material_pos)
    mesh_object_neg = p3.Mesh(geometry=geometry_neg, material=solid_material_neg)
    mesh_object_pos.rotateY(-np.pi/2)
    mesh_object_neg.rotateY(-np.pi/2)

    # Wireframe Overlay
    wireframe_material = p3.LineBasicMaterial(color="black", linewidth=1.0)  # Thin black lines
    wireframe_pos = p3.LineSegments(
        p3.EdgesGeometry(geometry_pos),  # Extracts edges from geometry
        wireframe_material
    )
    wireframe_neg = p3.LineSegments(
        p3.EdgesGeometry(geometry_neg),  # Extracts edges from geometry
        wireframe_material
    )
    wireframe_pos.rotateY(-np.pi/2)
    wireframe_neg.rotateY(-np.pi/2)

    # Lighting setup
    ambient_light = p3.AmbientLight(color="white", intensity=4)

    # sampling plane for demo
    plane_geometry = p3.PlaneGeometry(6,6)  # Width, Height of the plane
    plane_material = p3.MeshStandardMaterial(color="#222222", side="DoubleSide", opacity=0.5, transparent=True)
    plane = p3.Mesh(geometry=plane_geometry, material=plane_material)

    # Position & Rotate the Plane
    plane.position = [0, yheight/10, 0]  # Center it at the origin
    plane.rotateX(np.pi/2)
    
    # Create Scene
    scene = p3.Scene(children=[mesh_object_pos, 
                               mesh_object_neg, 
                               wireframe_pos, 
                               wireframe_neg, 
                               plane,
                               ambient_light, 
                               p3.AxesHelper(size=5)
                              ])
    camera = p3.PerspectiveCamera(position=[5, 5, 5], fov=50)
    controller = p3.OrbitControls(controlling=camera)

    renderer = p3.Renderer(
        scene=scene, camera=camera, controls=[controller], 
        width=512, height=512, antialias=True
    )
    
    return renderer

def generate_contour_plot(n, l, m, sz ,res, height):
    fig, ax = plt.subplots(figsize=(5, 5))
    x = np.linspace(-sz, sz, res)
    r1, r2 = np.meshgrid(x,x)
    psi = wfcart(n,l,m, np.ones_like(r1) * height, r1, r2)
    
    v = 0.01
    ax.contour(x, x, psi, colors='black', levels=15, vmin=-v, vmax=v)
    c = ax.contourf(x, x, psi, cmap=plt.get_cmap('PiYG'), levels=15,
                    vmin=-v, vmax=v)
    
    ax.set_xlabel('x [a.u.]')
    ax.set_ylabel('z [a.u.]')
    ax.axis("square")
    
    plt.grid(linestyle='--', color='black', alpha=0.5)
    plt.xlim(-sz,sz)
    plt.ylim(-sz,sz)
    plt.colorbar(c)
    
    return fig, ax

def parse_orbital(orb):
    match = re.findall(r"-?\d+", orb)  # Extract all numbers
    return tuple(map(int, match))  # Convert to integers and return as tuple

def update_plot(N, orb):
    name = 'dz2'
    n, l, m = parse_orbital(orb)
    print(n,l,m)
    fig, ax = generate_contour_plot(n, l, m, 15, 50, N)
    generate_isosurfaces(n, l, m, 15, 50, name)
    renderer = visualize_mesh(name, yheight=N)
    out = Output()
    with out:
        plt.show(fig)  # Display the figure inside the Output widget
    display(HBox([renderer, out]))

orbitals = []
for n in range(1, 4):  # Principal quantum number (1, 2, 3)
    for l in range(n):  # Azimuthal quantum number (0 to n-1)
        for m in range(-l, l + 1):  # Magnetic quantum number (-l to +l)
            orbitals.append(f"n={n}, l={l}, m={m}")

# Create an interactive dropdown menu for orbitals
orbital_dropdown = widgets.Dropdown(
    options=orbitals,
    value="n=3, l=2, m=0",  # Default selection (1s orbital)
    description="Orbital:"
)
    
# Create an interactive slider to vary the number of basis functions
N_slider = widgets.IntSlider(
    min=-30, max=30, step=1, value=0,
    description="Height"
)

# Link slider to update function and display interactive UI
widgets.interactive(update_plot, N=N_slider, orb=orbital_dropdown)

interactive(children=(IntSlider(value=0, description='Height', max=30, min=-30), Dropdown(description='Orbital…