# Tutorial 1: Introduction to ABMAX
This tutorial will take us through the basics of using ABMAX, to model a simple agent based model.
We will learn how to:
- Create a collection of agents.
- Stepping agents in the collection.
- Removing agents from the collection.
- Adding agents to the collection.
- Running a simulation loop.
- Running multiple simulations in parallel.

This tutorial also assumes that we are familiar with the basics of `Jax`

## Step 1: Defining the agent
Agents in ABMAX are defined using a class that inherits from the `Agent` class.
They are also decorated with `@struct.dataclass` from `Flax` to make them similar to a jax primitive data type (something that passes through jax vmaps, jits and grads without throwing errors).

The `Agent` class must have the following simple members:
- `id`: An integer that uniquely identifies the agent.
- `agent_type`: An integer that represents the type of the agent.
- `active_state`: A boolean that represents if the agent is active or not.
- `age`: A float that represents the age of the agent in the simulation.
- `key`: A jax random key that is used to generate randomness when need be.
The `Agent` class must also have the following complex members:
- `state`: A `State` object that represents the state of the agent that evolves over time.
- `params`: A `Params` object that represents the parameters of the agent that ideally do not change over time.
- `policy`: A `Policy` object that represents the policy of the agent that determines the actions of the agent.

More information about `State`, `Params`, `Policy` and other helper classes can be found in the documentation.

To comply with the functional programming paradigm that jax requires, as well as, to include some 'OOPness', the methods of an `Agent` class are defined as static methods that take the agent as the first argument.


### An unbiased Dice agent.
Let's define an agent that represents an unbiased dice.
- At each step, an active dice agents will roll a number between 1 and 6 with equal probability while an inactive dice agents will not roll a number.
- This drawn number will be stored in the agent's state.
- We also will define what happens when a dice agent is removed from or added to a collection of agents.

In [6]:
import sys
sys.path.append('../')

from abmax.structs import *
from abmax.functions import *
import jax.numpy as jnp
import jax.random as random
import jax

@struct.dataclass
class Dice(Agent):
    @staticmethod
    def create_agent(type, params, id, active_state, key):
        key, subkey = random.split(key)
        
        def create_active_agent():
            draw = jax.random.randint(subkey, (1,), 1, 7)
            state_content = {'draw': draw}
            return State(content=state_content)
        
        def create_inactive_agent():
            state_content = {'draw': jnp.array([0])}
            return State(content=state_content)
        agent_state = jax.lax.cond(active_state, lambda _: create_active_agent(), lambda _: create_inactive_agent(), None)

        return Dice(params=params, id=id, state=agent_state, agent_type=type, key = key, policy = None, age = 0.0, 
                    active_state=active_state)
    
    @staticmethod
    def step_agent(agent, input, step_params):
        
        def step_active_agent():
            key, subkey = random.split(agent.key)
            draw = jax.random.randint(subkey, (1,), 1, 7)
            state_content = {'draw': draw}
            new_state = State(content=state_content)
            return agent.replace(state = new_state, key = key, age = agent.age + 1.0)
        
        def step_inactive_agent():
            return agent
        
        new_agent = jax.lax.cond(agent.active_state, lambda _: step_active_agent(), lambda _: step_inactive_agent(), None)
        return new_agent
    
    
    @staticmethod
    def remove_agent(agents, idx, remove_params):
        agent_to_remove = jax.tree_util.tree_map(lambda x:x[idx], agents)
        new_state_content = {'draw': 0}
        new_state = State(content = new_state_content)
        return agent_to_remove.replace(state = new_state, active_state = False, age = 0.0)
    
    @staticmethod
    def add_agent(agents, idx, add_params):
        agent_to_add = jax.tree_util.tree_map(lambda x:x[idx], agents)
        key, subkey = random.split(agent_to_add.key)
        draw = jax.random.randint(subkey, (1,), 1, 7)
        state_content = {'draw': draw}
        new_state = State(content=state_content)

        return agent_to_add.replace(state = new_state, key = key, active_state = True, age = 0.0)

Here we define 4 static methods in the `Dice` class:
- `create_agent(type, params, id, active_state, key)`: Creates a new dice agent. Note how we create active and inactive agents separately according to the agent's active state. This is needed to create an agent "zero-padding" which inturn makes sure that the total number of agents in the collection of agents is fixed. This is the only method whose name is fixed and must be defined in the agent class. About the arguments:
    - `type`: The type of the agent.
    - `params`: The parameters used to create the agent.
    - `id`: The id of the agent.
    - `active_state`: The active state of the agent.
    - `key`: The random key of the agent.

