In [1]:
from pathlib import Path

import numpy as np
from numpy.random import default_rng

from simulator.objects.market import Market
from simulator.objects.participant import Participant
from simulator.objects.policies.base_policy import BasePolicy
from simulator.objects.policies.nn_policy import NNPolicy
from simulator.objects.stock import Portfolio, Stock, StockHolding

rng = default_rng()

In [2]:
def generate_stocks(n_stocks: int) -> list[Stock]:
    output: list[Stock] = []
    for _ in range(n_stocks):
        output.append(
            Stock(
                cash=rng.uniform(-10000, 100000, size=1)[0],
                earning_value_of_assets=rng.uniform(10000, 30000, size=1)[0],
                latest_quarterly_earnings=rng.uniform(10000, 30000, size=1)[0],
                price_history=np.ones(shape=(1825,)) * rng.uniform(10, 200, size=1)[0],
                quality_of_leadership=rng.uniform(0, 1, size=1)[0],
                stock_volatility=rng.uniform(0, 0.1, size=1)[0],
            )
        )
    return output

In [3]:
def generate_portfolio(stocks: list[Stock]) -> Portfolio:
    stock_holdings: list[StockHolding] = []
    for stock in stocks:
        stock_quantity: int = rng.choice(a=np.array([0, 0, 0, 1, 2, 3]), size=1)[0]
        if stock_quantity > 0:
            stock_holdings.append(
                StockHolding(
                    stock=stock,
                    stock_quantity=stock_quantity,
                )
            )
    return Portfolio(stock_holdings=stock_holdings)

In [4]:
def generate_participants(
    n_participants: int, stock_list: list[Stock], market: Market, policy: BasePolicy
) -> list[Participant]:
    output: list[Participant] = []
    for _ in range(n_participants):
        output.append(
            Participant(
                stock_portfolio=generate_portfolio(stock_list),
                policy=policy,
                cash=3000,
            )
        )
    return output

In [5]:
N_STOCKS = 100
N_PARTICIPANTS = 50

market = Market(stocks=generate_stocks(N_STOCKS), interest_rate_apy=0.001)

In [6]:
market.add_participants(
    generate_participants(
        50,
        market.stocks,
        market,
        NNPolicy(
            market=market,
            n_stocks_to_sample=30,
            max_stocks_per_timestep=10,
            valuation_model_path=Path("model2.pt"),
            valuation_model_noise_std=0.05,
        ),
    )
)

market.add_participants(
    generate_participants(
        50,
        market.stocks,
        market,
        NNPolicy(
            market=market,
            n_stocks_to_sample=30,
            max_stocks_per_timestep=10,
            valuation_model_path=Path("model_high_prices.pt"),
            valuation_model_noise_std=0.05,
        ),
    )
)

In [7]:
for i in range(100):
    print(i)
    market.step_market()

0
Number of buy order stocks: 22
Number of sell order stocks: 40
1
Number of buy order stocks: 17
Number of sell order stocks: 36
2
Number of buy order stocks: 16
Number of sell order stocks: 33
3
Number of buy order stocks: 16
Number of sell order stocks: 34
4
Number of buy order stocks: 14
Number of sell order stocks: 34
5
Number of buy order stocks: 15
Number of sell order stocks: 34
6
Number of buy order stocks: 13
Number of sell order stocks: 34
7
Number of buy order stocks: 15
Number of sell order stocks: 32
8
Number of buy order stocks: 12
Number of sell order stocks: 35
9
Number of buy order stocks: 15
Number of sell order stocks: 34
10
Number of buy order stocks: 17
Number of sell order stocks: 37
11
Number of buy order stocks: 16
Number of sell order stocks: 34
12
Number of buy order stocks: 15
Number of sell order stocks: 36
13
Number of buy order stocks: 17
Number of sell order stocks: 34
14
Number of buy order stocks: 18
Number of sell order stocks: 35
15
Number of buy ord

In [8]:
print(market.participants[10].stock_portfolio.value_history)

[9264.20438104 9264.20438104 9264.20438104 ... 9340.47454792 9355.55000456
 9470.03447738]


In [14]:
print(market.stocks[10].price_history)
print(market.stocks[10].earning_value_of_assets)

[ 27.20356975  27.20356975  27.20356975 ... 157.48180713 176.6670472
 176.6670472 ]
28682.293123229992
