In [711]:
import sys

sys.path.append('/Users/siddarth.chaturvedi/Desktop/source/abmax_git/abmax')

from structs import *
from functions import *
import jax.numpy as jnp
import jax.random as random
import jax
from flax import struct

In [712]:
DT = 0.1 # Time step for simulation
KEY = jax.random.PRNGKey(0)

# Order params
ORDER_AGENT_TYPE = 0
MAX_NUM_ORDERS = 10
BUY_ORDER_POLARITY = jnp.array([0], dtype=jnp.int32)
SELL_ORDER_POLARITY = jnp.array([1], dtype=jnp.int32)
STARTING_PRICE = 100.0
MIN_ORDER_PRICE = 0.0
MAX_ORDER_PRICE = 1000.0
NUM_LOBS = 10

#trader params
MAX_NUM_TRADERS = 5
TRADER_AGENT_TYPE = 1

Market_create_params = Params(content={"num_traders": MAX_NUM_TRADERS, "trader_agent_type": TRADER_AGENT_TYPE, "starting_price": STARTING_PRICE, "max_num_orders": MAX_NUM_ORDERS,
                                       "order_agent_type": ORDER_AGENT_TYPE, "buy_order_polarity": BUY_ORDER_POLARITY, "sell_order_polarity": SELL_ORDER_POLARITY, "min_order_price": MIN_ORDER_PRICE, 
                                       "max_order_price": MAX_ORDER_PRICE})

In [713]:
@struct.dataclass
class Order(Agent):
    @staticmethod
    def create_agent(type: jnp.int32, params: Params, id: jnp.int32, active_state: bool, key: jax.random.PRNGKey) -> Agent:
        '''
        Create a blank order agent, assumption: active_state is false

        args:
            type: the type of the agent
            params: parameters to create the agent, should contain polarity
            id: the id of the agent
            active_state: whether the agent is active or not
            key: random key for agent creation, not used here
        returns:
            a blank order agent with the given type, params, id, and active_state
        '''
        num_shares = jnp.array([0])  # number of shares remaining in the order # dont put -1 as cumsum will not work
        remove_flag = jnp.array([0]) # flag to indicate if the order is removed
        agent_state_content = {"num_shares": num_shares, "remove_flag": remove_flag}
        agent_state = State(content=agent_state_content) 

        
        polarity = params.content["polarity"] #params.content["polarity"] # jnp.array([0]) -> buy, jnp.array([1]) ->sell
        min_order_price = params.content["min_order_price"]
        max_order_price = params.content["max_order_price"]
        price = jax.lax.cond(polarity[0] == 0, lambda _: jnp.array([min_order_price]), lambda _: jnp.array([max_order_price]), None) # blank_price, 0 for buy, 1000 for sell, makes it easy to sort orders
        trader_id = jnp.array([-1]) # id of the trader who placed the order, -1 for inactive orders

        agent_params = Params(content = {"polarity": polarity, "price": price, "trader_id": trader_id})

        agent = Order(agent_type=type, params=agent_params, id=id, active_state=active_state, state=agent_state, policy=None, key=key, age = 0.0)
        return agent
        

    @staticmethod
    def step_agent(agent: Agent, input: Signal, step_params: Params) -> Agent:
        '''
        deplete the number of shares remaining in the order by the input amount
        If the number of shares remaining is 0, set the remove_flag to 1, indicating the order is ready to be removed.
        Does nothing if the order is inactive

        args:
            agent: the order agent to be stepped
            input: the input signal containing the number of shares to remove
            step_params: parameters for the step, should contain dt (time step)
        returns:
            the updated order agent 

        '''
        def step_active_agent():
            num_shares_remove = input.content["num_shares_remove"]
            num_shares = agent.state.content["num_shares"]
            
            num_shares = jnp.maximum(num_shares - num_shares_remove, 0)
            remove_flag = jax.lax.cond(num_shares[0] == 0, lambda _: jnp.array([1]), lambda _: jnp.array([0]), None)
            agent_state_content = {"num_shares": num_shares, "remove_flag": remove_flag}
            agent_state = State(content=agent_state_content)

            dt = step_params.content["dt"]
            return agent.replace(state=agent_state, age=agent.age + dt)
        
        def step_inactive_agent():
            # Inactive agents do not change state, just return the agent as is
            return agent
        
        return jax.lax.cond(agent.active_state, lambda _: step_active_agent(), lambda _: step_inactive_agent(), None)


    
    @staticmethod
    def remove_agent(agent:Agent, remove_params:Params)->Agent:
        '''
        remove an agent by replacing it with a blank order agent based on its polarity
        arguments:
            agent: the agent to be removed
            remove_params: information about the agent to be removed, not used here
        returns:
            a blank order agent with the same polarity as the removed agent
        '''
        num_shares = jnp.array([0]) # don't put -1 as cumsum will not work
        remove_flag = jnp.array([0]) 
        agent_state_content = {"num_shares": num_shares, "remove_flag": remove_flag}
        agent_state = State(content=agent_state_content)

        polarity = agent.params.content["polarity"]
        min_order_price = remove_params.content["min_order_price"]
        max_order_price = remove_params.content["max_order_price"]
        price = jax.lax.cond(polarity[0] == 0, lambda _: jnp.array([min_order_price]), lambda _: jnp.array([max_order_price]), None)
        trader_id = jnp.array([-1]) # id of the trader who placed the order, -1 for inactive orders
        agent_params = Params(content = {"polarity": polarity, "price": price, "trader_id": trader_id})

        return agent.replace(state=agent_state, params=agent_params, active_state=0, age=0.0)

    @staticmethod
    def add_agent(agents, idx, add_params):
        '''
        add a new order at idx with the given params
        '''
        order_to_add = jax.tree_util.tree_map(lambda x: x[idx], agents)
        
        num_active_agents = add_params.content["num_active_agents"]
        param_idx = idx - num_active_agents

        trader_id = add_params.content["trader_id_list"][param_idx] # which trader has placed the order
        price = add_params.content["price_list"][param_idx] # whats the price of the order
        num_shares = add_params.content["num_shares_list"][param_idx] # how many shares in the order
        
        polarity = order_to_add.params.content["polarity"]
        
        remove_flag = jnp.array([0])
        agent_state = State(content={"num_shares": num_shares, "remove_flag": remove_flag})

        agent_params = Params(content = {"polarity": polarity, "price": price, "trader_id": trader_id})

        return order_to_add.replace(state=agent_state, params=agent_params, active_state=1, age=0.0)