- `step_agent(agent, input, step_params)`: Tells us how should an agent step. Here it steps by sampling a number between 1 and 6 and storing it in the agent's state by updating it. Again, note how we step active and inactive agents separately. This is so that we can vmap over the agents to step them in parallel. While vmapping, the parameters `agent` and `input` are vectorized so that different agents can get different inputs, but the `step_params` is not vectorized, so that all agents get the same `step_params`. Conventionally, we can use the `policy` data structure of the agent to determine the action of the agent. Here, for simplicity, we ignore it. 

- `remove_agent(agents, idx, remove_params)`: Produces a blank agent. This is called when we want to remove an agent from the collection of agents. Here, `agents` is the collection of agents, `idx` is the index where the agent should be removed from and `remove_params` are the parameters needed to produce the blank agent. 

- `add_agent(agents, idx, add_params)`: Produces an agent that can be added. This is called when we want to add an agent to the collection. Here, `agents` is, again, the collection of agents, `idx` is the index where the agent should be added and `add_params` are the parameters needed to produce the agents that is added. 

Note: It is important to use the same signature of these methods as shown in the example for ABMAX to be able to use them and provide the right arguments. We can always assign `None` to the arguments that we do not need.


Alright, let's test the agent and agent set creation.



## Step 2: Initializing the agent and an agent set

In [7]:
num_agents = 10
num_active_agents = 5
key = random.PRNGKey(0)
key, subkey = random.split(key)
agent_type = 1
params = None

dice_agents = create_agents(Dice, params=params, num_agents=num_agents, num_active_agents=num_active_agents, agent_type=agent_type, key=subkey)
print("agent active state: ", dice_agents.active_state)
print("agent draws: ", dice_agents.state.content['draw'].reshape(-1))

agent active state:  [1 1 1 1 1 0 0 0 0 0]
agent draws:  [5 3 5 3 5 0 0 0 0 0]


So here we initialized a collection of agents (a vmapped version of the `Dice` object). This collection has in total 10 agents, of which 5 are active and 5 are inactive. 

Now lets initialize an agent set. Over a collection of agents, an agent set also keeps track of the number of active agents and total number of agents in the collection. It is an object of class `Set` that is initialized with the following arguments:
- `agents`: The collection of agents i.e. the vmapped version of the `Dice` object in this case.
- `num_agents`: An integer that represents the total number of agents in the collection.
- `num_active_agents`: An integer that represents the total number of active agents in the collection.
- `state`: A `State` object that represents the state of the agent set.
- `params`: A `Params` object that represents the parameters of the agent set.
- `policy`: A `Policy` object that represents the policy of the agent set.
- `id`: An integer that uniquely identifies the agent set.
- `set_type`: An integer that represents the type of the agent set.
- `key`: A jax random key that is used to generate randomness when need be.


The set has a `state`, a `params`, and a `policy` members of its own. This is useful when we want to track states or parameter that affect all the agents in the set. Fon now we will just initialize them with `None`. It has an `id` member that is used to uniquely identify the agent set among other agent sets. It also has a `key` member that is used to generate randomness when need be.

In [8]:
key, subkey = random.split(key)

dice_set = Set(agents = dice_agents,
               num_agents = num_agents,
               num_active_agents = num_active_agents,
               state = None,
               params = None,
               policy = None,
               id = 0,
               set_type = 1,
               key = subkey)

print("number of active agents in the set: ", dice_set.num_active_agents)

number of active agents in the set:  5


## Step 3: Stepping the agents
The `jit_step_agents` function provided by ABMAX, vmaps and jits the `step_agent` method of an agent class, defined by us. Thus, stepping through all the agents in parallel and in a jitted manner.

In [9]:
print("agent draws before step: ", dice_set.agents.state.content['draw'].reshape(-1))

dice_set = jit_step_agents(step_func = Dice.step_agent, 
                           step_params = None, 
                           input = None, 
                           set = dice_set)

print("agent draws after step:  ", dice_set.agents.state.content['draw'].reshape(-1))

agent draws before step:  [5 3 5 3 5 0 0 0 0 0]
agent draws after step:   [5 1 3 1 5 0 0 0 0 0]


