<a href="https://colab.research.google.com/github/general-molecular-simulations/so3lr/blob/main/examples/so3lr_colab_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#@title Install SO3LR
#@markdown ## Change runtime type to GPU (if available), then click run. Installation takes ~2 minutes.
import os
import sys

_is_so3lr_setup_file = '/content/.SO3LR_SETUP'

if not os.path.exists(_is_so3lr_setup_file):
  # Install Miniconda with Python 3.12
  os.system('wget -q https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh')
  os.system('bash Miniconda3-latest-Linux-x86_64.sh -b -f -p /usr/local')
  os.system('conda install -y -q --prefix /usr/local python=3.12')

  # Install JAX (try GPU first, fall back to CPU)
  os.system('uv pip install --upgrade pip')
  if os.system('uv pip install "jax[cuda12]==0.5.3"') != 0:
    os.system('uv pip install jax==0.5.3')

  # Install dependencies
  os.system('uv pip install ase py3Dmol') #h5py py3Dmol

  # Update paths
  sys.path.append('/usr/local/lib/python3.12/site-packages/')
  os.environ['CONDA_PREFIX'] = '/usr/local/'
  os.environ['CONDA_DEFAULT_ENV'] = 'base'

  # Install SO3LR
  os.system('git clone https://github.com/general-molecular-simulations/so3lr.git /content/so3lr')
  os.system('cd /content/so3lr && uv pip install .')

  # Cleanup
  os.system(f"touch {_is_so3lr_setup_file}")
  os.unlink('Miniconda3-latest-Linux-x86_64.sh')
else:
  # Just verify imports
  import jax
  import so3lr
  print(f"SO3LR ready. JAX backend: {jax.default_backend()}")



In [4]:
#@title Run SO3LR Simulation
#@markdown ## Choose simulation type and parameters

import os
import sys
import subprocess
import tempfile
import glob
from IPython.display import display, HTML, clear_output, FileLink
from google.colab import files
import ipywidgets as widgets

# Check if SO3LR is installed and available
try:
    import so3lr
    so3lr_available = True
except ImportError:
    so3lr_available = False
    print("SO3LR not properly installed. Please run the first cell and restart the runtime.")

# Function to upload files
def upload_input_file():
    clear_output(wait=True)
    print("Uploading input file...")
    uploaded = files.upload()
    if uploaded:
        filename = list(uploaded.keys())[0]
        print(f"Successfully uploaded: {filename}")
        return filename
    else:
        print("No file was uploaded.")
        return None

# Function to check if file exists and get example files if needed
def check_input_file(filename):
    if not filename or not os.path.exists(filename):
        print(f"Input file '{filename}' not found.")
        return False
    return True

# Function to run SO3LR command with error handling
def run_so3lr_command(command):
    print(f"Running command: {' '.join(command)}")
    print("\n" + "-" * 70)
    print("Starting simulation...\n")

    try:
        # Run the command and capture output in real-time
        process = subprocess.Popen(
            command,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            universal_newlines=True,
            bufsize=1
        )

        # Display output as it becomes available
        for line in process.stdout:
            print(line, end='')

        # Wait for process to finish
        exit_code = process.wait()

        if exit_code != 0:
            print(f"\nError: SO3LR exited with code {exit_code}")
            return False
        else:
            print("\nSimulation completed successfully!")
            return True

    except Exception as e:
        print(f"\nError running SO3LR: {str(e)}")
        return False

