In [None]:
import sys

sys.path.append("ares_transverse_tuning")

In [None]:
import time
from datetime import datetime
from pathlib import Path

import gymnasium as gym
import langchain
import matplotlib.pyplot as plt
import numpy as np
import scienceplots
from dotenv import load_dotenv
from gymnasium.wrappers import TimeLimit
from icecream import ic
from langchain.callbacks import FileCallbackHandler, wandb_tracing_enabled
from langchain.chains import LLMChain
from langchain.chat_models import ChatOllama
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
from langchain.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from loguru import logger
from openai import RateLimitError
from src.environments.ea import TransverseTuning
from src.eval import Episode
from src.trial import load_trials
from src.wrappers import RecordEpisode, TQDMWrapper

from pacuna import PACuna

In [None]:
load_dotenv()

plt.style.use(["science", "nature", "no-latex"])

langchain.debug = True

In [None]:
# chat_model = ChatOpenAI(model="gpt-3.5-turbo-0125")
# chat_model = ChatOpenAI(model="gpt-4")
# chat_model = ChatOpenAI(model="gpt-4-32k")
# chat_model = ChatOpenAI(model="gpt-4-0125-preview")
# chat_model = ChatOllama(model="mistral:v0.2", base_url="http://max-wng060:11434")
# chat_model = ChatOllama(model="zephyr", base_url="http://max-wng052:11434")
# chat_model = ChatOllama(model="mixtral:8x7b", base_url="http://max-wng060:11434")
# chat_model = ChatOllama(model="gemma:2b", base_url="http://max-wng060:11434")
# chat_model = ChatOllama(model="gemma:7b", base_url="http://max-wng060:11434")
# chat_model = ChatOllama(model="starling-lm:7b-beta", base_url="http://max-wng060:11434")
chat_model = ChatOllama(model="vicuna:7b-16k", base_url="http://max-wng058:11434")
# chat_model = ChatOllama(model="vicuna:33b", base_url="http://max-wng054:11434")
# chat_model = ChatOllama(model="llava:34b", base_url="http://max-wng054:11434")
# chat_model = ChatOllama(model="orca2:7b", base_url="http://max-wng060:11434")
# chat_model = ChatOllama(model="orca2:13b", base_url="http://max-wng060:11434")
# chat_model = ChatOllama(model="llama2:7b", base_url="http://max-wng060:11434")
# chat_model = ChatOllama(model="llama2:13b", base_url="http://max-wng060:11434")
# chat_model = ChatOllama(model="llama2:70b", base_url="http://max-wng060:11434")
# chat_model = ChatOllama(model="falcon:180b-chat", base_url="http://max-wng054:11434")
# chat_model = ChatOllama(model="neural-chat:7b")
# chat_model = ChatOllama(model="mistral-openorca:7b", base_url="http://max-wng052:11434")
# chat_model = ChatOllama(model="phi:chat")
# chat_model = ChatOllama(model="megadolphin:120b", base_url="http://max-wng053:11434")
# chat_model = ChatOllama(model="yi:34b-chat", base_url="http://max-wng054:11434") x
# chat_model = PACuna()

In [None]:
# chat_model.invoke("Hello")

In [None]:
TRIAL_INDEX = 38
trials = load_trials(Path("ares_transverse_tuning/data/trials.yaml"))
trial = trials[TRIAL_INDEX]
trial

In [None]:
model_name = (
    chat_model.model_name if hasattr(chat_model, "model_name") else chat_model.model
).replace(":", "-")
now = datetime.now()

log_dir = (
    Path("data")
    / "paper"
    / "optimisation"
    / model_name
    / f"trial-{TRIAL_INDEX}_{now.strftime('%Y-%m-%d_%H-%M-%S')}"
)
langchain_log_file = log_dir / "langchain.log"
logger.add(langchain_log_file, colorize=True, enqueue=True)
handler = FileCallbackHandler(langchain_log_file)

In [None]:
log_dir