The `jit_step_agents` function takes in 
- `step_func`: a `callable` that represents function that steps or updates the agent passed to it.
- `input`: a `Signal` object that represents the input to be passed to each agent.
- `step_params`: a `Params` object that represents the parameters to be passed to each agent which can be helpful in stepping the agents.
- `set`: a `Set` object that represents the set of agents that need to be stepped.

It returns the updated agent set, where the old agents are replaced by the new agents. Here, we pass the `Dice.step_agent` that we defined earlier as the `step_func`.
It vmaps the `step_func` over the agents in the `set` and the `input` but not over the `step_params`. This is so that all agents can get different inputs and the same parameters. In order to pass different parameters to different agents, we can consider making these different parameters part of the `input` signal or storing them in the `params` member of the agent.

## Step 4: Removing agents
Lets say, we want to remove all the agents that have rolled a 5 in the previous step. First, we will need to select the agents that have rolled a 5. Then we will use the `jit_remove_agents` function provided by ABMAX to remove the selected agents from the agent set.

### step 4.1: Selecting agents to remove
The `jit_select_agents` function provided by ABMAX can be used to select a subset of agents from the agent set. It gives us the indices of the selected agents and the number of selected agents, in a bit of a non-trivial way as shown ahead. But first, we need to define a `select_func` that takes in an agent and some parameters and returns a boolean value. This value is `True` if the agent is to be selected and `False` otherwise.

In [10]:
select_params = Params(content={'select_draw': 5})

def select_func(dice_agent, select_params):
    draw = jnp.reshape(dice_agent.state.content['draw'], (-1))
    return draw == select_params.content['select_draw']

#test select function
value = select_func(dice_set.agents, select_params)
print("agent draws: ", dice_set.agents.state.content['draw'].reshape(-1))
print("value: ", jnp.int16(value)) # convert to int, True = 1, False = 0


agent draws:  [5 1 3 1 5 0 0 0 0 0]
value:  [1 0 0 0 1 0 0 0 0 0]


Again, the `Params` data structure is a helper class that stores a dictionary. (More information in the documentation).

Next, we will use the `jit_select_agents` function to select the agents that have rolled a 5.

In [11]:
num_agents_selected, selected_indx = jit_select_agents(select_func = select_func, 
                                                       select_params = select_params, 
                                                       set = dice_set)

print("number of agents selected: ", num_agents_selected)
print("indices of the agents selected: ", selected_indx)

number of agents selected:  2
indices of the agents selected:  [0 4 1 2 3 5 6 7 8 9]


So the `jit_select_agents` function takes in 
- `select_func`: a `callable` that represents the function that selects the agents.
- `select_params`: a `Params` object that represents the parameters to be passed to the select function.
- `set`: a `Set` object that represents the set of agents from which certain agents need to be selected.

It returns the number of selected agents and an array of indices of the selected agents, sorted in a particular way.

So, this is the non-trivial part. The indices are sorted in such a way that the first `num_agents_selected` indices in `selected_indx` array are the indices of the selected agents and the rest are the indices of the unselected agents. This is done to maintain a constant shape of the `selected_indx` array, to comply with jax's requirements.

In the example above, the dice agents at indices 0 and 4 are selected as they have rolled a 5. Thus, also the `num_agents_selected` is 2.

### step 4.2: Removing the selected agents
Now that we have the indices of the agents that we want to remove, we can use the `jit_remove_agents` function provided by ABMAX to remove the agents from the agent set.

In [12]:
print(" agents active before remove: ", dice_set.num_active_agents)
print(" agents draws before remove: ", dice_set.agents.state.content['draw'].reshape(-1))
print("\n")

remove_params = Params(content={'remove_indx': selected_indx})

dice_set, sorted_indx = jit_remove_agents(remove_func = Dice.remove_agent, 
                                          remove_params = remove_params, 
                                          num_agents_remove = num_agents_selected, 
                                          set = dice_set)

print(" agents active after remove: ", dice_set.num_active_agents)
print(" agents draws after remove: ", dice_set.agents.state.content['draw'].reshape(-1))
print(" sorted_ids: ", sorted_indx)



 agents active before remove:  5
 agents draws before remove:  [5 1 3 1 5 0 0 0 0 0]


 agents active after remove:  3
 agents draws after remove:  [1 3 1 0 0 0 0 0 0 0]
 sorted_ids:  [1 2 3 0 4 5 6 7 8 9]