# Define the simulation handler
def run_simulation():
    simulation_type = simulation_selector.value
    input_file = input_file_entry.value.strip()

    # Upload file if requested
    if upload_button_pressed:
        input_file = upload_input_file()
        if input_file:
            input_file_entry.value = input_file

    # Check if input file exists
    if not check_input_file(input_file):
        return False

    # Prepare output file name
    if not output_file_entry.value:
        # Default output name based on input file
        output_file = os.path.splitext(input_file)[0] + f"_{simulation_type}.xyz"
    else:
        output_file = output_file_entry.value

    # Construct the SO3LR command
    command = ["so3lr", simulation_type, "--input", input_file, "--output", output_file]

    # Add common parameters
    if lr_cutoff_entry.value != 12.0:
        command.extend(["--lr-cutoff", str(lr_cutoff_entry.value)])

    if total_charge_entry.value != 0:
        command.extend(["--total-charge", str(total_charge_entry.value)])

    # Add type-specific parameters
    if simulation_type == "opt":
        if force_conv_entry.value:
            command.extend(["--force-conv", str(force_conv_entry.value)])
        else:
            command.extend(["--min-cycles", str(min_cycles_entry.value)])
            command.extend(["--min-steps", str(min_steps_entry.value)])

    elif simulation_type in ["nvt", "npt"]:
        command.extend(["--temperature", str(temperature_entry.value)])
        command.extend(["--dt", str(dt_entry.value)])
        command.extend(["--md-cycles", str(md_cycles_entry.value)])
        command.extend(["--md-steps", str(md_steps_entry.value)])
        command.extend(["--seed", str(seed_entry.value)])

        if simulation_type == "npt":
            command.extend(["--pressure", str(pressure_entry.value)])

    elif simulation_type == "eval":
        # For eval, input file is treated as datafile
        command = ["so3lr", "eval", "--datafile", input_file]
        if output_file:
            command.extend(["--save-to", output_file])

    # Run the command
    success = run_so3lr_command(command)

    # Display download link for output if successful
    if success and os.path.exists(output_file):
        print(f"\nOutput file created: {output_file}")

    return success

# Create UI elements
# Simulation type selector
simulation_selector = widgets.Dropdown(
    options=[
        ('Geometry Optimization', 'opt'),
        ('NVT Molecular Dynamics', 'nvt'),
        ('NPT Molecular Dynamics', 'npt'),
        ('Model Evaluation', 'eval')
    ],
    value='opt',
    description='Simulation:',
    disabled=not so3lr_available
)

# Common parameters
input_file_entry = widgets.Text(description='Input file:', placeholder='molecule.xyz')
output_file_entry = widgets.Text(description='Output file:', placeholder='Leave blank for auto-naming')
lr_cutoff_entry = widgets.FloatText(value=12.0, description='LR cutoff (Å):')
total_charge_entry = widgets.IntText(value=0, description='Total charge:')
seed_entry = widgets.IntText(value=0, description='Random seed:', tooltip='Set to 0 for random seed, or specify an integer for reproducible results')

# Optimization parameters
force_conv_entry = widgets.FloatText(value=0.05, description='Force conv (eV/Å):')
min_cycles_entry = widgets.IntText(value=10, description='Min cycles:')
min_steps_entry = widgets.IntText(value=10, description='Steps per cycle:')

# MD parameters
temperature_entry = widgets.FloatText(value=300.0, description='Temperature (K):')
dt_entry = widgets.FloatText(value=0.5, description='Time step (fs):')
md_cycles_entry = widgets.IntText(value=100, description='MD cycles:')
md_steps_entry = widgets.IntText(value=100, description='Steps per cycle:')
pressure_entry = widgets.FloatText(value=1.0, description='Pressure (atm):')

# Upload button with flag to track when it's pressed
upload_button_pressed = False
def on_upload_button_clicked(b):
    global upload_button_pressed
    upload_button_pressed = True
    run_simulation()
    upload_button_pressed = False

upload_button = widgets.Button(description="Upload & Run", button_style='info', disabled=not so3lr_available)
upload_button.on_click(on_upload_button_clicked)

# Run button
run_button = widgets.Button(description="Run Simulation", button_style='success', disabled=not so3lr_available)
run_button.on_click(lambda b: run_simulation())

# Function to update visible parameters based on simulation type
def update_parameters_visibility(change):
    sim_type = change['new']

    # Hide all parameter groups
    for param in [force_conv_entry, min_cycles_entry, min_steps_entry,
                  temperature_entry, dt_entry, md_cycles_entry, md_steps_entry, pressure_entry]:
        param.layout.display = 'none'

    # Show relevant parameters based on simulation type
    if sim_type == 'opt':
        force_conv_entry.layout.display = 'flex'
        min_cycles_entry.layout.display = 'flex'
        min_steps_entry.layout.display = 'flex'
    elif sim_type in ['nvt', 'npt']:
        temperature_entry.layout.display = 'flex'
        dt_entry.layout.display = 'flex'
        md_cycles_entry.layout.display = 'flex'
        md_steps_entry.layout.display = 'flex'
        if sim_type == 'npt':
            pressure_entry.layout.display = 'flex'

# Connect the function to the simulation_selector
simulation_selector.observe(update_parameters_visibility, names='value')

# Initialize parameter visibility
update_parameters_visibility({'new': simulation_selector.value})

