In [1]:
from getpass import getpass

In [2]:
from ase.build import bulk
from ase.data import reference_states, atomic_numbers
from langchain.agents import tool

In [42]:
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.agents.format_scratchpad.openai_tools import (
    format_to_openai_tool_messages,
)
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
from langchain.agents import AgentExecutor

In [68]:
from langchain.agents import format_scratchpad

##### Tools

In [16]:
@tool
def get_crystal_structure(chemical_symbol: str) -> str:
    """Returns the atomic crystal structure of a chemcial symbol"""
    if not chemical_symbol in atomic_numbers:
        return f"{chemical_symbol} is not a valid element name as per periodic table."
    ref_state = reference_states[atomic_numbers[chemical_symbol]]
    if ref_state is None:
        return "No crystal structure known."
    else:
        return ref_state["symmetry"]

In [18]:
tools = [get_crystal_structure]

##### Groq

Available models
* llama3-8b-8192 (8,192 tokens)
* llama3-70b-8192 (8,192 tokens)
* mixtral-8x7b-32768 (32,768 tokens)
* gemma-7b-it (8,192 tokens)

In [5]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq

In [8]:
GROQ_API_KEY = getpass(prompt="Enter your GROQ Token:")

In [62]:
MODEL = "llama3-8b-8192"
groq_llm = ChatGroq(temperature=0, groq_api_key=GROQ_API_KEY, model_name=MODEL)

OpenAI

In [26]:
OPENAI_API_KEY = getpass(prompt='Enter your OpenAI Token:')

In [27]:
from langchain_openai import ChatOpenAI

oa_llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0, openai_api_key=OPENAI_API_KEY)

##### Agents

In [30]:
prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            # "You are very powerful assistant, but don't know current events.",  # This initial query fails when car is provided as input rather than gold
            "You are very powerful assistant, but don't know current events. For each query vailidate that it contains a chemical element and otherwise cancel.",
        ),
        ("user", "{input}"),
        MessagesPlaceholder(variable_name="agent_scratchpad"),
    ]
)

In [65]:
llm_with_tools = groq_llm.bind_tools(tools)
agent = (
    {
        "input": lambda x: x["input"],
        "agent_scratchpad": lambda x: format_to_openai_tool_messages(
            x["intermediate_steps"]
        ),
    }
    | prompt
    | llm_with_tools
    | OpenAIToolsAgentOutputParser()
)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)

In [66]:
lst = list(agent_executor.stream({"input": "What is the crystal structure of gold"})) 



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `get_crystal_structure` with `{'chemical_symbol': 'Au'}`


[0m[36;1m[1;3mfcc[0m[32;1m[1;3mThe crystal structure of gold is fcc.[0m

[1m> Finished chain.[0m