So the `jit_remove_agents` function takes in
- `remove_func`: a `callable` that represents the function that produces a blank agent.
- `remove_params`: a `Params` object that represents the parameters to be passed to the remove function. These parameters are same for all the agents that are removed and are passed to the `remove_func` internally. Note that the `remove_params` must contain a key `remove_indx` that contains the indices of the agents that are to be removed.
- `num_agents_to_remove`: an integer that represents the number of agents to be removed.

It returns the updated agent set, where the selected agents are removed and `sorted_indx`. Behind the scenes, `jit_remove_agents`

- Takes the minimum of the number of agents to be removed and the number of active agents.
- Runs a `jax.lax.fori_loop` over the number of agents to be removed.
- Updates the number of active agents in the agent set.
- Sorts the agents such that the active agents after the removal are at the beginning and the inactive agents are at the end. The `sorted_indx` is the indices of the agents in the sorted agent set. This will be useful when we want to add agents to the agent set.
- Returns the updated agent set and the sorted indices.

## Step 5: Adding agents
Now, it is time to add some agents back to the agent set. Say, we want to add as many agents as the number of agents that have drawn a 1. Again, we will first need to find the number of agents that drew a 1. Then we will use the `jit_add_agents` function to add the agents to the agent set.

### step 5.1: Selecting agents to add

In [13]:
select_params = Params(content={'select_draw': 1})

def select_func(dice_agent, select_params):
    draw = jnp.reshape(dice_agent.state.content['draw'], (-1))
    return draw == select_params.content['select_draw']

num_agents_selected, selected_indx = jit_select_agents(select_func = select_func, 
                                                       select_params = select_params, 
                                                       set = dice_set)

print("agents draws: ", dice_set.agents.state.content['draw'].reshape(-1))
print("number of agents selected: ", num_agents_selected)
print("indices of the selected agents: ", selected_indx)


agents draws:  [1 3 1 0 0 0 0 0 0 0]
number of agents selected:  2
indices of the selected agents:  [0 2 1 3 4 5 6 7 8 9]


Here, there are 2 agents that have drawn a 1 at indices 0 and 2.

### step 5.2: Adding selected agents
Now, we can use the `jit_add_agents` function to add the agents to the agent set. The added agents will be added at the end of the agent set and will be determined by our `add_agent` method in the `Dice` class. 

In [14]:
print(" agents active before add: ", dice_set.num_active_agents)

dice_set = jit_add_agents(add_func = Dice.add_agent, 
                          add_params = None, 
                          num_agents_add = num_agents_selected, 
                          set = dice_set)

print(" agents active after add: ", dice_set.num_active_agents)
print(" draws after add: ", dice_set.agents.state.content['draw'].reshape(-1))

 agents active before add:  3
 agents active after add:  5
 draws after add:  [1 3 1 5 5 0 0 0 0 0]


Similar to `jit_remove_agents`, `jit_add_agents` takes in
- `add_func`: a `callable` that represents the function that produces an agent that can be added.
- `add_params`: a `Params` object that represents the parameters to be passed to the add function. These parameters are same for all the agents that are added and are passed to the `add_func` internally.
- `num_agents_to_add`: an integer that represents the number of agents to be added.
- `set`: a `Set` object that represents the set of agents to which certain agents need to be added.

It returns the updated agent set, where the selected agents are added at the end of the agent set. Behind the scenes, `jit_add_agents`

- Takes the maximum of the number of agents to be added and the difference between the total number of agents and the number of active agents.
- Runs a `jax.lax.fori_loop` over the number of agents to be added and adds agents behind the last active agent.
- Updates the number of active agents in the agent set.
- Returns the updated agent set.

Here the sorting done in `jit_remove_agents` is useful as ABMAX knows where to add the agents in the agent set.

And now if we step the agents again we will see that all the 5 active agents draw a number between 1 and 6.

In [15]:
dice_set = jit_step_agents(Dice.step_agent, step_params=None, input=None, set=dice_set)
print("draws after step: ", dice_set.agents.state.content['draw'].reshape(-1))

draws after step:  [5 2 1 4 2 0 0 0 0 0]


## Step 6: Looping through the steps
Most simulations will require us to loop through the steps multiple times. Let's say we want to repeat the steps we went through above, 10 times. We can do this as shown below.

In [16]:
num_agents = 10
num_active_agents = 5
key = jax.random.PRNGKey(0)
key, subkey = random.split(key)
agent_type = 1

