In [185]:
from ase.atoms import Atoms
from ase.build import bulk
from ase.calculators.emt import EMT
from ase.constraints import UnitCellFilter
from ase.eos import calculate_eos, plot
from ase.data import reference_states, atomic_numbers
from ase.optimize import LBFGS
from ase.units import kJ

In [None]:
from mace.calculators import mace_mp

In [97]:
from langchain.agents import AgentExecutor
from langchain.agents import tool
from langchain_core.pydantic_v1 import BaseModel
from langchain_openai import ChatOpenAI
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.agents.format_scratchpad.openai_tools import (
    format_to_openai_tool_messages,
)
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser

In [3]:
%matplotlib widget
from typing import List
import numpy as np
import matplotlib.pyplot as plt

In [4]:
import ipywidgets as widgets

In [136]:
class AtomsDict(BaseModel):
    numbers: List[int]
    positions: List[List[float]]
    cell: List[List[float]]
    pbc: List[bool]

In [135]:
def get_calculator(calculator_str: str = "emt"):
    if calculator_str == "emt":
        return EMT()
    elif calculator_str == "mace":
        return mace_mp(
            model="medium",
            dispersion=False,
            default_dtype="float32",
            device='cpu'
        )

In [134]:
@tool 
def get_equilibirum_lattice(chemical_symbol: str, calculator_str: str) -> AtomsDict:
    """Returns equilibrium atoms dictionary for a given chemical symbol and a selected model specified by the calculator string.
    """
    atoms = bulk(name=chemical_symbol)
    atoms.calc = get_calculator(calculator_str=calculator_str)
    ase_optimizer_obj = LBFGS(UnitCellFilter(atoms))
    ase_optimizer_obj.run(fmax=0.000001)
    return AtomsDict(**{k: v.tolist() for k, v in atoms.todict().items()})

In [133]:
@tool
def plot_equation_of_state(atom_dict: AtomsDict, calculator_str: str) -> str:
    """Returns plot of equation of state of chemical symbol for a given atoms dictionary and a selected model specified by the calculator string"""
    atoms = Atoms(**atom_dict.dict())
    atoms.calc = get_calculator(calculator_str=calculator_str)
    eos = calculate_eos(atoms)
    plotdata = eos.getplotdata()
    return plot(*plotdata)

In [132]:
@tool
def get_bulk_modulus(atom_dict: AtomsDict, calculator_str: str) -> str:
    """Returns the bulk modulus of chemcial symbol for a given atoms dictionary and a selected model specified by the calculator string in GPa.
    """
    atoms = atoms = Atoms(**atom_dict.dict())
    atoms.calc = get_calculator(calculator_str=calculator_str)
    eos = calculate_eos(atoms)
    v, e, B = eos.fit()
    return B / kJ * 1.0e24

In [131]:
@tool
def get_equilibrium_volume(atom_dict: AtomsDict, calculator_str: str) -> str:
    """Returns the equilibrium volume of chemcial symbol for a given atoms dictionary and a selected model specified by the calculator string in Angstrom^3.
    """
    atoms = atoms = Atoms(**atom_dict.dict())
    atoms.calc = get_calculator(calculator_str=calculator_str)
    eos = calculate_eos(atoms)
    v, e, B = eos.fit()
    return v

# Simulation LLM
The first large language model for calculating material properties. Example queries: 
* "Calculate the bulk modulus for aluminium with the mace model"
* "Calculate the equilibrium volume of gold"

In [130]:
key_widget = widgets.Text(
    value='sk-********************',
    placeholder='',
    description='OpenAI Token:',
    disabled=False,
    layout=widgets.Layout(width='400px'),
    style = {'description_width': '100px'} 
)

In [121]:
query_widget = widgets.Textarea(
    value='Calculate the equilibrium volume of gold',
    placeholder='',
    description='Next query:',
    disabled=False,
    layout=widgets.Layout(width='400px'),
    style = {'description_width': '100px'}     
)

In [137]:
model_widget = widgets.Dropdown(
    options=['gpt-4-turbo', 'gpt-3.5-turbo', "llama3-70b-8192", "mixtral-8x7b-32768"],
    value='gpt-3.5-turbo',
    description='Model:',
    disabled=False
)

