# Experiment with MiniWoB++ environment

### Configs

In [1]:
# if True, run MiniWoB++ in headless mode
headless = False

# number of episodes to run per task
num_episodes = 1

# set to True to use the expert samples
ALLOW_DEMO = True
# set to False to directly use generated plan without syntax check
CODE_CHECK = False
# 'gpt-3.5' / 'gpt-3' / 'davinci-003'
GPT_MODEL = 'davinci-003'

# max number of trials for each task
NUM_TRY = 20
# after NUM_TRY_RESET trials, the agent will try to start from step 1.
NUM_TRY_RESET = 10

# list of tasks where in-loop refinement is needed.
tasks_with_ask = [
    'simple-algebra',
    'click-shades',
    'click-shape',
    'identify-shape',
    'count-shape',
    'click-checkboxes-soft',
    'search-engine',
    'tic-tac-toe',
    'guess-number',
    'terminal',
]

# list of tasks where feedback is available.
tasks_with_feedback = [
    "search-engine",
    "tic-tac-toe",
    "email-inbox-forward-nl-turk",
    "terminal",
    "login-user-popup",
    "guess-number",
    "email-inbox-forward-nl",
    "email-inbox",
    "email-inbox-nl-turk"
]

# save to files
LOG_FILE = 'interaction_log.txt'
PR_FILE = 'Prompt_Response.txt'

### Functions

In [2]:
import argparse
import random

import numpy as np
import os
import openai
import computergym
import gym

from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.common.action_chains import ActionChains
from datetime import datetime
import logging

logging.basicConfig(level=logging.WARNING)

import urllib3
urllib3.disable_warnings()

from computergym.miniwob.miniwob_interface.action import (
    MiniWoBType,
    MiniWoBElementClickId,
    MiniWoBElementClickXpath,
    MiniWoBElementClickOption,
    MiniWoBMoveXpath,
)
import re
from selenium.webdriver.common.keys import Keys
import json
import time
import openai
from pathlib import Path

import io
import traceback

import sys
sys.stdout.flush()

openai.api_key = os.environ["OPENAI_API_KEY"]

def save_to_file(directory, filename, content):
    if not os.path.exists(directory):
        os.makedirs(directory)
        
    file_path = os.path.join(directory, filename)
    with open(filename, 'a') as file:
        file.write(content + '\n')

def clear_file_content(filename):
    with open(filename, 'w') as file:
        pass

def dict_to_traj(d):
    traj = ''
    for i, a in enumerate(d['actions']):
        traj += f'Act {i}: {a}\nObs {i}: {d["observations"][i]}\n'
    return traj

def extract_answers(text, markers):
    answers = []
    for i, marker in enumerate(markers):
        if i < len(markers) - 1:
            next_marker = markers[i + 1]
            pattern = fr"{re.escape(marker)}\s*([\s\S]*?)\n*{re.escape(next_marker)}"
        else:
            pattern = fr"{re.escape(marker)}\s*([\s\S]*)" 
        answer = re.search(pattern, text)
        if answer:
            answer = answer.group(1).strip()
            answers.append(answer)
        else:
            answers.append('Not found')
    return answers

def eliminate_content_after(text, eliminate_from):
    lines = text.split('\n')
    result = []
    
    for line in lines:
        if eliminate_from in line:
            break
        result.append(line)
    
    return '\n'.join(result)

def get_error_step(error_message):
    pattern = r'\[Step (\d+)\]'
    match = re.search(pattern, error_message)
    if match:
        return int(match.group(1))
    else:
        return None
    
def get_first_digit(input_string):
    pattern = r'\d'
    match = re.search(pattern, input_string)
    if match:
        return int(match.group())
    else:
        return None

datetime = datetime.now().strftime("%m%d-%H%M%S")
save_to_folder = f'./r_{GPT_MODEL}_{datetime}/'
LOG_FILE =  save_to_folder + LOG_FILE
PR_FILE =  save_to_folder + PR_FILE

save_to_file(save_to_folder, PR_FILE, 'Start: ' + str(datetime) + '\n')
save_to_file(save_to_folder, LOG_FILE, 'Start: ' + str(datetime) + '\n')

