In [7]:
import torch
import numpy as np
import onnxruntime as ort
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from utils.fetch_data_with_indicators import fetch_data_with_indicators
from utils.fetch_data_with_indicators import Api
import os
from scipy.signal import find_peaks
from meta.peaks_env import CryptoTradingEnv
from wrappers.peaks_wrapper import RecurrentModelWrapper

num_features = 5

# Load test data and prepare environment
test_data = fetch_data_with_indicators(Api.YAHOO, 'BTC-USD', '2024-05-01', '2024-10-01', '1d', ['RSI', 'EMA_50'])
test_data = test_data.copy()
test_data['Pct Change'] = test_data['Close'].pct_change() * 100
test_data.dropna(inplace=True)

peaks, properties = find_peaks(test_data['Close'], height=100, prominence=5, distance=40)

test_data['Peak'] = 0
test_data.loc[test_data.index[peaks], 'Peak'] = 1

# Define number of features (used in observation space)
num_features = 5

eval_env = DummyVecEnv([lambda: CryptoTradingEnv(test_data)])
eval_env = VecNormalize(eval_env, norm_obs=True, norm_reward=False)

model_name = "PPO_PARTIAL_LSTM_64_1_peaks_with_fee"
# Load and prepare models
pytorch_model = PPO.load(f'./models/{model_name}', device='cpu')
wrapped_model = RecurrentModelWrapper(pytorch_model.policy)
wrapped_model.eval()

# Export to ONNX (if not already done)
obs = eval_env.reset()
prices = obs['prices']
portfolio = obs['portfolio']

# Create ONNX model if it doesn't exist
if not os.path.exists(f"onnx_exports/{model_name}.onnx"):
    torch.onnx.export(
        wrapped_model,
        (
            torch.tensor(prices, dtype=torch.float32),
            torch.tensor(portfolio, dtype=torch.float32)
        ),
        f"onnx_exports/{model_name}.onnx",
        export_params=True,
        opset_version=12,
        do_constant_folding=True,
        input_names=['prices', 'portfolio'],
        output_names=['action_logits', 'value'],
        dynamic_axes={
            'prices': {0: 'batch_size', 1: 'sequence'},
            'portfolio': {0: 'batch_size'},
            'action_logits': {0: 'batch_size'},
            'value': {0: 'batch_size'}
        }
    )
    print("Model exported to ONNX format")

# Load ONNX model
session = ort.InferenceSession(f"onnx_exports/{model_name}.onnx")

# Run comparison
done = False
state = None
total_reward = 0
step = 0

while not done:
    # PyTorch prediction
    action, state = pytorch_model.predict(obs, state=state, deterministic=True)
    
    # ONNX prediction
    ort_inputs = {
        'prices': obs['prices'].astype(np.float32),
        'portfolio': obs['portfolio'].astype(np.float32)
    }
    
    ort_outputs = session.run(None, ort_inputs)
    onnx_action_logits, onnx_value = ort_outputs
    onnx_action = np.argmax(onnx_action_logits, axis=1)
    
    # Compare predictions
    print(f"\nStep {step}:")
    print("PyTorch Action:", action)
    print("ONNX Action:", onnx_action)
    print("ONNX Action Logits:", onnx_action_logits)
    print("ONNX Value:", onnx_value)
    
    # Step environment with PyTorch action
    obs, reward, done, info = eval_env.step(action)
    total_reward += reward
    step += 1
    
    # Optional: render environment
    eval_env.render()
    
    if done:
        print(f"\nEpisode finished after {step} steps")
        print(f"Total reward: {total_reward}")

eval_env.close()

[*********************100%***********************]  1 of 1 completed




Model exported to ONNX format

Step 0:
PyTorch Action: [1]
ONNX Action: [1]
ONNX Action Logits: [[-0.00158001  0.0738294  -0.06446071]]
ONNX Value: [[0.59566706]]
Step: 11, Net Worth: 999.0, Balance: 0.0, Crypto Held: 0.015938532348002507, Last Reward: 5.57757421875, Last Action: 1, Avg Buy Price: 0, Avg Sell Price: 0

Step 1:
PyTorch Action: [2]
ONNX Action: [2]
ONNX Action Logits: [[-0.07684935 -0.00355085  0.07505915]]
ONNX Value: [[2.057636]]
Step: 12, Net Worth: 1001.7683238371936, Balance: 0.0, Crypto Held: 0.015938532348002507, Last Reward: 0.1, Last Action: 2, Avg Buy Price: 0, Avg Sell Price: 0

Step 2:
PyTorch Action: [2]
ONNX Action: [2]
ONNX Action Logits: [[-0.0659728   0.00671548  0.05684638]]
ONNX Value: [[1.5277923]]
Step: 13, Net Worth: 988.6514720538155, Balance: 0.0, Crypto Held: 0.015938532348002507, Last Reward: 0.1, Last Action: 2, Avg Buy Price: 0, Avg Sell Price: 0

Step 3:
PyTorch Action: [1]
ONNX Action: [1]
ONNX Action Logits: [[-0.15077268  0.07686314  0.060