# Trading Bot Advanced

Pip installation

In [1]:
!pip install shimmy
!pip install stable_baselines3
!pip install gym
!pip install torch

Collecting shimmy
  Downloading Shimmy-2.0.0-py3-none-any.whl.metadata (3.5 kB)
Collecting gymnasium>=1.0.0a1 (from shimmy)
  Downloading gymnasium-1.0.0-py3-none-any.whl.metadata (9.5 kB)
Collecting farama-notifications>=0.0.1 (from gymnasium>=1.0.0a1->shimmy)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)
Downloading Shimmy-2.0.0-py3-none-any.whl (30 kB)
Downloading gymnasium-1.0.0-py3-none-any.whl (958 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m958.1/958.1 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Installing collected packages: farama-notifications, gymnasium, shimmy
Successfully installed farama-notifications-0.0.4 gymnasium-1.0.0 shimmy-2.0.0
Collecting stable_baselines3
  Downloading stable_baselines3-2.3.2-py3-none-any.whl.metadata (5.1 kB)
Collecting gymnasium<0.30,>=0.28.1 (from stable_baselines3)
  Downloading gymnasium-0.29.1-py3-none-any.whl.meta

Librairies

# Part 1: Data

Librairies.

In [2]:
import os
from PIL import Image
import numpy as np

Mount GDrive.

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


Import images and prices.

In [4]:
def load_images_and_prices(folder):
    images = {}
    prices = []
    # ---> for loop to extract files and prices
    for filename in os.listdir(folder):
        if filename.endswith(".png"):
            parts = filename.split('.')
            price_str = parts[1]  # assuming the price is in the third part
            price = float(price_str)
            prices.append(price)
            img_path = os.path.join(folder, filename)
            img = Image.open(img_path).convert('L')
            img = img.resize((128, 128))  # resize.
            img_array = np.array(img, dtype=np.float32) / 255.0
            images[price] = img_array
    return images, prices

In [5]:
folder = '/content/drive/My Drive/chart_pictures'
images, prices = load_images_and_prices(folder)

Test if imported.

In [6]:
print(len(prices))

860


# Part 2: TradingEnv.

Librairies.

In [7]:
import gym
from gym import spaces
import torch as th
import torch.nn as nn
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

  from jax import xla_computation as _xla_computation


In [11]:
total_balance = []

class TradingEnv(gym.Env):
    def __init__(self, prices, images):
        super(TradingEnv, self).__init__()
        self.prices = prices
        self.images = images
        self.current_step = 0
        self.balance = 1000.0
        self.equity = 0.0
        self.stock_price = None
        self.total_steps = len(prices)
        self.equity_history = []

        # define action space (0: Enter Long, 1: Close Long, 2: Pass)
        self.action_space = spaces.Discrete(3)

        # observation space: [current_price, current_step, chart_image]
        self.observation_space = spaces.Dict({
            'tabular': spaces.Box(low=np.array([0, 0]), high=np.array([np.inf, self.total_steps]), dtype=np.float32),
            'image': spaces.Box(low=0, high=1, shape=(128, 128), dtype=np.float32)
        })

    def reset(self):
        print(self.balance)
        total_balance.append(self.balance)
        self.current_step = 0
        self.balance = 1000.0
        self.equity = 0.0
        self.stock_price = None
        self.equity_history = []
        return self._get_observation()

    def _get_observation(self):
        if self.current_step >= self.total_steps:
            self.current_step = self.total_steps - 1
        current_price = self.prices[self.current_step]
        image = self.images[current_price]
        tabular_data = np.array([current_price, self.current_step], dtype=np.float32)
        return {'tabular': tabular_data, 'image': image}

    def step(self, action):
        current_price = self.prices[self.current_step]
        reward = 0

        if action == 0:  # enter long
            if self.stock_price is None:
                self.stock_price = current_price

        elif action == 1:  # enter short
            if self.stock_price is not None:
                profit = current_price - self.stock_price
                self.balance += profit
                self.stock_price = None
                self.equity += profit

        elif action == 2:  # pass
            pass

        # save equity for each step
        self.equity_history.append(self.equity)

        # proceed to next step
        self.current_step += 1
        done = self.current_step >= self.total_steps

        if done:
            reward = self.balance

        return self._get_observation(), reward, done, {}

    def render(self, mode='human'):
            if mode == 'human':
                print(f'Step: {self.current_step}, Price: {self.prices[self.current_step]}, Equity: {self.balance}')
            elif mode == 'rgb_array':
                # implement rendering to return an RGB array if needed
                pass
            else:
                super().render(mode=mode)  # fallback to default Gym render

In [9]:
class CustomCombinedExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.spaces.Dict):
        super(CustomCombinedExtractor, self).__init__(observation_space, features_dim=1)

        # extractors for the image data
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten()
        )

        # calculate the output size of the CNN
        n_flatten = self.cnn(th.zeros(1, 1, 128, 128)).shape[1]

        # extractors for the tabular data
        self.tabular = nn.Sequential(
            nn.Linear(2, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU()
        )

        # output layer to combine the extracted features
        self.linear = nn.Sequential(
            nn.Linear(n_flatten + 64, 256),
            nn.ReLU()
        )

        self._features_dim = 256

    def forward(self, observations):
        image = observations['image'].unsqueeze(1)
        tabular = observations['tabular']

        # extract features
        image_features = self.cnn(image)
        tabular_features = self.tabular(tabular)

        # concatenate and process the combined features
        combined_features = th.cat((image_features, tabular_features), dim=1)
        return self.linear(combined_features)

# Part 3: Training the agent.

In [12]:
# create the environment
env = TradingEnv(prices, images)

# define the policy network architecture
policy_kwargs = dict(
    features_extractor_class=CustomCombinedExtractor,
    features_extractor_kwargs=dict(),
)

# instantiate the agent
model = PPO("MultiInputPolicy", env, policy_kwargs=policy_kwargs, verbose=1)

# train the agent
model.learn(total_timesteps=1000)

# save the model
model.save("ppo_trading_model_combined")

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.




1000.0
-1059.0
1524.0
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 860      |
|    ep_rew_mean     | 232      |
| time/              |          |
|    fps             | 95       |
|    iterations      | 1        |
|    time_elapsed    | 21       |
|    total_timesteps | 2048     |
---------------------------------


# Part 4: Testing the agent.

In [13]:
# load the trained model
model = PPO.load("ppo_trading_model_combined")

# initialize environment
env = TradingEnv(prices, images)  # Ensure prices and images are defined

# reset environment
obs = env.reset()
done = False

# evaluation loop
while not done:
    try:
        action, _states = model.predict(obs, deterministic=False)  # use deterministic=True for evaluation
        obs, reward, done, info = env.step(action)

        # debugging output
        print(f"Step: {env.current_step}, Action: {action}, Reward: {reward}, Done: {done}")
        env.render()  # assuming render method prints the current state or logs it
    except Exception as e:
        print(f"Error during evaluation step: {e}")
        break  # exit the loop if an error occurs during evaluation

# after evaluation, print equity history if available
if hasattr(env, 'equity_history'):
    print(env.equity_history)


1000.0
Step: 1, Action: 1, Reward: 0, Done: False
Step: 1, Price: 107042.0, Equity: 1000.0
Step: 2, Action: 0, Reward: 0, Done: False
Step: 2, Price: 107022.0, Equity: 1000.0
Step: 3, Action: 1, Reward: 0, Done: False
Step: 3, Price: 107038.0, Equity: 980.0
Step: 4, Action: 1, Reward: 0, Done: False
Step: 4, Price: 107175.0, Equity: 980.0
Step: 5, Action: 0, Reward: 0, Done: False
Step: 5, Price: 107154.0, Equity: 980.0
Step: 6, Action: 0, Reward: 0, Done: False
Step: 6, Price: 107298.0, Equity: 980.0
Step: 7, Action: 1, Reward: 0, Done: False
Step: 7, Price: 107302.0, Equity: 1103.0
Step: 8, Action: 2, Reward: 0, Done: False
Step: 8, Price: 107219.0, Equity: 1103.0
Step: 9, Action: 0, Reward: 0, Done: False
Step: 9, Price: 107245.0, Equity: 1103.0
Step: 10, Action: 0, Reward: 0, Done: False
Step: 10, Price: 107011.0, Equity: 1103.0
Step: 11, Action: 0, Reward: 0, Done: False
Step: 11, Price: 107193.0, Equity: 1103.0
Step: 12, Action: 0, Reward: 0, Done: False
Step: 12, Price: 107166.0

In [14]:
print(total_balance)

[1000.0, -1059.0, 1524.0, 1000.0]
