In [None]:
import os


In [None]:
agent_dir = "/Users/lorecampa/Desktop/Projects/ICAIF24-challenge/agents"

ppo_agent_dir = f'{agent_dir}/ppo/first_trial'
agent_paths = [os.path.join(ppo_agent_dir, agent) for agent in os.listdir(ppo_agent_dir)]
assert all([os.path.exists(agent_path) for agent_path in agent_paths]), "Some agent paths do not exist"

AGENTS = [{"type": "ppo", "file": agent_path} for agent_path in agent_paths]


In [None]:
import glob


policy_list = glob.glob(f'{agent_dir}/fqi/trial_2_window_stap_gap_2/Policy_iter3.pkl')
# fqi_policy = policy_list[0].split(agent_dir+'/')[1]
fqi_policy = policy_list[0]
AGENTS.append({"type": "fqi", "file": fqi_policy})
AGENTS

In [None]:
from agent.factory import AgentsFactory


agents = [AgentsFactory.load_agent(agent) for agent in AGENTS]
for agent in agents:
    agent.load()
agents


In [1]:
from agent.base import AgentBase
from erl_config import build_env
import torch as th
from trade_simulator import EvalTradeSimulator
import numpy as np


def evaluate_agent(agent: AgentBase, eval_env, eval_sequential: bool = False, verbose: int = 0):
    device = th.device("cpu")

    seed = eval_env.seed
    num_eval_sims = eval_env.num_sims

    state, _ = eval_env.reset(seed=seed, eval_sequential=eval_sequential)
    
    total_reward = th.zeros(num_eval_sims, dtype=th.float32, device=device)
    rewards = th.empty((0, num_eval_sims), dtype=th.float32, device=device)
    
        
    for i in range(eval_env.max_step):
        
        action = agent.action(state)
        # action = th.tensor(action).to(device)
        action = th.from_numpy(action).to(device)            
        state, reward, terminated, truncated, _ = eval_env.step(action=action)
        
        rewards = th.cat((rewards, reward.unsqueeze(0)), dim=0)
            
        total_reward += reward

        if terminated.any() or truncated:
            break
    
    
    mean_total_reward = total_reward.mean().item()
    std_simulations = total_reward.std().item() if num_eval_sims > 1 else 0.
    mean_std_steps = rewards.std(dim=0).mean().item()
    
    if verbose:
        print(f'Sims mean: {mean_total_reward} Sims std: {std_simulations}, Mean std steps: {mean_std_steps}')
    
    
    return mean_total_reward, std_simulations, mean_std_steps


In [8]:
from agent.base import AgentBase
from agent.base import AgentBase
from erl_config import build_env
import torch as th
from trade_simulator import EvalTradeSimulator
import numpy as np
import os
from agent.factory import AgentsFactory


def model_selection(agent_path: str, num_sims: int, args, eval_sequential: bool = False):
    eval_env_args = args.copy()
    eval_env_args["num_envs"] = 1
    eval_env_args["num_sims"] = num_sims
    eval_env_args["eval_sequential"] = eval_sequential
    eval_env_args["env_class"] = EvalTradeSimulator
    
    agent_file_names = [x for x in os.listdir(agent_path) if x.split('_')[0] in ['ppo', 'fqi', 'dqn']]
    
    print(f'All found agents: {agent_file_names}')
    results = {}
    for w in range(1, 8):
        curr_agents = [a for a in agent_file_names if f'_w{w-1}.' in a]
        
        curr_eval_env_args = eval_env_args.copy()
        curr_eval_env_args["days"] = [w + 7, w + 7]
        eval_env = build_env(curr_eval_env_args["env_class"], curr_eval_env_args, gpu_id=-1)
        
        results[w] = {
            "agents": [],
            "mean_total_rewards": [],
            "std_simulations": []
        }
        for agent_file in curr_agents:
            agent_type = agent_file.split('_')[0]
            agent = AgentsFactory.load_agent({"type": agent_type, "file": os.path.join(agent_path, agent_file)})
            print(agent_file, w)
            mean_total_reward, std_simulations, mean_std_steps = evaluate_agent(agent, eval_env, eval_sequential, verbose=1)
            results[w]["agents"].append(agent_file)
            results[w]["mean_total_rewards"].append(mean_total_reward)
            results[w]["std_simulations"].append(std_simulations)
            # print(f'Agent: {agent_file} Mean Total Reward: {mean_total_reward} Std Simulations: {std_simulations} Mean std steps: {mean_std_steps}')
        if len(results[w]["agents"]) > 0:
            best_idx = np.argmax(results[w]["mean_total_rewards"])
            results[w]["best_agent"] = results[w]["agents"][best_idx]
            results[w]["best_mean_total_reward"] = results[w]["mean_total_rewards"][best_idx]

    return results


eval_env_args = {
    "env_name": "TradeSimulator-v0",
    "num_envs": 1,
    "num_sims": 50,
    "state_dim": 10,
    "action_dim": 3,
    "if_discrete": True,
    "max_position": 1,
    "slippage": 7e-7,
    "step_gap": 2,
    "eval_sequential": False,
    "eval": True,
    "env_class": EvalTradeSimulator,
    "max_step": 480
} 
results = model_selection("/Users/lorecampa/Desktop/Projects/ICAIF24-challenge/agents_new", 5, eval_env_args)
results

All found agents: ['ppo_w4.zip', 'fqi_w2.pkl', 'dqn_w6.zip']
fqi_w2.pkl 3
Sims mean: 639.0226440429688 Sims std: 210.58241271972656, Mean std steps: 6.3891777992248535
ppo_w4.zip 5
Sims mean: 388.995849609375 Sims std: 121.4652328491211, Mean std steps: 4.251712643405026
dqn_w6.zip 7
Sims mean: 1322.0654296875 Sims std: 478.4322204589844, Mean std steps: 10.19191994090365


{1: {'agents': [], 'mean_total_rewards': [], 'std_simulations': []},
 2: {'agents': [], 'mean_total_rewards': [], 'std_simulations': []},
 3: {'agents': ['fqi_w2.pkl'],
  'mean_total_rewards': [639.0226440429688],
  'std_simulations': [210.58241271972656],
  'best_agent': 'fqi_w2.pkl',
  'best_mean_total_reward': 639.0226440429688},
 4: {'agents': [], 'mean_total_rewards': [], 'std_simulations': []},
 5: {'agents': ['ppo_w4.zip'],
  'mean_total_rewards': [388.995849609375],
  'std_simulations': [121.4652328491211],
  'best_agent': 'ppo_w4.zip',
  'best_mean_total_reward': 388.995849609375},
 6: {'agents': [], 'mean_total_rewards': [], 'std_simulations': []},
 7: {'agents': ['dqn_w6.zip'],
  'mean_total_rewards': [1322.0654296875],
  'std_simulations': [478.4322204589844],
  'best_agent': 'dqn_w6.zip',
  'best_mean_total_reward': 1322.0654296875}}