dice_agents = create_agents(Dice, params=None, num_agents=num_agents, num_active_agents=num_active_agents, 
                            agent_type=agent_type, key=subkey)

key, subkey = random.split(key)
dice_set = Set(agents = dice_agents, num_agents = num_agents, num_active_agents = num_active_agents, state = None, 
                params = None, policy = None, id = 0, set_type = 1, key = subkey)

remove_select_params = Params(content={'select_draw': 5})
add_select_params = Params(content={'select_draw': 1})

def select_func(dice_agent, select_params):
    draw = jnp.reshape(dice_agent.state.content['draw'], (-1))
    return draw == select_params.content['select_draw']

def loop_step(set, t):
    set = jit_step_agents(Dice.step_agent, step_params=None, input=None, set=set)
    draw_before_remove_add = set.agents.state.content['draw'].reshape(-1)
        
    num_agents_selected, selected_indx = jit_select_agents(select_func = select_func, select_params = remove_select_params, set = set)
    
    remove_params = Params(content={'remove_indx': selected_indx})
    set, sorted_indx = jit_remove_agents(remove_func = Dice.remove_agent, remove_params = remove_params, num_agents_remove = num_agents_selected, set = set)
    
    num_agents_selected, selected_indx = jit_select_agents(select_func = select_func, select_params = add_select_params, set = set)
    set = jit_add_agents(add_func = Dice.add_agent, add_params = None, num_agents_add = num_agents_selected, set = set)
    draw_after_remove_add = set.agents.state.content['draw'].reshape(-1)
    
    return set, (draw_before_remove_add, draw_after_remove_add)
jit_loop_step = jax.jit(loop_step)

ts = jnp.arange(1,10)
dice_set, (draw_before_remove_add, draw_after_remove_add) = jax.lax.scan(jit_loop_step, dice_set, ts)

for i in range(9):
    print("draws before remove and add at time step ", i+1, ": ", draw_before_remove_add[i])
    print("draws after remove and add at time step ", i+1, ": ", draw_after_remove_add[i])
    print("\n")



draws before remove and add at time step  1 :  [5 1 3 1 5 0 0 0 0 0]
draws after remove and add at time step  1 :  [1 3 1 5 5 0 0 0 0 0]


draws before remove and add at time step  2 :  [5 2 1 4 2 0 0 0 0 0]
draws after remove and add at time step  2 :  [2 1 4 2 2 0 0 0 0 0]


draws before remove and add at time step  3 :  [5 4 5 5 6 0 0 0 0 0]
draws after remove and add at time step  3 :  [4 6 0 0 0 0 0 0 0 0]


draws before remove and add at time step  4 :  [4 4 0 0 0 0 0 0 0 0]
draws after remove and add at time step  4 :  [4 4 0 0 0 0 0 0 0 0]


draws before remove and add at time step  5 :  [2 6 0 0 0 0 0 0 0 0]
draws after remove and add at time step  5 :  [2 6 0 0 0 0 0 0 0 0]


draws before remove and add at time step  6 :  [6 3 0 0 0 0 0 0 0 0]
draws after remove and add at time step  6 :  [6 3 0 0 0 0 0 0 0 0]


draws before remove and add at time step  7 :  [6 2 0 0 0 0 0 0 0 0]
draws after remove and add at time step  7 :  [6 2 0 0 0 0 0 0 0 0]


draws before remove and add

Here, we used the `jax.lax.scan` function to loop through the steps. Another option is to use the `jax.lax.fori_loop` function. This will be shown in future tutorials.

## step 7: Running multiple sets in parallel
As set is considered as a valid jax data type, we can run multiple sets in parallel. This is useful when we want to run multiple simulations in parallel. Here, we run 5 simulations in parallel by using `vmap` over the sets.

### step 7.1: creating multiple sets
First we define a class called `DiceSet` that inherits from the `Set` class. This class has the same members as the `Set` class. We also define a static method called `create_set` that creates a new set of dice agents. Then we use the `create_sets` function provided by ABMAX to create 5 sets of dice agents. This

In [17]:
@struct.dataclass
class DiceSet(Set):
    @staticmethod
    def create_set(num_agents, num_active_agents, agents, set_params, id, set_type, set_subkeys):
        return DiceSet(agents = agents, num_agents = num_agents, num_active_agents = num_active_agents, state = None, 
                       params = set_params, policy = None, id = id, set_type = set_type, key = set_subkeys)