In [123]:
def get_extension(OPENAI_API_KEY, model="gpt-3.5-turbo"):
    if model in ("gpt-3.5-turbo", "gpt-4-turbo"):
        llm = ChatOpenAI(model=model, temperature=0, openai_api_key=OPENAI_API_KEY)
    elif model in ("llama3-70b-8192", "mixtral-8x7b-32768"):
        llm = ChatGroq(model=model, temperature=0, groq_api_key=OPENAI_API_KEY)
    else:
        return "No valid model specified"

    tools = [get_equilibrium_volume, get_equilibirum_lattice, plot_equation_of_state, get_bulk_modulus]
    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                "You are very powerful assistant, but don't know current events. To calculate with emt use the calculator string emt and to calculate with mace use the calculator string mace. For each query vailidate that it contains a chemical element and a calculator string and otherwise use Alumninum as the default chemical element and emt as the default calculator string.",
            ),
            ("user", "{input}"),
            MessagesPlaceholder(variable_name="agent_scratchpad"),
        ]
    )
    llm_with_tools = llm.bind_tools(tools)
    agent = (
        {
            "input": lambda x: x["input"],
            "agent_scratchpad": lambda x: format_to_openai_tool_messages(
                x["intermediate_steps"]
            ),
        }
        | prompt
        | llm_with_tools
        | OpenAIToolsAgentOutputParser()
    )

    return AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True)

In [None]:
out = widgets.Output()

In [None]:
def get_plot(output_steps):            
    for step in output_steps:
        if hasattr(step, 'observation') and isinstance(step.observation, plt.Axes):
            return step.observation

In [None]:
def run_refresh(*ignore):
    agent_executor = get_extension(OPENAI_API_KEY=key_widget.value)
    output = list(agent_executor.stream({"input": query_widget.value}))
    with out:
        for part in output:
            if "steps" in part.keys():
                plot_axes = get_plot(part["steps"])
                if isinstance(plot_axes, plt.Axes):
                    display(plot_axes.figure)          
        display(output[-1]["output"])

In [None]:
button = widgets.Button(description ="Submit")
button.on_click(run_refresh)



[1m> Entering new AgentExecutor chain...[0m


AuthenticationError: Error code: 401 - {'error': {'message': 'Incorrect API key provided: sk-********************. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}



[1m> Entering new AgentExecutor chain...[0m


AuthenticationError: Error code: 401 - {'error': {'message': 'Incorrect API key provided: gsk_SiHS********************************************fEKV. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}



[1m> Entering new AgentExecutor chain...[0m


BadRequestError: Error code: 400 - {'error': {'message': "Failed to call a function. Please adjust your prompt. See 'failed_generation' for more details.", 'type': 'invalid_request_error', 'code': 'tool_use_failed', 'failed_generation': '```\n{\n  "tool_calls": [\n    {\n      "id": "pending",\n      "type": "function",\n      "function": {\n        "name": "get_equilibrium_volume"\n      },\n      "parameters": {\n        "atom_dict": {\n          "cell": [[1, 0, 0], [0, 1, 0], [0, 0, 1]],\n          "numbers": [79],\n          "pbc": [True, True, True],\n          "positions": [[0, 0, 0]]\n        },\n        "calculator_str": "emt"\n      }\n    }\n  ]\n}\n```'}}



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `get_equilibrium_volume` with `{'atom_dict': {'cell': [[4.078, 0.0, 0.0], [0.0, 4.078, 0.0], [0.0, 0.0, 4.078]], 'numbers': [79, 79, 79], 'pbc': [True, True, True], 'positions': [[0.0, 0.0, 0.0], [0.5, 0.5, 0.5], [0.5, 0.0, 0.5]]}, 'calculator_str': 'emt'}`


[0m

ValueError: No minimum!



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `get_equilibrium_volume` with `{'atom_dict': {'cell': [[4.05, 0.0, 0.0], [0.0, 4.05, 0.0], [0.0, 0.0, 4.05]], 'numbers': [13, 13, 13], 'pbc': [True, True, True], 'positions': [[0.0, 0.0, 0.0], [0.5, 0.5, 0.0], [0.5, 0.0, 0.5]]}, 'calculator_str': 'emt'}`


[0m

ValueError: No minimum!



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `get_equilibrium_volume` with `{'atom_dict': {'cell': [[4.05, 0.0, 0.0], [0.0, 4.05, 0.0], [0.0, 0.0, 4.05]], 'numbers': [13, 13, 13], 'pbc': [True, True, True], 'positions': [[0.0, 0.0, 0.0], [0.5, 0.5, 0.0], [0.5, 0.0, 0.5]]}, 'calculator_str': 'emt'}`


[0m

ValueError: No minimum!

example queries </br>
* First get the equilibrium lattice for gold, then calculate the bulk modulus
* What is the equilibrium volume of aluminium
* Calculate the bulk modulus of gold

In [None]:
display(query_widget, key_widget, model_widget, button, out)