# Create and display UI layout
ui = widgets.VBox([
    widgets.HTML("<h3>SO3LR Simulation Parameters</h3>"),
    simulation_selector,
    widgets.HBox([input_file_entry]),
    widgets.HBox([output_file_entry]),
    widgets.HBox([lr_cutoff_entry, total_charge_entry]),
    widgets.HBox([seed_entry]),  # Added seed parameter in its own row
    force_conv_entry, min_cycles_entry, min_steps_entry,
    temperature_entry, dt_entry, md_cycles_entry, md_steps_entry, pressure_entry,
    widgets.HBox([upload_button, run_button])
])

if so3lr_available:
    display(ui)
else:
    print("Please run the first cell before trying to run simulations.")

Uploading input file...


Saving ala15_folded.xyz to ala15_folded.xyz
Successfully uploaded: ala15_folded.xyz
Running command: so3lr nvt --input ala15_folded.xyz --output ala15_folded_nvt.xyz --lr-cutoff 100.0 --temperature 300.0 --dt 0.5 --md-cycles 200 --md-steps 1000 --seed 0

----------------------------------------------------------------------
Starting simulation...

  """Periodic boundary conditions on a parallelepiped.
  """Computes the pair correlation function at a mesh of distances.
  """Computes the pair correlation function at a mesh of distances.
  """Computes the phop indicator of rearrangements.
  """Helper function to simulate a Nose-Hoover Chain coupled to a system.
  """Code to calculate the elastic modulus tensor for athermal systems.
  """From Miller III et al., S(q) is defined so that \dot q = 1/2S(q)\omega.
  """Convert from the conjugate momentum of a quaternion to angular momentum.

  ███████╗ ██████╗ ██████╗ ██╗     ██████╗ 
  ██╔════╝██╔═══██╗╚════██╗██║     ██╔══██╗
  ███████╗██║   █

In [5]:
#@title Visualize XYZ Trajectory

import os
import glob
import numpy as np
from IPython.display import display, clear_output
import ipywidgets as widgets

# Move imports inside try/except to properly handle missing dependencies
try:
    import py3Dmol
    import ase.io
    visualization_available = True
except ImportError:
    visualization_available = False
    print("Visualization tools not available. Please install required packages:")
    print("!pip install py3Dmol ase")

def display_trajectory(trajectory_file, width=800, height=500):
    """
    Display an XYZ trajectory file.

    Args:
        trajectory_file (str): Path to the XYZ trajectory file
        width (int): Width of the viewer
        height (int): Height of the viewer
    """
    if not visualization_available:
        print("Visualization tools not available. Please install ASE and py3Dmol first.")
        return

    if not os.path.exists(trajectory_file):
        print(f"File not found: {trajectory_file}")
        return

    # Read the trajectory
    try:
        atoms_list = list(ase.io.read(trajectory_file, index=':'))
        if not atoms_list:
            print("No atoms found in the trajectory file.")
            return

        print(f"Successfully loaded trajectory with {len(atoms_list)} frames")

        # Get all frames as a single multi-model XYZ file
        from io import StringIO
        combined_xyz = StringIO()

        for atoms in atoms_list:
            ase.io.write(combined_xyz, atoms, format='xyz')

        xyz_data = combined_xyz.getvalue()

        # Create the viewer
        view = py3Dmol.view(width=width, height=height)

        # Add the multi-frame XYZ data with multimodel=True
        view.addModel(xyz_data, 'xyz', {'multimodel': True})

        # Set style for all atoms
        view.setStyle({'sphere': {'radius': 0.4}, 'stick': {'radius': 0.2}})

        # Configure viewer
        view.zoomTo()
        view.setBackgroundColor('white')

        # Set up animation controls
        if len(atoms_list) > 1:
            view.animate({'loop': 'forward', 'interval': 100})
            print(f"Animation ready with {len(atoms_list)} frames")
        else:
            print("Only one frame available - no animation needed")

        # Show the viewer
        view.show()

        return view  # Return the view object for further manipulation

    except Exception as e:
        print(f"Error displaying trajectory: {str(e)}")
        import traceback
        traceback.print_exc()

# File selection UI components
trajectory_file_selector = widgets.Text(
    description='Trajectory file:',
    placeholder='Enter path to XYZ trajectory file',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='70%')
)

# Add file upload widget for direct uploads
file_upload = widgets.FileUpload(
    accept='.xyz',
    multiple=False,
    description='Upload XYZ:',
    style={'description_width': 'initial'}
)