In [None]:
env = TransverseTuning(
    action_mode="direct",
    magnet_init_mode=trial.initial_magnets,
    target_beam_mode=trial.target_beam,
    backend="cheetah",
    backend_args={
        "incoming_mode": trial.incoming_beam,
        "misalignment_mode": trial.misalignments,
        "generate_screen_images": False,
    },
)
# env = TimeLimit(env, max_episode_steps=18)
# env = RecordEpisode(env, Path("data/recorded_episodes"))
# env = TQDMWrapper(env)
env

In [None]:
observation, info = env.reset()
np.mean(np.abs(observation["target"] - observation["beam"])) * 1e3

In [None]:
sample_template = """Inputs:
```json
{{
	"Q1": {q1:.2f},
	"Q2": {q2:.2f},
	"CV": {cv:.2f},
	"Q3": {q3:.2f},
	"CH": {ch:.2f}
}}
```
Objective value = {objective:.2f}
"""

In [None]:
message_template = """Now you will help me minimise a function with five input variables Q1, Q2, CV, Q3 and CH. I have some (Q1, Q2, CV, Q3, CH) pairs and the corresponding function values at those points. The samples are arranged in descending order based on their function values, where lower values are better.

{prior_samples}

Give me a new sample (Q1, Q2, CV, Q3, CH) that is different from all pairs above, and has a function value lower than any of the above.

{format_instructions}
"""

In [None]:
response_q1 = ResponseSchema(
    name="Q1",
    type="float",
    description="First input",
)
response_q2 = ResponseSchema(
    name="Q2",
    type="float",
    description="Second input",
)
response_cv = ResponseSchema(
    name="CV",
    type="float",
    description="Third input",
)
response_q3 = ResponseSchema(
    name="Q3",
    type="float",
    description="Fourth input",
)
response_ch = ResponseSchema(
    name="CH",
    type="float",
    description="Fifth input",
)

output_parser = StructuredOutputParser.from_response_schemas(
    [response_q1, response_q2, response_cv, response_q3, response_ch]
)

print(output_parser.get_format_instructions())

In [None]:
prompt = PromptTemplate.from_template(message_template)
prompt

In [None]:
chain = LLMChain(
    prompt=prompt, llm=chat_model, output_parser=output_parser, callbacks=[handler]
)
decapitated_chain = LLMChain(prompt=prompt, llm=chat_model)
chain

In [None]:
# # with wandb_tracing_enabled():
# response = decapitated_chain.invoke(
#     {
#         "prior_samples": samples_str,
#         "format_instructions": output_parser.get_format_instructions(),
#     }
# )
# print(response["text"])

In [None]:
# with wandb_tracing_enabled():
#     response = chain.invoke(
#         {
#             "prior_samples": samples_str,
#             "format_instructions": output_parser.get_format_instructions(),
#         }
#     )
# print(response["text"])