In [3]:
def ask(prompt):
    prompt_chat = [
            {"role": "user", "content": prompt.strip()},
        ]
    cnt = 0
    while True:
        try:
            if GPT_MODEL == 'gpt-3.5':
                response = openai.ChatCompletion.create(
                    model="gpt-3.5-turbo",
                    messages=prompt_chat,
                    temperature=0,
                    max_tokens=1400,
                    top_p=1,
                    frequency_penalty=0.0,
                    presence_penalty=0.0,
                    )
                answer = response["choices"][0]["message"]['content'].strip()
                save_to_file(save_to_folder,
                             PR_FILE, 
                            'Prompt: \n' +
                            prompt +
                            '\nResponse: \n' +
                            answer + '\n' + '='*20 + '\n')
        
                return answer
            elif GPT_MODEL == 'gpt-3' or GPT_MODEL == 'davinci-003':
                response = openai.Completion.create(
                    model="text-davinci-002" if GPT_MODEL == 'gpt-3' else "text-davinci-003",
                    prompt=prompt,
                    temperature=0,
                    max_tokens=1000,
                    top_p=1,
                    frequency_penalty=0.0,
                    presence_penalty=0.0,
                    )
                answer = response["choices"][0]["text"].strip()
                save_to_file(save_to_folder,
                            PR_FILE,
                            'Prompt: \n' +
                            prompt +
                            '\nResponse: \n' +
                            answer + '\n' + '='*20 + '\n')
                return answer
            else:
                raise Exception('Wrong GPT_MODEL')
        except openai.error.RateLimitError as e:
            retry_after = 3
            print(f"Rate limit error: {e}. Retrying in {retry_after} seconds.")
            time.sleep(retry_after)
        except openai.error.InvalidRequestError as e:
            eliminate_context = "# for example: You have a list of receptacles, and you want to sort them by the likelihood of a soapbar appearing in them. You can do this by asking the assistant:\nreceptacles = ['countertop 1', 'garbagecan 1', 'sinkbasin 2', 'sinkbasin 1', 'toilet 1', 'toiletpaperhanger 1', 'towelholder 1']\nanswer = ask(f'Sort the list of receptacles, starting from the one a soapbar is most likely to appear: {receptacles}. You should return a Python list.')\n# answer = ['sinkbasin 1', 'sinkbasin 2', 'countertop 1', 'towelholder 1', 'toiletpaperhanger 1', 'garbagecan 1', 'toilet 1']"
            prompt = prompt.replace(eliminate_context, '')
            prompt_chat = [
                {"role": "user", "content": prompt.strip()},
            ]
            print(f"Exceed max: {e}.")
            cnt += 1
            if cnt > 3:
                return 'Exceed max limit. Tried 3 times. Skip this one.'
        except openai.error.APIError:
            cnt += 1
            if cnt > 3:
                return 'APIError. Tried 3 times. Skip this one.'
        except Exception as e:
            print(f"An unexpected error occurred: {e}")
            cnt += 1
            if cnt > 3:
                return 'Error. Tried 3 times. Skip this one.'

In [4]:
def capture_output(func, agent, step=1):
    # Store the original standard output and standard error
    original_stdout = sys.stdout
    original_stderr = sys.stderr

    # Redirect the standard output and error to in-memory file-like objects
    temp_stdout = io.StringIO()
    temp_stderr = io.StringIO()
    sys.stdout = temp_stdout
    sys.stderr = temp_stderr

    checkpoint = None
    # Run the function and capture exceptions
    try:
        func(agent, start_from=step)
    except Exception as e:
        traceback.print_exc()
        checkpoint = sys.exc_info()[2].tb_next.tb_frame.f_locals

    # Restore the original standard output and error
    sys.stdout = original_stdout
    sys.stderr = original_stderr

    # Get the output and error messages as strings
    output_string = temp_stdout.getvalue()
    error_string = temp_stderr.getvalue()

    print(output_string)
    print(error_string)
    return error_string, checkpoint, output_string + error_string

def modify_header(checkpoint):
    if not checkpoint:
        return 'def solution(agent, start_from=1):'
    load_checkpoint = ''
    skip_vars = ['agent', 'start_from']
    for k,v in checkpoint.items():
        if k not in skip_vars:
            if type(v) == str:
                load_checkpoint += f', {k}="{v}"'
            else:
                load_checkpoint += f', {k}={v}'
    header = f'def solution(agent, start_from{load_checkpoint}):'
    return header

### Define the agent

In [5]:
# define the Agent

# adapted from rci
def get_html_state(env, states):
    if not states[0]:
        return None
    # return states[0]._dom_elements
    extra_html_task = [
        "click-dialog",
        "click-dialog-2",
        "use-autocomplete",
        "choose-date",
    ]
    html_body = states[0].html_body
    if env in extra_html_task:
        html_body += states[0].html_extra
    # return bs(html_body, 'html.parser')
    return html_body