In [714]:
@struct.dataclass
class Trader(Agent):
    @staticmethod
    def create_agent(type, params, id, active_state, key):
        key, *create_keys = random.split(key, 4)
        
        policy = None
        num_lobs = NUM_LOBS # this is not part of params since shapes depend on the number of lobs
        starting_price = params.content["starting_price"] # all lobs start at the same price
        belief_span = params.content["belief_span"]
        lower_price = starting_price * (1.0 - belief_span)
        upper_price = starting_price * (1.0 + belief_span)

        beliefs = jax.random.uniform(key=create_keys[0], shape=(num_lobs,), minval=lower_price, maxval=upper_price)
        cash = jax.random.uniform(key=create_keys[1], shape=(1,), minval=35000.0, maxval=450000.0)  # cash is a random value between 35000 and 45000
        shares = jax.random.randint(key=create_keys[2], shape=(num_lobs,), minval=100, maxval=1000)  # shares is a random value between 100 and 1000
        
        buy_flag = jnp.tile(False, (num_lobs,))
        buy_num_shares = jnp.tile(-1, (num_lobs,))
        buy_price = jnp.tile(0.0, (num_lobs,))


        sell_flag = jnp.tile(False, (num_lobs,))
        sell_num_shares = jnp.tile(-1, (num_lobs,))
        sell_price = jnp.tile(0.0, (num_lobs,))

        agent_state_content = {"cash": cash, "shares": shares, "beliefs": beliefs, "buy_flag": buy_flag, "buy_num_shares": buy_num_shares, 
                               "buy_price": buy_price, "sell_flag": sell_flag, "sell_num_shares": sell_num_shares, "sell_price": sell_price}
        agent_state = State(content=agent_state_content)
        agent_params = Params(content={"starting_price": starting_price, "belief_span": belief_span})

        return Trader(agent_type=type, params=agent_params, id=id, active_state=active_state, state=agent_state, policy=policy, key=key, age = 0.0)
    
    @staticmethod
    def step_agent(agent, input, step_params):
        # update cash, shares
        cash = agent.state.content["cash"]
        shares = agent.state.content["shares"]
        beliefs = agent.state.content["beliefs"]

        cash_diff = input.content["cash_diff"]
        shares_diff = input.content["shares_diff"]
        
        cash = jnp.maximum(cash + cash_diff, 0.0)  # Ensure cash and shares does not go negative
        shares = jnp.maximum(shares + shares_diff, 0)
        
        # for now just a noisy trader
        key, *dec_keys = jax.random.split(agent.key, 8)
        buy_flag = jax.random.uniform(dec_keys[0], (NUM_LOBS,), minval=0.0, maxval=1.0) < 0.5
        buy_num_shares = jax.random.randint(dec_keys[1], (NUM_LOBS,), minval=1, maxval=10)
        buy_price = beliefs - jax.random.uniform(dec_keys[2], (NUM_LOBS,), minval=-10.0, maxval=10.0)

        sell_flag = jax.random.uniform(dec_keys[3], (NUM_LOBS,), minval=0.0, maxval=1.0) < 0.5
        sell_num_shares = jax.random.randint(dec_keys[4], (NUM_LOBS,), minval=1, maxval=10)
        sell_price = beliefs + jax.random.uniform(dec_keys[5], (NUM_LOBS,), minval=-10.0, maxval=10.0)

        beliefs = beliefs + jax.random.uniform(dec_keys[6], (NUM_LOBS,), minval=-10.0, maxval=10.0)  # small noise to beliefs
        beliefs = jnp.clip(beliefs, 50.0, 1000.0)

        # update agent state
        agent_state_content = {"cash": cash, "shares": shares, "beliefs": beliefs, "buy_flag": buy_flag, 
                               "buy_num_shares": buy_num_shares, "buy_price": buy_price, 
                               "sell_flag": sell_flag, "sell_num_shares": sell_num_shares, "sell_price": sell_price}
        agent_state = State(content=agent_state_content)
        return agent.replace(state=agent_state, key=key, age=agent.age + step_params.content["dt"])