In [None]:
class LLMTransverseTuningAgent:
    """
    An agent for doing transverse beam parameter tuning in the ARES experimental
    area using an LLM via LangChain.
    """

    #     message_template = """Now you will help me optimise the horizontal and vertical position and size of an electron beam on a diagnostic screen in a particle accelerator.
    #
    # The target beam parameters I want you to find are:
    #  - horizontal position: 0.0 μm
    #  - horizontal size: 0.0 μm
    #  - vertical position: 0.0 μm
    #  - vertical size: 0.0 μm
    #
    # You are able to control five magnets in the beam line. The magnets are called:
    #  - Q1
    #  - Q2
    #  - CV
    #  - Q3
    #  - CH
    #
    # Q1, Q2, Q3 are quadrupole magnets. When their k1 strenth is increased, the beam becomes more focused in the horizontal plane and more defocused in the vertical plane. When their k1 strength is decreased, the beam becomes more focused in the vertical plane and more defocused in the horizontal plane. When their k1 strength is zero, the beam is not focused in either plane. Quadrupole magnets might also steer the beam in the horizontal or vertical plane depending on their k0 strength, when the beam does not travel through the centre of the magnet. The range of the k1 strength is -30.0 to 30.0 m^-2.
    #
    # CV is vertical steering magnet. When its deflection angle is increased, the beam is steered upwards. When its deflection angle is decreased, the beam is steered downwards. The range of the deflection angle is -6.0 to 6.0 mrad.
    #
    # CH is horizontal steering magnet. When its deflection angle is increased, the beam is steered to the right. When its deflection angle is decreased, the beam is steered to the left. The range of the deflection angle is -6.0 to 6.0 mrad.
    #
    # I have some pairs of magnet settings and the corresponding beam parameters.
    #
    # {prior_samples}
    #
    # Give me new magnet settings that are different from all pairs above, and will result in transnverse beam parameters closer to the target beam parameters than any of the above. Beam parameters less than 40 μm from their target are considered optimal. If you do not know which magnet settings would improve the beam parameters, choose magnet settings that maximise information gain. Smooth changes to the magnet settings are preferred. Do not write code.
    #
    # {format_instructions}
    # """

    #     sample_template = """Magnet settings:
    #  - Q1: {q1} m^-2
    #  - Q2: {q2} m^-2
    #  - CV: {cv} mrad
    #  - Q3: {q3} m^-2
    #  - CH: {ch} mrad
    # Beam parameters:
    #  - horizontal position: {mu_x} μm
    #  - horizontal size: {sigma_x} μm
    #  - vertical position: {mu_y} μm
    #  - vertical size: {sigma_y} μm
    # """

    def __init__(
        self, env: gym.Env, warmup_steps: int = 0, verbose: bool = False
    ) -> None:
        self.env = env
        self.warmup_steps = warmup_steps
        self.verbose = verbose

        global chain
        self._chain = chain

        self._observations = []
        self._objectives = []

    def predict(self, observation: dict) -> np.ndarray:
        """
        Takes an observation from the environment and returns an action.
        """
        self._observations.append(observation)

        objective = np.mean(np.abs(observation["target"] - observation["beam"]))
        self._objectives.append(objective)

        # -1 because we already have the first observation from the reset
        if len(self._observations) < self.warmup_steps - 1:
            return self.env.action_space.sample()

        samples = zip(self._observations, self._objectives)
        sorted_samples = sorted(samples, key=lambda x: x[1])
        reversed_samples = reversed(sorted_samples)

        global sample_template
        global output_parser

        samples_str = "\n".join(
            sample_template.format(
                q1=observation["magnets"][0],
                q2=observation["magnets"][1],
                cv=observation["magnets"][2] * 1e3,
                q3=observation["magnets"][3],
                ch=observation["magnets"][4] * 1e3,
                objective=objective * 1e3,
            )
            for observation, objective in reversed_samples
        )

        try:
            response = self._chain.invoke(
                {
                    "prior_samples": samples_str,
                    "format_instructions": output_parser.get_format_instructions(),
                }
            )
        except Exception as e:
            time.sleep(5.0)
            response = self._chain.invoke(
                {
                    "prior_samples": samples_str,
                    "format_instructions": output_parser.get_format_instructions(),
                }
            )

        if self.verbose:
            print(response["explanation"])

        action = np.array(
            [
                response["text"]["Q1"],
                response["text"]["Q2"],
                response["text"]["CV"] / 1e3,
                response["text"]["Q3"],
                response["text"]["CH"] / 1e3,
            ]
        )

        return action

In [None]:
env = TransverseTuning(
    action_mode="direct",
    magnet_init_mode=trial.initial_magnets,
    target_beam_mode=trial.target_beam,
    backend="cheetah",
    backend_args={
        "incoming_mode": trial.incoming_beam,
        "misalignment_mode": trial.misalignments,
        "generate_screen_images": False,
    },
)
env = TimeLimit(env, max_episode_steps=50)
env = RecordEpisode(env, log_dir / "recorded_episodes")
env = TQDMWrapper(env)
env

In [None]:
agent = LLMTransverseTuningAgent(env=env, warmup_steps=0, verbose=False)
agent

In [None]:
with wandb_tracing_enabled():
    observation, info = env.reset()
    done = False
    while not done:
        action = agent.predict(observation)
        observation, _, terminated, truncated, _ = env.step(action)
        done = terminated or truncated

In [None]:
env.close()

In [None]:
episode = Episode.load(log_dir / "recorded_episodes" / "recorded_episode_1.pkl")

In [None]:
_ = episode.plot_summary()