class Agent:
    def __init__(self, env_name):
        self.env_name = env_name
        self.env = gym.make("MiniWoBEnv-v0", env_name=env_name, headless=headless)
        observation = self.env.reset(seeds=[random.random()], record_screenshots=False)
        self.description = observation[0].utterance
        self.current_state = get_html_state(env_name, observation)
        self.interaction_history = {'actions': [], 'observations': []}
        self.is_success = False

    def reset(self):
        observation = self.env.reset(seeds=[random.random()], record_screenshots=False)
        self.description = observation[0].utterance
        self.current_state = get_html_state(self.env_name, observation)
        self.interaction_history = {'actions': [], 'observations': []}
        self.is_success = False
        self.add_to_history('Initial state', str(self.current_state))

    def close(self):
        self.env.close()

    def __del__(self): 
        print(f"Current task {self.env_name} is done.")

    def interact_with_env(self, action):
        observation, reward, done, _ = self.env.step([action])
        # print(observation, reward, done)
        html_obs, reward, done = get_html_state(self.env_name, observation), reward[0], all(done)
        # html_obs = observation[0]._dom_elements
        return html_obs, reward, done
    
    # Note down the history of interactions with the environment
    def add_to_history(self, action, observation):
        self.interaction_history['actions'].append(action)
        self.interaction_history['observations'].append(observation)
        
    # Get an observation from the environment after performing an action, and add it to the history
    def observation(self, action):
        # add some sleep time between each step.
        time.sleep(0.3)
        html_obs, reward, done = self.interact_with_env(action)
        self.current_state = html_obs
        self.add_to_history(action, html_obs)
        print(f'Act: {action}, reward: {reward}')
        if done or reward != 0:
            if reward > 0:
                self.is_success = True
            print('Done. Success:', reward)
            return 'DONE'
        return html_obs
    
    # Here are the admissible actions:
    # Action: type a string into the input box
    # this function returns the html after the action
    def type(self, characters: str):
        action = MiniWoBType(characters)
        return self.observation(action)

    # Actions: press a key on the keyboard, the input can be one of the following:
    # enter, space, arrow_left, arrow_right, arrow_up, arrow_down, backspace
    # this function returns the html after the action
    def press_key(self, key: str):
        if key == 'enter':
            miniwob_key = '\n'
        elif key == 'space':
            miniwob_key = ' '
        elif key == 'arrow_left':
            miniwob_key = Keys.LEFT
        elif key == 'arrow_right':
            miniwob_key = Keys.RIGHT
        elif key == 'arrow_up':
            miniwob_key = Keys.UP
        elif key == 'arrow_down':
            miniwob_key = Keys.DOWN
        elif key == 'backspace':
            miniwob_key = Keys.BACKSPACE
        action = MiniWoBType(miniwob_key)
        return self.observation(action)
        
    # Action: click an option HTML element in a list with an XPath
    # this function returns the html after the action
    def click_option(self, xpath: str):
        action = MiniWoBElementClickOption(xpath)
        return self.observation(action)

    # Action: click an HTML element with an XPath
    # this function returns the html after the action
    def click_xpath(self, xpath: str):
        action = MiniWoBElementClickXpath(xpath)
        return self.observation(action)

    # Action: click an HTML element with an ID
    # this function returns the html after the action
    def click_id(self, element_id: str):
        action = MiniWoBElementClickId(element_id)
        return self.observation(action)

    # Action: move the mouse cursor on an HTML element with an XPath
    # this function returns the html after the action
    def move_mouse_on(self, xpath: str):
        action = MiniWoBMoveXpath(xpath)
        return self.observation(action)

### Closed-loop simulation

In [None]:
from prompt import get_solution_prompt_CL_with_ask, get_solution_prompt_CL_no_ask, feedback_fix_prompt_with_ask, feedback_fix_prompt_no_ask, code_check_prompt, example_list, get_start_from_prompt
import textwrap

results = {k: [] for k in tasks_with_feedback}
skills = {k: [] for k in tasks_with_feedback}

