In [1]:
import os
import subprocess
import tempfile

from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import (
    ChatPromptTemplate,
    HumanMessagePromptTemplate,
    SystemMessagePromptTemplate,
)

In [None]:
assert os.environ["OPENAI_API_KEY"]

In [2]:
llm = ChatOpenAI(temperature=0.3, model_name="gpt-4o")

  llm = ChatOpenAI(temperature=0.3, model_name="gpt-4o")


In [3]:
system_prompt = SystemMessagePromptTemplate.from_template(
    """
You are an expert in writing LAMMPS input scripts for molecular dynamics simulations.
The user will provide requirements, and you will produce a LAMMPS input script.

**Important:**
- Do not output any explanations or reasoning lines outside of LAMMPS-compatible comments.
- If you must explain or clarify something, write it as LAMMPS comments using `#` at the beginning of the line.
- The final output should be a valid LAMMPS input file with optional comments but no extra text outside of the script.
"""
)

In [4]:
human_prompt = HumanMessagePromptTemplate.from_template(
    """
I want you to write a LAMMPS input script that:
1. Uses an Lennard-Jones potential for Argon (Ar)
2. Sets up an FCC lattice
3. Minimizes energy
4. Runs an NVT simulation at 300K for 1000 steps
5. Outputs thermodynamic data and final coordinates

Produce a LAMMPS input script meeting these conditions.
"""
)

In [5]:
import re


def clean_lammps_script(script_content: str) -> str:
    cleaned = re.sub(r"```[A-Za-z]*\n?", "", script_content)
    cleaned = re.sub(r"```", "", cleaned)
    return cleaned

In [6]:
def run_lammps_script(script_content: str):
    with tempfile.NamedTemporaryFile(suffix=".in", delete=False) as f:
        input_file = f.name
        f.write(script_content.encode("utf-8"))
    try:
        result = subprocess.run(
            ["lmp", "-in", input_file], capture_output=True, text=True
        )
    except FileNotFoundError:
        raise RuntimeError("LAMMPS command 'lmp' not found in PATH.")

    stdout = result.stdout
    stderr = result.stderr

    os.remove(input_file)

    return stdout, stderr, result.returncode

In [7]:
def fix_script(llm, original_script: str, error_message: str):
    fix_system_prompt = SystemMessagePromptTemplate.from_template(
        """
You are an expert LAMMPS script writer. The user provided a script that caused LAMMPS errors.
You will receive the original script and the error message LAMMPS produced.
Improve the script so that it runs successfully, while keeping the user's initial intent.
If possible, explain briefly what changes you made.
                                                                  
**Important:**
- Do not output any explanations or reasoning lines outside of LAMMPS-compatible comments.
- If you must explain or clarify something, write it as LAMMPS comments using `#` at the beginning of the line.
- The final output should be a valid LAMMPS input file with optional comments but no extra text outside of the script.

"""
    )
    fix_human_prompt = HumanMessagePromptTemplate.from_template(
        f"""
Original script:
{original_script}

Error message from LAMMPS:
{error_message}

Please fix the script.
"""
    )

    fix_chain = LLMChain(
        llm=llm,
        prompt=ChatPromptTemplate.from_messages([fix_system_prompt, fix_human_prompt]),
    )
    new_script = fix_chain.run({})
    return new_script

In [8]:
chain_prompt = ChatPromptTemplate(messages=[system_prompt, human_prompt])
chain = LLMChain(llm=llm, prompt=chain_prompt)

  chain = LLMChain(llm=llm, prompt=chain_prompt)


In [9]:
current_script = chain.run({})

  current_script = chain.run({})


In [10]:
print(clean_lammps_script(current_script))

# LAMMPS Input Script for Argon using Lennard-Jones Potential

# Initialize simulation
units           lj
atom_style      atomic

# Create simulation box and atoms
lattice         fcc 0.8442
region          box block 0 10 0 10 0 10
create_box      1 box
create_atoms    1 box

# Define Lennard-Jones potential for Argon
mass            1 1.0
pair_style      lj/cut 2.5
pair_coeff      1 1 1.0 1.0 2.5

# Set up neighbor list
neighbor        0.3 bin
neigh_modify    every 20 delay 0 check yes

# Energy minimization
min_style       cg
minimize        1.0e-4 1.0e-6 100 1000

# Define simulation settings
timestep        0.005
thermo          100

# NVT ensemble at 300K
velocity        all create 1.0 87287 loop geom
fix             1 all nvt temp 1.0 1.0 0.1

# Run the simulation
run             1000

# Output final coordinates
write_restart   final.restart
write_data      final.data

# End of LAMMPS input script



In [11]:
max_iterations = 5

script_history = []

output_dir = "scripts_history"
os.makedirs(output_dir, exist_ok=True)

for i in range(max_iterations):
    current_script = clean_lammps_script(current_script)

    script_history.append((i, current_script))
    script_filename = os.path.join(output_dir, f"script_iteration_{i}.in")
    with open(script_filename, "w") as f:
        f.write(current_script)

    stdout, stderr, returncode = run_lammps_script(current_script)

    if returncode == 0:
        print("LAMMPS script ran successfully!")
        print("STDOUT:", stdout)
        break
    else:
        print(f"Iteration {i} failed. Trying to fix...")
        fixed_script = fix_script(llm, current_script, stderr)
        current_script = clean_lammps_script(fixed_script)
        print("Fixed script candidate:\n", current_script)
else:
    print("Max iterations reached, still errors remain.")

for iter_idx, script_content in script_history:
    print(f"Iteration {iter_idx} script:\n{script_content}\n")

LAMMPS script ran successfully!
STDOUT: LAMMPS (29 Sep 2021 - Update 2)
OMP_NUM_THREADS environment is not set. Defaulting to 1 thread. (src/comm.cpp:98)
  using 1 OpenMP thread(s) per MPI task
Lattice spacing in x,y,z = 1.6795962 1.6795962 1.6795962
Created orthogonal box = (0.0000000 0.0000000 0.0000000) to (16.795962 16.795962 16.795962)
  1 by 1 by 1 MPI processor grid
Created 4000 atoms
  using lattice units in orthogonal box = (0.0000000 0.0000000 0.0000000) to (16.795962 16.795962 16.795962)
  create_atoms CPU = 0.000 seconds
Neighbor list info ...
  update every 1 steps, delay 0 steps, check yes
  max neighbors/atom: 2000, page size: 100000
  master list distance cutoff = 2.8
  ghost atom cutoff = 2.8
  binsize = 1.4, bins = 12 12 12
  1 neighbor lists, perpetual/occasional/extra = 1 0 0
  (1) pair lj/cut, perpetual
      attributes: half, newton on
      pair build: half/bin/atomonly/newton
      stencil: half/bin/3d
      bin: standard
Setting up cg style minimization ...
  U