num_sets = 5
num_agents = 10
num_active_agents = jnp.array([4, 5, 6, 7, 8]) # different sets can have different number of active agents
key = random.PRNGKey(0)


dice_sets = create_sets(set=DiceSet, set_params = None, set_type = 1, 
                          agent=Dice, agent_params=None, agent_type=1,
                          num_sets = num_sets, num_agents = num_agents, num_active_agents = num_active_agents,
                          key = key)

print("number of active agents in each set: ", dice_sets.num_active_agents)

number of active agents in each set:  [4 5 6 7 8]


So, there are 2 important functions used in this step: `create_set` and `create_sets`. Now `create_set` is something we defined in the `DiceSet` class. Its important to preserve its name and the signature, and to make sure it returns a `DiceSet` object. The `create_sets` function is provided by ABMAX and it takes in the following arguments:
- `set`: a `Set` class that represents the set of agents that need to be created. This is used internally by the `create_sets` function to use the `create_set` function that we defined in the `DiceSet` class.
- `set_params`: a `Params` object that represents the parameters to be passed to the `create_set` function. They can represent the parameters that are same for all the agents in a set. They are vmapped over the number of sets.
- ` set_type`: an integer that represents the type of sets in the collection. Can be useful if there are multiple collection of sets in the simulation.
- `agent`: an `Agent` class that represents the agent that needs to be created. This is used internally by the `create_sets` function to use the `create_agent` function that we defined in the `Dice` class.
- `agent_params`: a `Params` object that represents the parameters to be passed to the `create_agents` function. Different agents can have different parameters. Thus, they are vmapped twice. Once over the number of sets and then over the number of agents in each set. In future tutorials, we will see how to use this to create agents with different parameters.
- `agent_type`: an integer that represents the type of agents in the collection. Can be useful if there are multiple collection of agents in the simulation.
- `num_sets`: an integer that represents the number of sets to be created.
- `num_agents`: an integer that represents the maximum number of agents in each set. Again due to the way jax works, each set must have the same maximum number of agents. This is useful when we want to run multiple simulations in parallel.
- `num_active_agents`: an array of integers that represents the number of active agents in each set. As different sets can have different number of active agents, this is vmapped over the number of sets.
- `key`: a jax random key that is used to generate randomness when need be.

It returns a collection of sets of agents. Each set has the same maximum number of agents, a unique id and a different number of active agents.

Now we are ready to `vmap` our `loop_step` function over the sets and run the simulations in parallel.

In [21]:
ts = jnp.arange(1,10)

def sim(set, ts):
    return jax.lax.scan(loop_step, set, ts)
vmap_sim = jax.vmap(sim, in_axes=(0, None))
jit_vmap_sim = jax.jit(vmap_sim)

sets, (draw_before_remove_add, draw_after_remove_add) = jit_vmap_sim(dice_sets, ts)

In [20]:
for i in range(5):
    print("for set ", i+1, ": ")
    for j in range(9):
        print("draws before remove and add at time step ", j+1, ": ", draw_before_remove_add[i][j])
        print("draws after remove and add at time step ", j+1, ": ", draw_after_remove_add[i][j])
        print("\n")
    print("\n")

for set  1 : 
draws before remove and add at time step  1 :  [4 2 5 2 0 0 0 0 0 0]
draws after remove and add at time step  1 :  [4 2 2 0 0 0 0 0 0 0]


draws before remove and add at time step  2 :  [1 4 5 0 0 0 0 0 0 0]
draws after remove and add at time step  2 :  [1 4 6 0 0 0 0 0 0 0]


draws before remove and add at time step  3 :  [3 6 3 0 0 0 0 0 0 0]
draws after remove and add at time step  3 :  [3 6 3 0 0 0 0 0 0 0]


draws before remove and add at time step  4 :  [4 1 4 0 0 0 0 0 0 0]
draws after remove and add at time step  4 :  [4 1 4 3 0 0 0 0 0 0]


draws before remove and add at time step  5 :  [5 3 2 2 0 0 0 0 0 0]
draws after remove and add at time step  5 :  [3 2 2 0 0 0 0 0 0 0]


draws before remove and add at time step  6 :  [1 6 4 0 0 0 0 0 0 0]
draws after remove and add at time step  6 :  [1 6 4 6 0 0 0 0 0 0]


draws before remove and add at time step  7 :  [6 2 5 4 0 0 0 0 0 0]
draws after remove and add at time step  7 :  [6 2 4 0 0 0 0 0 0 0]


draws before 