In [715]:
num_traders = 5
num_lobs = jnp.array([10,10,10,10,10])  # each trader has 10 lobs

starting_price = jnp.tile(jnp.array([100.0]), (num_traders,)).reshape(-1)  # all traders start at the same price
belief_span = jnp.tile(jnp.array([0.1]), (num_traders,)).reshape(-1)  # all traders have the same belief span
key,trader_key = random.split(KEY, 2)

trader_create_params = Params(content={"starting_price": starting_price, "belief_span": belief_span})
traders = create_agents(Trader, trader_create_params, num_traders, num_traders, TRADER_AGENT_TYPE, trader_key)
trader_set = Set(num_agents=num_traders, num_active_agents=num_traders, agents=traders, id=0, set_type=2, params=None, state=None, policy=None, key=None)


In [716]:
print(traders.state)

State(content={'beliefs': Array([[108.61428 ,  92.5028  ,  99.80765 ,  97.1813  , 107.69564 ,
        101.1341  , 102.49718 ,  96.06952 , 109.59692 ,  98.82875 ],
       [106.175575,  90.59939 , 104.215004,  93.875786,  94.43253 ,
        106.87343 ,  96.32103 ,  91.8928  ,  96.53785 ,  92.09825 ],
       [ 99.078606, 104.22937 ,  92.38129 , 105.69371 ,  97.416306,
        106.1803  ,  99.2912  , 102.30957 , 101.913124,  94.86038 ],
       [ 90.95743 , 102.67638 ,  91.09129 , 104.34966 ,  96.7095  ,
         98.512405, 104.72549 , 107.32034 , 103.72168 , 102.06541 ],
       [100.12843 , 105.41839 ,  91.09472 , 104.72369 , 103.814514,
        103.80144 ,  98.90672 ,  90.63679 , 107.52612 , 104.424736]],      dtype=float32), 'buy_flag': Array([[False, False, False, False, False, False, False, False, False,
        False],
       [False, False, False, False, False, False, False, False, False,
        False],
       [False, False, False, False, False, False, False, False, False,
        Fa

In [717]:
input = Signal(content={"cash_diff": jnp.array([100.0, 200.0, -50.0, 300.0, -100.0]), "shares_diff": jnp.array([10, -5, 20, -10, 15])})
step_params = Params(content={"dt": DT})

print("before_step")
print("buy flag:", traders.state.content["buy_flag"])
print("buy num shares:", traders.state.content["buy_num_shares"])
print("buy price:", traders.state.content["buy_price"])
print("sell flag:", traders.state.content["sell_flag"])
print("sell num shares:", traders.state.content["sell_num_shares"])
print("sell price:", traders.state.content["sell_price"])

trader_set = step_agents(Trader.step_agent, input=input, step_params=step_params, set=trader_set)

print("after_step")
print("buy flag:", trader_set.agents.state.content["buy_flag"])
print("buy num shares:", trader_set.agents.state.content["buy_num_shares"])
print("buy price:", trader_set.agents.state.content["buy_price"])
print("sell flag:", trader_set.agents.state.content["sell_flag"])
print("sell num shares:", trader_set.agents.state.content["sell_num_shares"])
print("sell price:", trader_set.agents.state.content["sell_price"])

before_step
buy flag: [[False False False False False False False False False False]
 [False False False False False False False False False False]
 [False False False False False False False False False False]
 [False False False False False False False False False False]
 [False False False False False False False False False False]]
buy num shares: [[-1 -1 -1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1]
 [-1 -1 -1 -1 -1 -1 -1 -1 -1 -1]]
buy price: [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
sell flag: [[False False False False False False False False False False]
 [False False False False False False False False False False]
 [False False False False False False False False False False]
 [False False False False False False False False False False]
 [False False False False False False Fa

In [718]:
@struct.dataclass
class LOB:
    buy_LOB: Set
    sell_LOB: Set
    price: jnp.ndarray # a history of prices, shape (num_time_steps, 1)



def match_orders(buy_orders: Order, sell_orders: Order, traders: Trader):
    # step 1 sort sell orders according to increasing price and buy orders according to decreasing price, first element is the best price
    buy_orders_sorted, b_indx = jit_sort_agents(quantity=-1*buy_orders.params.content["price"], agents=buy_orders)
    sell_orders_sorted, s_indx = jit_sort_agents(quantity=sell_orders.params.content["price"], agents=sell_orders)

    #step 2 compute the cumulative number of shares
    buy_cumulative_shares = jnp.cumsum(buy_orders_sorted.state.content["num_shares"]).reshape(-1)
    sell_cumulative_shares = jnp.cumsum(sell_orders_sorted.state.content["num_shares"]).reshape(-1)

    # step 3 calculate the transaction for each buy order
    def for_each_buy_order(buy_order, buy_cumulative_share, sell_orders, sell_cumulative_shares):
        
        #step 3.1 get max transaction price for the buy order and anything above that price is not reachable by the buy order
        transaction_mask = jnp.where(buy_order.params.content["price"] < sell_orders.params.content["price"],0,1).reshape(-1) #0-> not reachable, 1-> reachable
        
        # step 3.2 multiply the cumulative shares by the transaction mask, this will set the cumulative shares to 0 for all the unreachable orders
        sell_cumulative_shares = jnp.multiply(sell_cumulative_shares, transaction_mask) # this will set the cumulative shares to 0 for all the unreachable orders
            
        num_sell_order_shares = sell_orders.state.content["num_shares"].reshape(-1)
        num_sell_order_shares = jnp.multiply(num_sell_order_shares, transaction_mask)

        # step 3.3 now remove the share from the sell orders that will be matched by buy orders with more priority than the current buy order
        # for this we need to know how many shares have more priority than the current buy order, that = buy_cumulative_share - buy_order.state.content["num_shares"]
        num_higher_priority_shares = buy_cumulative_share - buy_order.state.content["num_shares"][0]
        
        new_sell_cumulative_shares = jnp.maximum(sell_cumulative_shares - num_higher_priority_shares, 0)
        new_num_sell_order_shares = jnp.minimum(num_sell_order_shares, new_sell_cumulative_shares)

        # step 3.5 now using the same method as above remove the shares of this buy order
        new_sell_cumulative_shares = jnp.maximum(new_sell_cumulative_shares - buy_order.state.content["num_shares"][0], 0)
        new_new_num_sell_order_shares = jnp.minimum(new_num_sell_order_shares, new_sell_cumulative_shares)

        share_change = new_num_sell_order_shares - new_new_num_sell_order_shares
        cash_change = jnp.multiply(share_change, sell_orders.params.content["price"].reshape(-1))

        return share_change, cash_change

    share_change, cash_change = jax.vmap(for_each_buy_order, in_axes=(0, 0, None, None))(buy_orders_sorted, buy_cumulative_shares, sell_orders_sorted, sell_cumulative_shares)

    # use change to caculate change in shares and cash for traders
    # change in buy and sell shares, each row is for a buy order and each column is for a sell order
    
    def for_each_trader(trader):
        buy_orders_mask = jnp.where(trader.id == buy_orders_sorted.params.content["trader_id"].reshape(-1), 1, 0)
        sell_orders_mask = jnp.where(trader.id == sell_orders_sorted.params.content["trader_id"].reshape(-1), 1, 0)

        # when buying, first mask then matrix -> mask x matrix multiplication
        cash_change_buy = jnp.sum(jnp.matmul(buy_orders_mask, cash_change))
        shares_change_buy = jnp.sum(jnp.matmul(buy_orders_mask, share_change))

        # when selling, first matrix then mask -> matrix x mask multiplication
        cash_change_sell = jnp.sum(jnp.matmul(cash_change, sell_orders_mask))
        shares_change_sell = jnp.sum(jnp.matmul(share_change, sell_orders_mask))

        # update the trader cash and shares change, when selling add cash and remove shares, when buying remove cash and add shares
        trader_cash_change = cash_change_sell - cash_change_buy
        trader_shares_change = shares_change_buy - shares_change_sell
        return trader_cash_change, trader_shares_change
    
    traders_cash_change, traders_shares_change = jax.vmap(for_each_trader)(traders)

    #traders_step_input = Signal(content={"cash_diff": traders_cash_change.reshape(-1, 1),"shares_diff": traders_shares_change.reshape(-1, 1)})
    
    buy_shares_removed = share_change.sum(axis=1).reshape(-1, 1)  # total shares removed for each buy order
    sell_shares_removed = share_change.sum(axis=0).reshape(-1, 1)  # total shares removed for each sell order

    buy_order_step_input = Signal(content={"num_shares_remove": buy_shares_removed})
    sell_order_step_input = Signal(content={"num_shares_remove": sell_shares_removed})

    
    return buy_orders_sorted, sell_orders_sorted, buy_order_step_input, sell_order_step_input, traders_cash_change, traders_shares_change


jit_match_orders = jax.jit(match_orders)


'''
def is_flag(trader, select_params):
    sell_or_buy = select_params.content["sell_or_buy"] # 0 for buy, 1 for sell
    sell_flag = trader.state.content["sell_flag"].reshape(-1)
    buy_flag = trader.state.content["buy_flag"].reshape(-1)
    return jax.lax.cond(sell_or_buy == 0, 
                            lambda _: buy_flag, 
                            lambda _: sell_flag, 
                            None)


def get_order_add_params(trader_set, buy_LOB, sell_LOB):
    buy_select_params = Params(content={"sell_or_buy": 0})  # 0 for buy, 1 for sell
    
    num_buy_orders, buy_select_indx = jit_select_agents(is_flag, buy_select_params, trader_set)
    buy_price_list = jnp.take(trader_set.agents.state.content["buy_price"], buy_select_indx).reshape(-1, 1)
    buy_num_shares_list = jnp.take(trader_set.agents.state.content["buy_num_shares"], buy_select_indx).reshape(-1, 1)
    buy_trader_id_list = jnp.take(trader_set.agents.id, buy_select_indx).reshape(-1, 1)  # trader id is the index of the trader in the LOB
    
    buy_add_params = Params(content={"price_list": buy_price_list,
                                     "num_shares_list": buy_num_shares_list, 
                                     "trader_id_list": buy_trader_id_list,
                                     "num_active_agents": buy_LOB.num_active_agents})

    sell_select_params = Params(content={"sell_or_buy": 1})  # 0 for buy, 1 for sell
    num_sell_orders, sell_select_indx = jit_select_agents(is_flag, sell_select_params, trader_set)
    
    sell_price_list = jnp.take(trader_set.agents.state.content["sell_price"], sell_select_indx).reshape(-1, 1)
    sell_num_shares_list = jnp.take(trader_set.agents.state.content["sell_num_shares"], sell_select_indx).reshape(-1, 1)
    sell_trader_id_list = jnp.take(trader_set.agents.id, sell_select_indx).reshape(-1, 1)  # trader id is the index of the trader in the LOB
    
    sell_add_params = Params(content={"price_list": sell_price_list,
                                      "num_shares_list": sell_num_shares_list,
                                      "trader_id_list": sell_trader_id_list,
                                      "num_active_agents": sell_LOB.num_active_agents})
    return buy_add_params, sell_add_params, num_buy_orders, num_sell_orders
jit_get_order_add_params = jax.jit(get_order_add_params)
'''
def select_traders(trading_flags):
    trading_flags = trading_flags.reshape(-1)
    selected_mask = jnp.where(trading_flags , 1, 0)
    sort_selected_indx = jnp.argsort(-1*selected_mask)
    num_selected = jnp.sum(selected_mask)
    return num_selected, sort_selected_indx
jit_select_traders = jax.jit(select_traders)

def get_price_list(price_list_arr, indx):
    '''
    Get the price list for the given indices
    '''
    return jnp.take(price_list_arr, indx).reshape(-1, 1)

def get_num_shares_list(num_shares_arr, indx):
    '''
    Get the number of shares list for the given indices
    '''
    return jnp.take(num_shares_arr, indx).reshape(-1, 1)

def get_order_add_params(trader_set, buy_LOB, sell_LOB):
    num_buy_orders, buy_trader_id = jax.vmap(jit_select_traders, in_axes=(1))(trader_set.agents.state.content["buy_flag"])

    buy_price_list = jax.vmap(get_price_list, in_axes=(1,0))(trader_set.agents.state.content["buy_price"], buy_trader_id)
    buy_num_shares_list = jax.vmap(get_num_shares_list, in_axes=(1,0))(trader_set.agents.state.content["buy_num_shares"], buy_trader_id)

    num_sell_orders, sell_trader_id = jax.vmap(jit_select_traders, in_axes=(1))(trader_set.agents.state.content["sell_flag"])
    
    sell_price_list = jax.vmap(get_price_list, in_axes=(1,0))(trader_set.agents.state.content["sell_price"], sell_trader_id)
    sell_num_shares_list = jax.vmap(get_num_shares_list, in_axes=(1,0))(trader_set.agents.state.content["sell_num_shares"], sell_trader_id)

    buy_add_params = Params(content={"price_list": buy_price_list,
                                     "num_shares_list": buy_num_shares_list, 
                                     "trader_id_list": buy_trader_id,
                                     "num_active_agents": buy_LOB.num_active_agents})

    sell_add_params = Params(content={"price_list": sell_price_list,
                                      "num_shares_list": sell_num_shares_list,
                                      "trader_id_list": sell_trader_id,
                                      "num_active_agents": sell_LOB.num_active_agents})
    
    return buy_add_params, sell_add_params, num_buy_orders, num_sell_orders

In [719]:
@struct.dataclass
class Market:
    LOB: LOB
    traders: Set
    
    @staticmethod
    def create_market(params, key):
        key, *subkeys = random.split(key, 3)
        
        # step 1 creating traders
        num_traders = params.content["num_traders"]
        
        starting_price = params.content["starting_price"]
        starting_price_arr = jnp.tile(starting_price, (num_traders, 1))  # shape (num_traders, 1)
        belief_span_arr = jax.random.uniform(subkeys[0], (num_traders, 1), minval=0.05, maxval=0.1).reshape(-1)  # random belief span between 0.1 and 0.5
        

        trader_create_params = Params(content={"starting_price": starting_price_arr, "belief_span": belief_span_arr})

        traders = create_agents(Trader, trader_create_params, num_traders, num_traders, params.content["trader_agent_type"], subkeys[1])
        trader_set = Set(num_agents=num_traders, num_active_agents=num_traders, agents=traders, id=0, set_type=2,
                 params=None, state=None, policy=None, key=None)
        
        #step 2 create buy and sell LOBs
        num_lobs = NUM_LOBS
        key, *lob_keys = random.split(subkeys[1], num_lobs + 1)
        lob_keys = jnp.array(lob_keys)    
        
        def create_lobs(params, key):
            starting_price = params.content["starting_price"]
            max_num_orders = params.content["max_num_orders"]
            
            min_order_price = params.content["min_order_price"]
            max_order_price = params.content["max_order_price"]
            min_order_price_arr = jnp.tile(min_order_price, (max_num_orders, ))
            max_order_price_arr = jnp.tile(max_order_price, (max_num_orders, ))
            
            buy_polarity_arr = jnp.tile(params.content["buy_order_polarity"], (max_num_orders, 1))

            buy_create_params = Params(content={"polarity": buy_polarity_arr, "min_order_price": min_order_price_arr, "max_order_price": max_order_price_arr})
            num_active_buy_orders = 0

            # create buy orders
            buy_orders = create_agents(Order, buy_create_params, max_num_orders, num_active_buy_orders, params.content["order_agent_type"], key)

            buy_LOB = Set(num_agents=max_num_orders, num_active_agents=0, agents=buy_orders, id=0, set_type=0, params=None, state=None, policy=None, key=None)
            
            #create sell orders
            sell_polarity_arr = jnp.tile(params.content["sell_order_polarity"], (max_num_orders, 1))
            sell_create_params = Params(content={"polarity": sell_polarity_arr, "min_order_price": min_order_price_arr, "max_order_price": max_order_price_arr})
            
            num_active_sell_orders = 0
            sell_orders = create_agents(Order, sell_create_params, max_num_orders, num_active_sell_orders, params.content["order_agent_type"], key)
            sell_LOB = Set(num_agents=max_num_orders, num_active_agents=0, agents=sell_orders, id=0, set_type=1, params=None, state=None, policy=None, key=None)

            return LOB(buy_LOB=buy_LOB, sell_LOB=sell_LOB, price=starting_price)
        lobs = jax.vmap(create_lobs, in_axes=(None, 0))(params, lob_keys)
        return Market(LOB=lobs, traders=trader_set)
    
    @staticmethod
    def get_empty_order_mask(orders:Agent):
        return orders.state.content["remove_flag"].reshape(-1)
    
    @staticmethod
    def get_lob_price(buy_lob, sell_lob):
        #return jnp.mean(jnp.array([buy_lob.agents.params.content["price"][0], sell_lob.agents.params.content["price"][0]]), axis=0).reshape(-1, 1)
        min_sell_price = jnp.min(sell_lob.agents.params.content["price"])
        max_buy_price = jnp.max(buy_lob.agents.params.content["price"])
        return (min_sell_price + max_buy_price) / 2.0
        
    @staticmethod
    #@jax.jit
    def step_market(market, _t):
        # step 1: match orders in lobs
        buy_orders, sell_orders, buy_order_step_input, sell_order_step_input, traders_cash_change, traders_shares_change  = jax.vmap(jit_match_orders,in_axes=(0,0,None))(market.LOB.buy_LOB.agents, market.LOB.sell_LOB.agents, market.traders.agents)
        buy_lobs = market.LOB.buy_LOB.replace(agents=buy_orders)
        sell_lobs = market.LOB.sell_LOB.replace(agents=sell_orders)
        lob_prices = jax.vmap(Market.get_lob_price)(buy_lobs, sell_lobs)

        # step 2: step the trader agents
        traders_cash_change = jnp.sum(traders_cash_change, axis=0).reshape(-1, 1)
        traders_shares_change = jnp.transpose(traders_shares_change)  # shape (num_traders, num_lobs)
        
        traders_step_input = Signal(content={"cash_diff": traders_cash_change,"shares_diff": traders_shares_change})
        step_params = Params(content={"dt": DT})
        trader_set = jit_step_agents(Trader.step_agent, input=traders_step_input, step_params=step_params, set=market.traders)

        #step 3 step orders
        buy_lobs = jax.vmap(jit_step_agents, in_axes=(None, None, 0, 0))(Order.step_agent, step_params, buy_order_step_input, buy_lobs)
        sell_lobs = jax.vmap(jit_step_agents, in_axes=(None, None, 0, 0))(Order.step_agent, step_params, sell_order_step_input, sell_lobs)


        # step 5: add new orders to the LOBs 
        buy_add_params, sell_add_params, num_buy_orders, num_sell_orders = get_order_add_params(trader_set, buy_lobs, sell_lobs)
        print("trader buy information")
        print(trader_set.agents.state.content["buy_flag"].reshape(NUM_LOBS, -1))
        print("num buy orders to add", num_buy_orders)
        
        
        buy_lobs = jax.vmap(jit_add_agents, in_axes=(None, 0, 0, 0))(Order.add_agent, buy_add_params, num_buy_orders, buy_lobs)
        sell_lobs = jax.vmap(jit_add_agents, in_axes=(None, 0, 0, 0))(Order.add_agent, sell_add_params, num_sell_orders, sell_lobs)
        
        #print("sell lobs")
        #print(sell_lobs.agents.state.content["num_shares"].reshape(MAX_NUM_ORDERS, -1))

        # step 4: remove orders that are ready to be removed
        '''
        buy_orders_remove_mask = jax.vmap(Market.get_empty_order_mask)(buy_lobs.agents)
        buy_remove_mask_params = Params(content={"set_mask": buy_orders_remove_mask})
        sell_orders_remove_mask = jax.vmap(Market.get_empty_order_mask)(sell_lobs.agents)
        sell_remove_mask_params = Params(content={"set_mask": sell_orders_remove_mask})
        remove_params = Params(content={"min_order_price": MIN_ORDER_PRICE, "max_order_price": MAX_ORDER_PRICE})
        buy_lobs = jax.vmap(jit_set_agents_mask, in_axes=(None, None, 0, None, 0))(Order.remove_agent, remove_params, buy_remove_mask_params, -1, buy_lobs)
        sell_lobs = jax.vmap(jit_set_agents_mask, in_axes=(None, None, 0, None, 0))(Order.remove_agent, remove_params, sell_remove_mask_params, -1, sell_lobs)

        '''

        # step 6: update the market state
        LOB = market.LOB.replace(buy_LOB=buy_lobs, sell_LOB=sell_lobs, price=lob_prices)
        return market.replace(LOB=LOB, traders=trader_set), lob_prices
        


In [720]:
market = Market.create_market(Market_create_params, KEY)

In [721]:
print(market.traders.agents.state.content["beliefs"])

[[103.713326 103.090904  99.99444   94.294525  97.19464  104.16191
   98.39724  100.52996  102.1637   104.98427 ]
 [106.27392  106.16206  101.98166  102.82131  104.619354 106.01331
  107.11544  100.88893   97.72069   93.25188 ]
 [102.376434 102.78041  106.69833   99.372536  96.25616  103.700836
   96.51273  105.860916 103.64505  100.23264 ]
 [101.600624 100.66803   96.22579  102.34541  102.32601   98.08268
   98.22071  100.08997   95.79564  104.20527 ]
 [ 99.519226  92.33847  103.327     98.34014   96.77655   98.92415
  100.39056  101.00722  101.17503   93.60669 ]]


In [653]:
market, lob_prices = Market.step_market(market, 0)
'''
print("lob prices after step:", lob_prices)
print("active market orders:")
print("buy orders")
print(market.LOB.buy_LOB.agents.state.content["num_shares"].reshape(MAX_NUM_ORDERS, -1))
print(market.LOB.buy_LOB.agents.params.content["price"].reshape(MAX_NUM_ORDERS, -1))
print("sell orders")
print(market.LOB.sell_LOB.agents.state.content["num_shares"].reshape(MAX_NUM_ORDERS, -1))
print(market.LOB.sell_LOB.agents.params.content["price"].reshape(MAX_NUM_ORDERS, -1))
print("traders beliefs after step:")
print(market.traders.agents.state.content["beliefs"])
'''


trader buy information
buy order flag shape (5, 10)
[[ True False False False False]
 [ True  True  True False False]
 [False False  True False False]
 [ True False  True  True  True]
 [False False  True  True False]
 [False False False  True  True]
 [ True False False False  True]
 [False False  True  True  True]
 [ True False False  True  True]
 [False False False  True False]]
num buy orders to add [3 0 2 2 2 2 1 3 4 3]


'\nprint("lob prices after step:", lob_prices)\nprint("active market orders:")\nprint("buy orders")\nprint(market.LOB.buy_LOB.agents.state.content["num_shares"].reshape(MAX_NUM_ORDERS, -1))\nprint(market.LOB.buy_LOB.agents.params.content["price"].reshape(MAX_NUM_ORDERS, -1))\nprint("sell orders")\nprint(market.LOB.sell_LOB.agents.state.content["num_shares"].reshape(MAX_NUM_ORDERS, -1))\nprint(market.LOB.sell_LOB.agents.params.content["price"].reshape(MAX_NUM_ORDERS, -1))\nprint("traders beliefs after step:")\nprint(market.traders.agents.state.content["beliefs"])\n'

In [92]:
print(market.LOB.price.shape)  # should be (num_lobs, 1)

(10,)


In [93]:
buy_orders, sell_orders, buy_order_step_input, sell_order_step_input, traders_cash_change, traders_shares_change  = jax.vmap(jit_match_orders,in_axes=(0,0,None))(market.LOB.buy_LOB.agents, market.LOB.sell_LOB.agents, market.traders.agents)
traders_cash_change = jnp.sum(traders_cash_change, axis=0).reshape(-1, 1)
traders_shares_change = jnp.transpose(traders_shares_change)  # shape (num_traders, num_lobs)
traders_step_input = Signal(content={"cash_diff": traders_cash_change,"shares_diff": traders_shares_change})


In [94]:
input = Signal(content={"cash_diff": traders_cash_change, "shares_diff": traders_shares_change})
step_params = Params(content={"dt": DT})
trader_set = step_agents(Trader.step_agent, input=input, step_params=step_params, set=market.traders)



In [95]:
print(trader_set.agents.state.content["buy_flag"])
print(trader_set.agents.state.content["buy_num_shares"])

buy_add_params, sell_add_params, num_buy_orders, num_sell_orders = get_order_add_params(trader_set, market.LOB.buy_LOB, market.LOB.sell_LOB)



#add orders to the LOBs
print("num active orders in buy LOB before adding:", market.LOB.buy_LOB.num_active_agents)
print("buy lobs price:", market.LOB.buy_LOB.agents.params.content["price"].shape)
new_buy_lobs = jax.vmap(jit_add_agents, in_axes=(None, 0, 0, 0))(Order.add_agent, buy_add_params, num_buy_orders, market.LOB.buy_LOB)
new_sell_lobs = jax.vmap(jit_add_agents, in_axes=(None, 0, 0, 0))(Order.add_agent, sell_add_params, num_sell_orders, market.LOB.sell_LOB)
print("num active orders in buy LOB after adding:", new_buy_lobs.num_active_agents)
print("buy lobs num shares:", new_buy_lobs.agents.state.content["num_shares"].reshape(10, -1))


[[ True False False False False  True  True  True False False]
 [False False  True False False  True False  True  True  True]
 [False False  True  True False False False False  True  True]
 [ True False False False  True False False  True  True  True]
 [ True False False  True  True False False False  True False]]
[[8 8 9 5 3 1 8 8 1 2]
 [3 5 8 9 3 9 6 7 7 2]
 [4 1 1 1 1 8 4 8 2 9]
 [5 7 6 1 5 4 3 1 4 3]
 [6 1 8 7 9 7 6 4 1 1]]
num active orders in buy LOB before adding: [0 0 0 0 0 0 0 0 0 0]
buy lobs price: (10, 10, 1)
num active orders in buy LOB after adding: [3 0 2 2 2 2 1 3 4 3]
buy lobs num shares: [[8 5 6 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [8 1 0 0 0 0 0 0 0 0]
 [1 7 0 0 0 0 0 0 0 0]
 [5 9 0 0 0 0 0 0 0 0]
 [1 9 0 0 0 0 0 0 0 0]
 [8 0 0 0 0 0 0 0 0 0]
 [8 7 1 0 0 0 0 0 0 0]
 [7 2 4 1 0 0 0 0 0 0]
 [2 9 3 0 0 0 0 0 0 0]]


In [96]:
buy_orders, sell_orders, buy_order_step_input, sell_order_step_input, traders_cash_change, traders_shares_change  = jax.vmap(jit_match_orders,in_axes=(0,0,None))(new_buy_lobs.agents, new_sell_lobs.agents, market.traders.agents)
print("buy orders before matching:", buy_orders.state.content["num_shares"].reshape(10, -1))
traders_shares_change = jnp.transpose(traders_shares_change)  # shape (num_traders, num_lobs)
#print("traders share change:", traders_shares_change)
print("buy_order_step_input:", buy_order_step_input.content["num_shares_remove"].reshape(10, -1))

new_buy_lobs = new_buy_lobs.replace(agents=buy_orders)

order_step_params = Params(content={"dt": DT})
buy_LOBs = jax.vmap(jit_step_agents, in_axes=(None, None, 0, 0))(Order.step_agent, order_step_params, buy_order_step_input, new_buy_lobs)

print("buy orders after matching:", buy_LOBs.agents.state.content["num_shares"].reshape(10, -1))



buy orders before matching: [[8 5 6 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [1 8 0 0 0 0 0 0 0 0]
 [7 1 0 0 0 0 0 0 0 0]
 [5 9 0 0 0 0 0 0 0 0]
 [9 1 0 0 0 0 0 0 0 0]
 [8 0 0 0 0 0 0 0 0 0]
 [8 7 1 0 0 0 0 0 0 0]
 [2 1 4 7 0 0 0 0 0 0]
 [3 9 2 0 0 0 0 0 0 0]]
buy_order_step_input: [[8 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [1 8 0 0 0 0 0 0 0 0]
 [7 1 0 0 0 0 0 0 0 0]
 [4 0 0 0 0 0 0 0 0 0]
 [9 0 0 0 0 0 0 0 0 0]
 [8 0 0 0 0 0 0 0 0 0]
 [7 0 0 0 0 0 0 0 0 0]
 [2 0 0 0 0 0 0 0 0 0]
 [3 0 0 0 0 0 0 0 0 0]]
buy orders after matching: [[0 5 6 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [1 9 0 0 0 0 0 0 0 0]
 [0 1 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0]
 [1 7 1 0 0 0 0 0 0 0]
 [0 1 4 7 0 0 0 0 0 0]
 [0 9 2 0 0 0 0 0 0 0]]


In [97]:
def get_empty_order_mask(orders:Agent):
        return orders.state.content["remove_flag"].reshape(-1)

buy_orders_remove_mask = jax.vmap(get_empty_order_mask)(buy_LOBs.agents)
mask_params = Params(content={"set_mask": buy_orders_remove_mask})
remove_params = Params(content={"min_order_price": MIN_ORDER_PRICE, "max_order_price": MAX_ORDER_PRICE})

print("buy orders before removing:", buy_LOBs.agents.params.content["price"].reshape(10, -1))
buy_remove_lob = jax.vmap(jit_set_agents_mask, in_axes=(None, None, 0, None, 0))(Order.remove_agent, remove_params, mask_params, -1, buy_LOBs)
print("buy orders after removing:", buy_remove_lob.agents.params.content["price"].reshape(10, -1))

buy orders before removing: [[113.033745 106.18571   98.33305    0.         0.         0.
    0.         0.         0.         0.      ]
 [  0.         0.         0.         0.         0.         0.
    0.         0.         0.         0.      ]
 [121.213486 109.60024    0.         0.         0.         0.
    0.         0.         0.         0.      ]
 [ 93.99754   92.91663    0.         0.         0.         0.
    0.         0.         0.         0.      ]
 [108.11291   83.891205   0.         0.         0.         0.
    0.         0.         0.         0.      ]
 [126.497894 109.904366   0.         0.         0.         0.
    0.         0.         0.         0.      ]
 [ 93.99642    0.         0.         0.         0.         0.
    0.         0.         0.         0.      ]
 [102.9835   100.24437   98.986626   0.         0.         0.
    0.         0.         0.         0.      ]
 [109.57571  102.72777   90.197685  87.7293     0.         0.
    0.         0.         0.         0