def on_file_uploaded(change):
    if not change.new:
        return

    # Get the uploaded file content
    file_info = next(iter(change.new.values()))
    filename = file_info['metadata']['name']
    content = file_info['content']

    # Save the uploaded file
    with open(filename, 'wb') as f:
        f.write(content)

    # Update the file selector
    trajectory_file_selector.value = filename
    print(f"Uploaded and saved: {filename}")

file_upload.observe(on_file_uploaded, names='value')

# Browse button for file selection
def on_browse_clicked(b):
    # Store current UI state
    current_ui = widgets.VBox([
        widgets.HTML("<h3>Browsing XYZ Files</h3>")
    ])

    # Clear output but keep the browse UI visible
    clear_output(wait=True)

    # Look for XYZ files in the current directory and some common locations
    xyz_files = glob.glob("*.xyz")

    if xyz_files:
        # Create selection dropdown
        file_dropdown = widgets.Dropdown(
            options=[(os.path.basename(f), f) for f in xyz_files],
            description='Select file:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='70%')
        )

        # Add back button
        back_button = widgets.Button(
            description='Back',
            button_style='warning',
            tooltip='Go back to main view'
        )

        def on_back_clicked(b):
            clear_output(wait=True)
            display(file_ui)

        back_button.on_click(on_back_clicked)

        # Add select button
        select_button = widgets.Button(
            description='Use Selected File',
            button_style='success',
            tooltip='Use this file'
        )

        def on_select_clicked(b):
            selected_file = file_dropdown.value
            if selected_file:
                trajectory_file_selector.value = selected_file
                clear_output(wait=True)
                print(f"Selected file: {selected_file}")
                display(file_ui)

        select_button.on_click(on_select_clicked)

        # Display the file browser UI
        browser_ui = widgets.VBox([
            widgets.HTML("<h3>Select a trajectory file:</h3>"),
            widgets.HTML(f"<p>Found {len(xyz_files)} XYZ files:</p>"),
            file_dropdown,
            widgets.HBox([back_button, select_button])
        ])

        display(browser_ui)
    else:
        print("No XYZ files found. Please enter the path manually or upload a file.")
        display(file_ui)

browse_button = widgets.Button(
    description='Browse Files',
    button_style='info',
    tooltip='Browse for XYZ trajectory files'
)
browse_button.on_click(on_browse_clicked)

# Display button
display_button = widgets.Button(
    description='Display Trajectory',
    button_style='success',
    tooltip='Display the trajectory'
)

def on_display_clicked(b):
    trajectory_file = trajectory_file_selector.value.strip()
    if not trajectory_file:
        print("Please select a trajectory file first.")
        return

    # Check if file exists
    if not os.path.exists(trajectory_file):
        print(f"Error: File '{trajectory_file}' not found. Please verify the path.")
        return

    # Clear any previous messages
    clear_output(wait=True)
    print(f"Loading trajectory from: {trajectory_file}")

    # Display the trajectory
    view = display_trajectory(trajectory_file)

    # Add button to go back to the file selection UI
    back_button = widgets.Button(
        description='Back to File Selection',
        button_style='info',
        tooltip='Return to file selection'
    )

    def on_back_to_selection(b):
        clear_output(wait=True)
        display(file_ui)

    back_button.on_click(on_back_to_selection)
    display(back_button)

display_button.on_click(on_display_clicked)

# Create the UI
file_ui = widgets.VBox([
    widgets.HTML("<h3>Trajectory Viewer</h3>"),
    widgets.HBox([trajectory_file_selector, browse_button]),
    file_upload,
    display_button
])

# Check if required packages are available and initialize the UI
if visualization_available:
    display(file_ui)
else:
    print("Required visualization packages are not available.")
    print("Please install the required packages with:")
    print("!pip install py3Dmol ase")

    # Create install button for convenience
    install_button = widgets.Button(
        description='Install Required Packages',
        button_style='danger'
    )

    def on_install_clicked(b):
        from IPython.display import display, HTML
        display(HTML("<div>Installing packages...</div>"))
        !pip install py3Dmol ase
        display(HTML("<div>Packages installed. Please restart the runtime and run this cell again.</div>"))

    install_button.on_click(on_install_clicked)
    display(install_button)

Loading trajectory from: ala15_folded_nvt.xyz
Successfully loaded trajectory with 211 frames
Animation ready with 211 frames


Button(button_style='info', description='Back to File Selection', style=ButtonStyle(), tooltip='Return to file…