for idx, task in enumerate(tasks_with_feedback):
    env_name = task

    if env_name in tasks_with_ask:
        get_solution_prompt = get_solution_prompt_CL_with_ask
        feedback_fix_prompt = feedback_fix_prompt_no_ask
    else:
        get_solution_prompt = get_solution_prompt_CL_no_ask
        feedback_fix_prompt = feedback_fix_prompt_with_ask

    print(f'Task: {env_name}')
    agent = Agent(env_name)
    for episode in range(num_episodes):
        agent.reset()
        terminal_output = ''
        print(f'Task: {agent.description}')
        initial_state = str(agent.current_state)

        example = ''
        if ALLOW_DEMO:
            example = example_list[env_name]

        prompt = get_solution_prompt\
                .replace('<initial_state>', initial_state)\
                .replace('<task>', agent.description)\
                .replace('<example>', example)
        response = ask(prompt)
        
        
        # refine internally                
        solution_func = '''
        def solution(agent, start_from=1):
        <solution>
        '''.strip().replace('<solution>', response) if not response.startswith('def solution(agent, start_from=1):') else response

        solution_func = eliminate_content_after(solution_func, '# Now complete the function')
        solution_func = eliminate_content_after(solution_func, 'solution(agent)')


        if CODE_CHECK:
            prompt = code_check_prompt\
                    .replace('<solution_func>', solution_func)
            response = ask(prompt)
            answers = extract_answers(response, ['[Decision]', '[Revised code]'])

            if 'def solution' in answers[1]:
                solution_func = answers[1].strip('```')\
                    .replace('[Revised code]', '')\
                    .replace('without any other words.', '')\
                    .replace('Revised code:', '').strip()\
                    .replace('```', '')

        solution_func = solution_func.replace('agent.ask', 'ask').replace('return', '# return').replace('solution(agent)','')

        # Replace 'solution(agent)' with '' but leave 'def solution(agent)' untouched
        pattern = r'(?<!def )solution\(agent\)'
        solution_func = re.sub(pattern, '', solution_func)
        start_num = 1
        for num_try in range(NUM_TRY):
            if num_try < NUM_TRY_RESET:
                step = start_num if start_num else 1
            else:
                step = 1
            print('start_from_step:', start_num)
            # execute the solution function
            def_error = False
            try:
                print(solution_func)
                exec(solution_func)
            except Exception as e:
                error_msg = str(e)
                error_string = str(e)
                checkpoint = None
                def_error = True

            if not def_error:
                error_string, checkpoint, output_string = capture_output(solution, agent, step)
                terminal_output += output_string
                if error_string:
                    error_msg = error_string.split('\n')[4:]
                    error_msg = '\n'.join(error_msg)
                else:
                    error_msg = 'You executed the solution function successfully but the task is not completed. Please check your solution function.'
            
            start_num = None

            if agent.is_success or 'DONE' in terminal_output or 'Done. Success:' in terminal_output:
                break

            prev_solution_func = solution_func

            if 'NoneType' in error_string:
                break

            # refine based on environment feedback
            prompt = feedback_fix_prompt\
                    .replace('<solution_func>', prev_solution_func)\
                    .replace('<task>', agent.description)\
                    .replace('<feedback>', error_msg)
            response = ask(prompt)
            solution_func = '''
            def solution(agent, start_from=1):
            <solution>
            '''.strip().replace('<solution>', textwrap.indent(response, "\t")) if not response.startswith('def solution(agent, start_from=1):') else response
            solution_func = solution_func.replace('agent.ask', 'ask').replace('return', '# return').replace('solution(agent)','')

            prompt = get_start_from_prompt\
                    .replace('<previous_solution>', prev_solution_func)\
                    .replace('<revised_solution>', solution_func)
            response = ask(prompt)
            start_num = get_first_digit(response)
            # temporarily not needed
            # solution_func = solution_func.replace('def solution(agent, start_from=1):', modify_header(checkpoint))

        results[task].append(agent.is_success)
        

        save_to_file(save_to_folder, LOG_FILE, 
                    f'Task {episode+1}: {agent.env_name}\n' + \
                    agent.description + '\n' + \
                    dict_to_traj(agent.interaction_history) + '\n' + \
                    f'Success: {agent.is_success}\n')
        save_to_file(save_to_folder, LOG_FILE, f'results {results}')
        save_to_file(save_to_folder, LOG_FILE, f'------------\n')

        print(f'Task: {agent.env_name}, Episode {episode+1}, Success: {agent.is_success}')
        print('-'*20)

    agent.close()
    del agent
    print(f'Task {idx}: {env_name}, Success rate: {np.mean(results[env_name])}')
    print('='*20)
save_to_file(save_to_folder, LOG_FILE, f'Final results {results}')
save_to_file(save_to_folder, LOG_FILE, f'Final skills {skills}')
print(f'Final results {results}')