# LLM Agent Environment

Components:
- Agent:
    - Action
        - Inference
        - Response parser
    - Observations
    - Training
- Environment:
    - State
    - Rewards
    - State update
    

In [3]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import re


PROMPT_TEMPLATE = """You are an agent in a resource optimization game. Your goal is to maximize the total rewards generated over your lifetime.

At each time step, you will be provided with a state description and a set of actions to choose from.

The state description consists of the following information:
- cash: the number of tokens you currently have
- investment: the number of tokens you have invested
- rewards: the total rewards you have accumulated
- previous actions: the actions you have taken so far

You will be asked to select an action to take at each time step. The action space consists of the following options:
- <|work|>: Work to generate new tokens
- <|invest|>: Invest tokens to generate tokens passively
- <|collect|>: Collect tokens by cashing out investment
- <|spend|>: Spend token to generate rewards

If you do not select a valid action, your action will be recorded as <|invalid|> and you will lose reward points.

Beware that the game may end at any random moment and your final score is based on rewards generated by spending tokens.


Below is an example game:

State:
cash=10
investment=40
rewards=20
previous actions=<|work|>, <|invest|>, <|collect|>, <|spend|>, <|work|>, <|invest|>, <|work|>

Select your action: <|work|>


Here's the state of the game you are currently playing:

State:
cash={cash}
investment={investment}
rewards={rewards}
previous actions={previous_actions}

Select your action: """


class Agent:
    def __init__(self):
        self._model = self._load_model()
        self._observations = None
        self.reset()

    def reset(self):
        self._observations = []
    
    def _load_model(self):
        model_id = "microsoft/Phi-3-mini-4k-instruct"
        torch.random.manual_seed(0)
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            device_map="cuda",
            torch_dtype="auto",
            trust_remote_code=True,
        )

        tokenizer = AutoTokenizer.from_pretrained(model_id)

        pipe = pipeline(
            "text-generation",
            model=model,
            tokenizer=tokenizer,
        )
        return pipe
    
    def _generate_prompt(self, observation):
        return PROMPT_TEMPLATE.format(
            cash=observation["cash"],
            investment=observation["investment"],
            rewards=observation["rewards"],
            previous_actions=", ".join(observation["previous_actions"])
        )

    def _inference(self, prompt):
        messages = [{"role": "system", "content": prompt},]

        generation_args = {
            "max_new_tokens": 50,
            "return_full_text": False,
            "do_sample": True,
            "temperature": 1.0,
        }
        return self._model(messages, **generation_args)[0]["generated_text"]
    
    def _parse_response(self, response):
        match = re.search(r"<\|(work|invest|collect|spend)\|>", response)
        action = match.group() if match else "<|invalid|>"
        return action

    def act(self, observation):
        prompt = self._generate_prompt(observation)
        response = self._inference(prompt); print(prompt + response)
        action = self._parse_response(response)
        return action

class Environment:
    def __init__(self):
        self.agent = Agent()
        self._state = None
        self._termination = None
        self.reset()

    def reset(self):
        if self.agent:
            self.agent.reset()
        
        self._state = {
            "cash": 0,
            "investment": 0,
            "rewards": 0,
            "previous_actions": [],
        }
        self._termination = False

    def step(self, action):
        match action:
            case "<|work|>":
                self._state["cash"] += 10
            case "<|invest|>":
                self._state["investment"] += self._state["cash"]
                self._state["cash"] = 0
            case "<|collect|>":
                self._state["cash"] += self._state["investment"]
                self._state["investment"] = 0
            case "<|spend|>":
                self._state["rewards"] += self._state["cash"]
                self._state["cash"] = 0
            case "<|invalid|>":
                if self._state["cash"] >= 10:
                    self._state["cash"] -= 10
        self._state["previous_actions"].append(action)
        self._state["investment"] *= 2
                
        if len(self._state["previous_actions"]) >= 50:
            self._termination = True

        return self._state, self._termination

    def last(self):
        return self._state, self._termination

In [None]:
env = Environment()
agent = env.agent

while True:
    observation, termination = env.last()
    if termination:
        break
    action = agent.act(observation)
    env.step(action)

env.reset()
print(f"Game ended with the stats: {observation}")