In [11]:

import sys
sys.path.insert(0, '/Users/james/Deep-Reinforcement-Learning-for-Hedging/')
import numpy as np
import torch
import matplotlib.pyplot as plt
from financial_models.asset_price_models import GBM
from financial_models.option_price_models import BSM
from torch import nn
from hedging_env_gymnasium import HedgingEnv


seed = 345
np.random.seed(seed)
torch.manual_seed(seed)

mu = 0.05
dt = 1/128
T = 1
num_steps = T/dt
s_0 = 1
strike_price = s_0
sigma = 0.15
r = 0.01

apm = GBM(mu=mu, dt=dt, s_0=s_0, sigma=sigma)
opm = BSM(strike_price=strike_price, risk_free_interest_rate=r, volatility=sigma, T=T, dt=dt)
env = HedgingEnv(asset_price_model=apm, dt=dt, T=T, num_steps=num_steps, trading_cost_para=1,
                 L=1, strike_price=strike_price, int_holdings=False, initial_holding=0, mode="PL",
                 option_price_model=opm)


In [12]:
env.step([0.8])

(array([0.8       , 1.0199794 , 0.9921875 , 0.07595538, 0.55632836],
       dtype=float32),
 -0.004622942774725858,
 False,
 False,
 {})

In [13]:
env.step([-0.5])

(array([0.3       , 1.0048479 , 0.984375  , 0.0667793 , 0.55632836],
       dtype=float32),
 -0.004636654729156796,
 False,
 False,
 {})

In [14]:
from stable_baselines3.common.env_checker import check_env
check_env(env)

In [15]:
from stable_baselines3 import SAC
from stable_baselines3.sac.policies import MlpPolicy

model = SAC(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=10000)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 128      |
|    ep_rew_mean     | -0.5     |
| time/              |          |
|    episodes        | 4        |
|    fps             | 86       |
|    time_elapsed    | 5        |
|    total_timesteps | 512      |
| train/             |          |
|    actor_loss      | -2.14    |
|    critic_loss     | 0.0272   |
|    ent_coef        | 0.884    |
|    ent_coef_loss   | -0.206   |
|    learning_rate   | 0.0003   |
|    n_updates       | 411      |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 128      |
|    ep_rew_mean     | -0.232   |
| time/              |          |
|    episodes        | 8        |
|    fps             | 77       |
|    time_elapsed    | 13       |
|    total_timesteps | 1024     |
| train/             |

<stable_baselines3.sac.sac.SAC at 0x7fba6cafff10>