In [722]:
from itertools import combinations, product, islice
from functools import partial
import networkx as nx
import jax.numpy as jnp
import json
import jax
import numpy as np
import gymnax
import chex

In [723]:
seed = 0
rng = jax.random.PRNGKey(seed)
rng, key_topology_request, key_slot_request, key_node_request, key_arrival_time, key_holding_time, key_traffic = jax.random.split(rng, 7)

In [775]:
@chex.dataclass
class VONEEnvState:
    arrival_time: chex.Array
    link_slot_array: chex.Array
    node_capacity_array: chex.Array
    node_mask: chex.Array
    path_link_array: chex.Array
    link_slot_mask: chex.Array
    request_array: chex.Array
    action_counter: chex.Array
    node_departure_array: chex.Array
    link_slot_departure_array: chex.Array
    node_resource_array: chex.Array
    rng: chex.Array
    key_topology_request: chex.Array
    key_slot_request: chex.Array
    key_node_request: chex.Array
    key_arrival_time: chex.Array
    key_holding_time: chex.Array

@chex.dataclass(frozen=True)
class VONEEnvParams:
    traffic_matrix: chex.Array
    values_nodes: chex.Array
    values_slots: chex.Array
    virtual_topology_patterns: chex.Array
    num_nodes: chex.Scalar
    num_links: chex.Scalar
    node_resources: chex.Scalar
    link_resources: chex.Scalar
    k_paths: chex.Scalar
    load: chex.Scalar
    mean_service_holding_time: chex.Scalar
    arrival_rate: chex.Scalar

@chex.dataclass
class RSAEnvState:
    arrival_time: chex.Array
    link_slot_array: chex.Array
    path_link_array: chex.Array
    link_slot_mask: chex.Array
    request_array: chex.Array
    link_slot_departure_array: chex.Array
    rng: chex.Array
    key_node_request: chex.Array
    key_slot_request: chex.Array
    key_arrival_time: chex.Array
    key_holding_time: chex.Array

@chex.dataclass(frozen=True)
class RSAEnvParams:
    traffic_matrix: chex.Array
    values_slots: chex.Array
    num_nodes: chex.Scalar
    num_links: chex.Scalar
    node_resources: chex.Scalar
    link_resources: chex.Scalar
    k_paths: chex.Scalar
    mean_service_holding_time: chex.Scalar
    load: chex.Scalar
    arrival_rate: chex.Scalar

In [776]:
def init_path_link_array(graph, k):
    """Initialise path-link array
    Each path is defined by a link utilisation array. 1 indicates link corrresponding to index is used, 0 indicates not used."""
    def get_k_shortest_paths(g, source, target, k, weight=None):
        return list(
            islice(nx.shortest_simple_paths(g, source, target, weight=weight), k)
        )

    paths = []
    for node_pair in combinations(graph.nodes, 2):
        k_paths = get_k_shortest_paths(
            graph, node_pair[0], node_pair[1], k
        )
        for k_path in k_paths:
            link_usage = [0]*len(graph.edges) # Initialise empty path
            for link in k_path:
                link_usage[link] = 1
            paths.append(link_usage)

    return jnp.array(paths)

@partial(jax.jit, static_argnums=(2,3))
def get_path_indices(s, d, k, N):
    array = jnp.arange(N, dtype=jnp.int32)
    array = jnp.where(array < s, array, 0)#array.at[s:].set(0)
    return (N*s + d - jnp.sum(array) - 2*s - 1) * k

def init_node_array(num_nodes, node_resources):
    """Initialize node array either with uniform resources"""
    return jnp.array([node_resources] * num_nodes)

def reset_node_array(state, params):
    """Reset node array in-place"""
    state.node_array.at[:].set(params.node_resources[0])
    return state, params

def init_link_slot_array(num_links, link_resources):
    """Initialize link array either with uniform resources"""
    return jnp.zeros((num_links, link_resources))

def reset_link_slot_array(state, link_resources):
    """Reset link array in-place"""
    state.link_slot_array.at[:].set(link_resources)
    return state

def init_vone_request_array(max_edges):
    """Initialize request array either with uniform resources"""
    return jnp.zeros((2, max_edges*2+1, ))

def reset_vone_request_array(state):
    """Reset request array in-place"""
    state.request_array.at[:].set(0)
    return state

def init_rsa_request_array():
    """Initialize request array"""
    return jnp.zeros(3)

def reset_rsa_request_array(state):
    """Reset request array in-place"""
    state.request_array.at[:].set(0)
    return state

def init_node_mask(num_nodes):
    """Initialize node mask"""
    return jnp.ones(num_nodes + 1)

def reset_node_mask(state):
    """Reset node mask in-place"""
    state.node_mask.at[:].set(1)
    return state

def init_link_slot_mask(k, link_resources):
    """Initialize link mask"""
    return jnp.ones(k*link_resources)

def reset_link_slot_mask(state):
    """Reset link mask in-place"""
    state.link_slot_mask = jnp.where(state.link_slot_mask == 0, 1, state.link_slot_mask)
    return state

def init_action_counter():
    """Initialize action counter.
    First index is num unique nodes, second index is total steps, final is remaining steps until completion of request."""
    return jnp.zeros(3)

def reset_action_counter(state):
    """Reset action counter in-place"""
    state.action_counter.at[:].set(0)
    return state

@jax.jit
def decrement_action_counter(state):
    """Decrement action counter in-place"""
    state.action_counter.at[-1].add(-1)
    return state

def init_node_departure_array(num_nodes, node_resources):
    return jnp.full((num_nodes, node_resources), jnp.inf)

def reset_node_departure_array(state):
    state.node_departure_array.at[:].set(0)
    return state

def init_link_slot_departure_array(num_links, link_resources):
    return jnp.full((num_links, link_resources), jnp.inf)

def reset_link_slot_departure_array(state):
    state.link_departure_array.at[:].set(0)
    return state

def init_selected_nodes_array(num_nodes, node_resources):
    """Array to track nodes selected for virtual topology"""
    return jnp.zeros((num_nodes, node_resources))

def reset_selected_nodes_array(state):
    """Reset selected nodes array in-place"""
    state.selected_nodes_array.at[:].set(0)
    return state

#@jax.jit
def generate_vone_request(state, virtual_topology_patterns, values_nodes, values_slots):
    # TODO - update this to be bitrate requests rather than slots
    # Define the four possible patterns for the first row
    shape = state.request_array.shape[1]
    # Randomly select topology, node resources, slot resources
    pattern = jax.random.choice(state.key_topology_request, virtual_topology_patterns)
    action_counter = jax.lax.dynamic_slice(pattern, (0,), (3,))#pattern[:3].copy()
    topology_pattern = jax.lax.dynamic_slice(pattern, (3,), (pattern.shape[0]-3,))#pattern[3:].copy()
    selected_node_values = jax.random.choice(state.key_node_request, values_nodes, shape=(shape,))
    selected_slot_values = jax.random.choice(state.key_slot_request, values_slots, shape=(shape,))
    # Create a mask for odd and even indices
    mask = jnp.tile(jnp.array([0, 1]), (shape+1) // 2)[:shape]
    # Vectorized conditional replacement using mask
    first_row = jnp.where(mask, selected_slot_values, selected_node_values)
    first_row = jnp.where(topology_pattern == 0, 0, first_row)
    state.request_array = jnp.vstack((first_row, topology_pattern))
    state.action_counter = action_counter
    return state

def normalise_traffic_matrix(traffic_matrix):
    """Normalise traffic matrix to sum to 1"""
    traffic_matrix /= jnp.sum(traffic_matrix)
    return traffic_matrix


#@jax.jit
def generate_rsa_request(state, traffic_matrix):
    # TODO - update this to be bitrate requests rather than slots
    # Flatten the probabilities to a 1D array
    shape = traffic_matrix.shape
    probabilities = traffic_matrix.ravel()
    # Use jax.random.choice to select index based on the probabilities
    source_dest_index = jax.random.choice(state.key_node_request, jnp.arange(traffic_matrix.size), p=probabilities)
    # Convert 1D index back to 2D
    source, dest = jnp.unravel_index(source_dest_index, shape)
    # Vectorized conditional replacement using mask
    slots = jax.random.choice(state.key_slot_request, values_slots)
    state.request_array = jnp.stack((source, dest, slots))
    return state

@partial(jax.jit, static_argnums=(1,2))
def get_paths(state, k, N, nodes):
    """Get k paths between source and destination"""
    # get source and destination nodes in order (for accurate indexing of path-link array)
    source, dest = jnp.sort(nodes)
    i = get_path_indices(source, dest, k, N)
    index_array = jax.lax.dynamic_slice(jnp.arange(0, state.path_link_array.shape[0]), (i,), (k,))
    return jnp.take(state.path_link_array, index_array, axis=0)

def traffic_generator(state):
    # TODO - figure out how to scale these with load etc (multiply by e^load or similar?)
    state.arrival_time +=  jax.random.exponential(key_arrival_time, shape=(1,))
    holding_time = jax.random.exponential(key_holding_time, shape=(1,)) * jnp.exp
    departure_time = state.arrival_time + holding_time

#@jax.jit
def decrease_last_element(array):
    last_value_mask = jnp.arange(array.shape[0]) == array.shape[0] - 1
    print(last_value_mask)
    return jnp.where(last_value_mask, array - 1, array)


In [777]:
# Generate the shortest path-link table
k = 2
graph = nx.node_link_graph(json.load(open('topologies/conus.json')))
# 7 node ring
# graph = nx.from_numpy_array(jnp.array([[0, 1, 0, 0, 0, 0, 1],
#                                        [1, 0, 1, 0, 0, 0, 0],
#                                        [0, 1, 0, 1, 0, 0, 0],
#                                        [0, 0, 1, 0, 1, 0, 0],
#                                        [0, 0, 0, 1, 0, 1, 0],
#                                        [0, 0, 0, 0, 1, 0, 1],
#                                        [1, 0, 0, 0, 0, 1, 0]]))
arrival_time = 0.0
load = 100.0
mean_service_holding_time = 10.0
arrival_rate = load / mean_service_holding_time
num_nodes = len(graph.nodes)
num_links = len(graph.edges)
node_resources = 30
link_resources = 100

path_link_array = init_path_link_array(graph, k)
node_capacity_array = init_node_array(num_nodes, node_resources)
node_departure_array = init_node_departure_array(num_nodes, node_resources)
node_mask = init_node_mask(num_nodes)
link_slot_array = init_link_slot_array(num_links, link_resources)
link_slot_departure_array = init_link_slot_departure_array(num_links, link_resources)
link_slot_mask = init_link_slot_mask(k, link_resources)
vone_request_array = init_vone_request_array(5)
rsa_request_array = init_rsa_request_array()
node_resource_array = init_selected_nodes_array(num_nodes, node_resources)
# First three values are for action counter
virtual_topology_patterns = jnp.array([
    [3,2,2, 2,1,3,1,4,0,0,0,0,0,0],
    [3,3,3, 2,1,3,1,4,1,2,0,0,0,0],
    [5,4,4, 2,1,3,1,4,1,5,1,6,0,0],
    [5,5,5, 2,1,3,1,4,1,5,1,6,1,2]
], dtype=jnp.int32)
action_counter = init_action_counter()
traffic_matrix = jax.random.uniform(key_traffic, shape=(num_nodes, num_nodes))
traffic_matrix = normalise_traffic_matrix(traffic_matrix)
values_nodes = jnp.arange(1, 3)
values_slots = jnp.arange(1, 5)

env_state = VONEEnvState(
    arrival_time=arrival_time,
    link_slot_array=link_slot_array,
    node_capacity_array=node_capacity_array,
    node_mask=node_mask,
    path_link_array=path_link_array,
    link_slot_mask=link_slot_mask,
    request_array=vone_request_array,
    action_counter=action_counter,
    node_departure_array=node_departure_array,
    link_slot_departure_array=link_slot_departure_array,
    node_resource_array=node_resource_array,
    rng=rng,
    key_topology_request=key_topology_request,
    key_slot_request=key_slot_request,
    key_node_request=key_node_request,
    key_arrival_time=key_arrival_time,
    key_holding_time=key_holding_time,
)

env_params = VONEEnvParams(
    traffic_matrix=traffic_matrix,
    values_nodes=values_nodes,
    values_slots=values_slots,
    mean_service_holding_time=mean_service_holding_time,
    virtual_topology_patterns=virtual_topology_patterns,
    k_paths=k,
    node_resources=node_resources,
    link_resources=link_resources,
    num_nodes=num_nodes,
    num_links=num_links,
    load=load,
    arrival_rate=arrival_rate,
)

rsa_env_state = RSAEnvState(
    arrival_time=arrival_time,
    link_slot_array=link_slot_array,
    path_link_array=path_link_array,
    link_slot_mask=link_slot_mask,
    request_array=rsa_request_array,
    link_slot_departure_array=link_slot_departure_array,
    rng=rng,
    key_node_request=key_node_request,
    key_slot_request=key_slot_request,
    key_arrival_time=key_arrival_time,
    key_holding_time=key_holding_time,
)

rsa_env_params = RSAEnvParams(
    traffic_matrix=traffic_matrix,
    values_slots=values_slots,
    mean_service_holding_time=mean_service_holding_time,
    k_paths=k,
    node_resources=node_resources,
    link_resources=link_resources,
    num_nodes=num_nodes,
    num_links=num_links,
    load=load,
    arrival_rate=arrival_rate,
)
print(type(k))
print(type(num_nodes))
print(type(env_params.k_paths))
print(type(env_params.num_nodes))

<class 'int'>
<class 'int'>
<class 'int'>
<class 'int'>


In [778]:
print(env_state.action_counter)
print(env_state.request_array)
generate_vone_request(env_state, env_params.virtual_topology_patterns, env_params.values_nodes, env_params.values_slots)
print(env_state.action_counter)
print(env_state.request_array)

[0. 0. 0.]
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
[3 2 2]
[[1 2 1 3 2 0 0 0 0 0 0]
 [2 1 3 1 4 0 0 0 0 0 0]]


In [783]:
# TODO - Check if changing float arrays to int arrays when appropriate improves performance
# TODO - When writing the env, have one key that gets split into multiple keys each step/reset (to avoid lots of splitting within each random function)
# TODO - Remember to check out HiPPO initialisation and SSMs
# TODO - Set departure time values to negative initially, then set back to infinite or turn positive when finalising
# TODO - Implement checks
# TODO - Implement masking
# TODO - Consider implementing dataclasses for request attributes e.g. node request, slot request, etc.

# def no_update_link(link, initial_slot, num_slots, value):
#     return link

#def update_selected_nodes(node_indices, array, node, request):
#     return jnp.where(node_indices == node, request, array)

def update_link(link, initial_slot, num_slots, value):
    slot_indices = jnp.arange(link.shape[0])
    return jnp.where((initial_slot <= slot_indices) & (slot_indices <= initial_slot+num_slots), link-value, link)

def update_path(link, link_in_path, initial_slot, num_slots, value):
    return jax.lax.cond(link_in_path == 1, lambda x: update_link(*x), lambda x: x[0], (link, initial_slot, num_slots, value))

@jax.jit
def vmap_update_path_links(link_array, path, initial_slot, num_slots, value):
    return jax.vmap(update_path, in_axes=(0, 0, None, None, None))(link_array, path, initial_slot, num_slots, value)

def update_node_departure(node_row, inf_index, value):
    row_indices = jnp.arange(node_row.shape[0])
    return jnp.where(row_indices == inf_index, value, node_row)

def update_selected_node_departure(node_row, node_selected, first_inf_index, value):
    return jax.lax.cond(node_selected != 0, lambda x: update_node_departure(*x), lambda x: node_row, (node_row, first_inf_index, value))

@jax.jit
def vmap_update_node_departure(node_departure_array, selected_nodes, value):
    first_inf_indices = jnp.argmax(node_departure_array, axis=1)
    return jax.vmap(update_selected_node_departure, in_axes=(0, 0, 0, None))(node_departure_array, selected_nodes, first_inf_indices, value)

def update_node_resources(node_row, zero_index, value):
    row_indices = jnp.arange(node_row.shape[0])
    return jnp.where(row_indices == zero_index, value, node_row)

def update_selected_node_resources(node_row, request, first_zero_index):
    return jax.lax.cond(request != 0, lambda x: update_node_resources(*x), lambda x: node_row, (node_row, first_zero_index, request))

@jax.jit
def vmap_update_node_resources(node_resource_array, selected_nodes):
    first_zero_indices = jnp.argmin(node_resource_array, axis=1)
    return jax.vmap(update_selected_node_resources, in_axes=(0, 0, 0))(node_resource_array, selected_nodes, first_zero_indices)

def remove_expired_slot_requests(state):
    mask = jnp.where(state.link_slot_departure_array < state.arrival_time, 1, 0)
    state.link_slot_array = jnp.where(mask == 1, jnp.inf, state.link_slot_array)
    state.link_slot_departure_array = jnp.where(mask == 1, jnp.inf, state.link_slot_departure_array)
    return state

def remove_expired_node_requests(state):
    mask = jnp.where(state.node_departure_array < state.arrival_time, 1, 0)
    expired_resources = jnp.sum(jnp.where(mask == 1, state.node_resource_array, 0), axis=1)
    state.node_capacity_array = state.node_capacity_array + expired_resources
    state.node_departure_array = jnp.where(mask == 1, jnp.inf, state.node_departure_array)
    return state

def update_node_array(node_indices, array, node, request):
    return jnp.where(node_indices == node, array-request, array)

def undo_node_action(state):
    mask = jnp.where(state.node_departure_array < 0, 1, 0)
    resources = jnp.sum(jnp.where(mask == 1, state.selected_nodes_array, 0), axis=1)
    state.node_capacity_array = state.node_capacity_array + resources
    state.node_departure_array = jnp.where(mask == 1, jnp.inf, state.node_departure_array)
    state.node_resource_array = jnp.where(mask == 1, 0, state.node_resource_array)
    return state

def check_unique_nodes(node_departure_array, total_requested_nodes):
    # TODO - Count negative values on each node (row) in node departure array, must not exceed 1
    return jnp.any(jnp.sum(jnp.where(node_departure_array < 0, 1, 0), axis=1) > 1)

def check_all_nodes_assigned(node_departure_array, total_requested_nodes):
    # TODO - Count negative values on each node (row) in node departure array, must equal total requested_nodes
    return jnp.any(jnp.sum(jnp.where(node_departure_array < 0, 1, 0), axis=1) != jnp.abs(node_departure_array[:, 0]))

def check_node_capacities(resource_array, max_resources):
    # TODO - Sum selected nodes array and check less than node resources
    return jnp.any(jnp.sum(resource_array, axis=1) > max_resources)

def check_no_spectrum_reuse(link_slot_array):
    # TODO - maybe set (or rather deduct) slot to -1 when used, then check if any < -1 in slot array
    return jnp.any(link_slot_array < -1)

def implement_node_action(state, s_node, d_node, s_request, d_request, n=2):
    node_indices = jnp.arange(state.node_capacity_array.shape[0])

    curr_selected_nodes = jnp.zeros(state.node_capacity_array.shape[0])
    curr_selected_nodes = update_node_array(node_indices, curr_selected_nodes, d_node, d_request)
    curr_selected_nodes = jax.lax.cond(n == 2, lambda x: update_node_array(*x), lambda x: x[1], (node_indices, curr_selected_nodes, s_node, s_request))

    state.node_capacity_array = update_node_array(node_indices, state.node_capacity_array, d_node, d_request)
    state.node_capacity_array = jax.lax.cond(n == 2, lambda x: update_node_array(*x), lambda x: x[1], (node_indices, state.node_capacity_array, s_node, s_request))

    state.node_resource_array = vmap_update_node_resources(state.node_resource_array, curr_selected_nodes)
    state.node_resource_array = jax.lax.cond(n == 2, lambda x: vmap_update_node_resources(*x), lambda x: x[0], (state.node_resource_array, curr_selected_nodes))

    state.node_departure_array = vmap_update_node_departure(node_departure_array, curr_selected_nodes, -state.arrival_time)
    state.node_departure_array = jax.lax.cond(n == 2, lambda x: vmap_update_node_departure(*x), lambda x: x[0], (state.node_departure_array, curr_selected_nodes, -state.arrival_time))
    return state

def implement_path_action(state, path, initial_slot_index, num_slots):
    # Update link-slot array
    state.link_slot_array = vmap_update_path_links(state.link_slot_array, path, initial_slot_index, num_slots, 1)
    # Update link-slot departure array
    state.link_slot_departure_array = vmap_update_path_links(state.link_slot_departure_array, path, initial_slot_index, num_slots, -state.arrival_time)
    return state

@partial(jax.jit, static_argnums=(2,3))
def implement_vone_action(state, action, k, N):
    request = jax.lax.dynamic_slice(state.request_array[0], ((state.action_counter[-1]-1)*2, ), (3, ))
    node_request_s = jax.lax.dynamic_slice(request, (2, ), (1, ))
    node_request_d = jax.lax.dynamic_slice(request, (0, ), (1, ))
    num_slots = jax.lax.dynamic_slice(request, (1, ), (1, ))
    total_actions = jnp.squeeze(jax.lax.dynamic_slice(state.action_counter, (1, ), (1, )))
    remaining_actions = jnp.squeeze(jax.lax.dynamic_slice(state.action_counter, (2, ), (1, )))
    nodes = action[:2]
    path_index = jnp.floor(action[2] / state.link_slot_array.shape[0]).astype(jnp.int32)
    initial_slot_index = jnp.mod(action[2], state.link_slot_array.shape[0])
    path = get_paths(state, k, N, nodes)[path_index]
    n_nodes = jax.lax.cond(total_actions == remaining_actions, lambda x: 2, lambda x: 1, (total_actions, remaining_actions))

    state = implement_node_action(state, nodes[0], nodes[1], node_request_s, node_request_d, n=n_nodes)

    state = implement_path_action(state, path, initial_slot_index, num_slots)

    return state

def finalise_vone_action():
    # TODO - Turn departure times positive, turn slots positive,
    pass


implement_vone_action(env_state, jnp.array([0,1,2]), env_params.k_paths, env_params.num_nodes)

VONEEnvState(arrival_time=Array(0., dtype=float32, weak_type=True), link_slot_array=Array([[0., 0., 1., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32), node_capacity_array=Array([28, 29, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30,
       30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30,
       30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30,
       30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30,
       30, 30, 30, 30, 30, 30, 30], dtype=int32), node_mask=Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,

In [793]:
array = jnp.array([[jnp.inf,jnp.inf,jnp.inf,jnp.inf,jnp.inf,jnp.inf],
                   [0.1,0.2,jnp.inf,jnp.inf,jnp.inf,jnp.inf],
                   [jnp.inf,jnp.inf,jnp.inf,jnp.inf,jnp.inf,jnp.inf],
                   [-0.2,jnp.inf,jnp.inf,jnp.inf,jnp.inf,jnp.inf],
                   [jnp.inf,jnp.inf,jnp.inf,jnp.inf,jnp.inf,jnp.inf],
                   [jnp.inf,jnp.inf,jnp.inf,jnp.inf,jnp.inf,jnp.inf],])

resource_array = jnp.array([[0,0,0,0,0,0],
                   [1,2,0,0,0,0],
                   [0,0,0,0,0,0],
                   [2,0,0,0,0,0],
                   [0,0,0,0,0,0],
                   [0,0,0,0,0,0],])

node_capacities = jnp.array([10,7,10,8,10,10])

link_slot_array = jnp.array([[1,-2,0,0,0,0],
                   [1,1,0,0,0,0],
                   [0,0,0,0,0,0],
                   [0,0,0,1,0,0],
                   [0,0,0,1,0,0],
                   [0,0,0,0,0,0]])



def update_first_inf(node_row, inf_index, value):
    row_indices = jnp.arange(node_row.shape[0])
    return jnp.where(row_indices == inf_index, value, node_row)

def conditional_update_first_inf(node_row, node_selected, first_inf_index, value):
    return jax.lax.cond(node_selected != 0, lambda x: update_first_inf(*x), lambda x: node_row, (node_row, first_inf_index, value))

def vmap_update_first_inf(node_departure_array, selected_nodes, value):
    first_inf_indices = jnp.argmax(node_departure_array, axis=1)
    return jax.vmap(conditional_update_first_inf, in_axes=(0, 0, 0, None))(node_departure_array, selected_nodes, first_inf_indices, value)


check_no_spectrum_reuse(link_slot_array)
jnp.any(jnp.sum(resource_array, axis=1) > 3)

Array(False, dtype=bool)

In [764]:
mask = jnp.where(array < 0.25, 1, 0)
print(f"mask: {mask}")
expired_resources = jnp.sum(jnp.where(mask == 1, resource_array, 0), axis=1)
print(f"expired_resources: {expired_resources}")
node_capacities = node_capacities + expired_resources
print(f"node_capacities: {node_capacities}")
array = jnp.where(mask == 1, jnp.inf, array)
print(f"array: {array}")
jnp.argmax(array, axis=1)
vmap_update_first_inf(array, jnp.array([0,1,2,3,0,0]), -100)


mask: [[0 0 0 0 0 0]
 [0 1 0 0 0 0]
 [0 0 0 0 0 0]
 [0 0 0 0 0 0]
 [0 0 0 0 0 0]
 [0 0 0 0 0 0]]
expired_resources: [0 2 0 0 0 0]
node_capacities: [10 10 10 10 10 10]
array: [[inf inf inf inf inf inf]
 [inf inf inf inf inf inf]
 [inf inf inf inf inf inf]
 [inf inf inf inf inf inf]
 [inf inf inf inf inf inf]
 [inf inf inf inf inf inf]]


Array([[  inf,   inf,   inf,   inf,   inf,   inf],
       [-100.,   inf,   inf,   inf,   inf,   inf],
       [-100.,   inf,   inf,   inf,   inf,   inf],
       [-100.,   inf,   inf,   inf,   inf,   inf],
       [  inf,   inf,   inf,   inf,   inf,   inf],
       [  inf,   inf,   inf,   inf,   inf,   inf]], dtype=float32)

In [787]:
implement_vone_action(env_state, jnp.array([0,1,2]), env_params.k_paths, env_params.num_nodes)

VONEEnvState(arrival_time=Array(0., dtype=float32, weak_type=True), link_slot_array=Array([[0., 0., 1., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32), node_capacity_array=Array([28, 29, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30,
       30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30,
       30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30,
       30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30,
       30, 30, 30, 30, 30, 30, 30], dtype=int32), node_mask=Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,

In [456]:
def add_one(elem):
    return elem + 1

def add_one_conditional(elem, other_elem):
    return jax.lax.cond(other_elem == 1, lambda x: x + 9, lambda x: x, elem)

@jax.jit
def vmap_add_one_conditional(array, path):
    return jax.vmap(add_one_conditional, in_axes=(0, 0))(array, path)

path = jnp.array([1,1,1,0,0,0,0,1,0,0])
array = jnp.zeros(10)
print(array)
vmap_add_one_conditional(array, path)


[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]


Array([9., 9., 9., 0., 0., 0., 0., 9., 0., 0.], dtype=float32)

In [438]:
request_array = env_state.request_array[0]
counter=jnp.array([3,2,2])
print(request_array)
#action_counter = env_state.action_counter
print(counter)
print(env_state.action_counter)
env_state.action_counter = decrease_last_element(env_state.action_counter)
#env_state = decrement_action_counter(env_state)
print(env_state.action_counter)

env_state.action_counter = decrease_last_element(env_state.action_counter)
print(counter)


In [691]:
@jax.jit
def decrease_last_element(array):
    last_value_mask = jnp.arange(array.shape[0]) == array.shape[0] - 1
    return jnp.where(last_value_mask, array - 1, array)

#array = jnp.array([1, 2, 3, 4, 5])
array = decrease_last_element(array)

print(array)  # prints: [1 2 3 4 4]

[ 1  2  3  4 -2]


In [None]:
def init_path_link_arra_old(graph, k):

    def get_k_shortest_paths(g, source, target, k, weight=None):
        return list(
            islice(nx.shortest_simple_paths(g, source, target, weight=weight), k)
        )

    def create_path_link_dict(graph, k):
        link_selection_dict = {}
        for node_pair in combinations(graph.nodes, 2):
            k_paths = get_k_shortest_paths(
                graph, node_pair[0], node_pair[1], k
            )
            link_selection_dict[node_pair] = k_paths
            #self.link_selection_dict[(node_pair[1], node_pair[0])] = k_paths
        return link_selection_dict

    def find_longest_path_length(dictionary):
        return max(
            len(sublist) for list_of_lists in dictionary.values()
            for sublist in list_of_lists if
            isinstance(list_of_lists, list) and all(isinstance(sublist, list)
            for sublist in list_of_lists)
        )

    def ensure_same_number_of_paths(dictionary, k):
        for item in dictionary.values():
            if isinstance(item, list) and all(isinstance(sublist, list) for sublist in item):
                while len(item) < k:
                    item.append([])
        return dictionary

    path_link_dict = create_path_link_dict(graph, k)
    path_link_dict = ensure_same_number_of_paths(path_link_dict, k)
    longest_length = find_longest_path_length(path_link_dict)
    sublists = []

    for sublist_list in path_link_dict.values():
        if isinstance(sublist_list, list) and all(isinstance(sublist, list) for sublist in sublist_list):
            for sublist in sublist_list:
                padded_sublist = sublist + [jnp.nan] * (longest_length - len(sublist))
                sublists.append(padded_sublist)

    return jnp.array(sublists)


def update_slot_conditional(slot_value, slot_index, initial_slot, num_slots, value):
    return jax.lax.cond(initial_slot <= slot_index <= initial_slot+num_slots, lambda x: update_slot(slot_value, value), lambda x: no_update_slot(slot_value, value), (slot_value, value))
    # update_values = jnp.ones(100)
    # print(update_values)
    # update_values = update_values.at[:num_slots].set(value)
    # print(update_values)
    # return jax.lax.dynamic_update_slice(link, update_values, (0,))

def vmap_update_slot_conditional(link, initial_slot, num_slots, value):
    slot_indices = jnp.arange(100)
    return jax.vmap(update_slot_conditional, in_axes=(0, 0, None, None, None))(link, slot_indices, initial_slot, num_slots, value)

# def update_link_conditional(link, link_in_path, initial_slot, num_slots, value):
#     return jax.lax.cond(link_in_path == 1, lambda x: vmap_update_slot_conditional(*x), lambda x: no_update_slot_conditional(*x), (link, initial_slot, num_slots, value))

def update_link_where(link, initial_slot, num_slots, value):
    slot_indices = jnp.arange(100)
    return jnp.where((initial_slot <= slot_indices) & (slot_indices <= initial_slot+num_slots), value, link)

def update_link_conditional(link, link_in_path, initial_slot, num_slots, value):
    return jax.lax.cond(link_in_path == 1, lambda x: update_link_where(*x), lambda x: no_update_link(*x), (link, initial_slot, num_slots, value))

@jax.jit
def vmap_update_link(link_array, path, initial_slot, num_slots, value):
    return jax.vmap(update_link_conditional, in_axes=(0, 0, None, None, None))(link_array, path, initial_slot, num_slots, value)

def update_slot(slot_value, value):
    return value

def no_update_slot(slot_value, value):
    return slot_value

def no_update_slot_conditional(link, initial_slot, num_slots, value):
    return link


def no_update_link(link, initial_slot, num_slots, value):
    return link

def update_link_conditional_new(link, link_in_path, initial_slot, num_slots, value):
    return jax.lax.cond(link_in_path == 1, lambda x: update_link(*x), lambda x: no_update_link(*x), (link, initial_slot, num_slots, value))

#@jax.jit
def vmap_update_link_conditional_new(link_array, path, initial_slot, num_slots, value):
    return jax.vmap(update_link_conditional_new, in_axes=(0, 0, None, None, None))(link_array, path, initial_slot, num_slots, value)

def update_link_1(link, initial_slot, value):
    update_values = jnp.full((1, ), value)
    return jax.lax.dynamic_update_slice(link, update_values, (initial_slot,))

def update_link_2(link, initial_slot, value):
    update_values = jnp.full((2, ), value)
    return jax.lax.dynamic_update_slice(link, update_values, (initial_slot,))

def update_link_3(link, initial_slot, value):
    update_values = jnp.full((3, ), value)
    return jax.lax.dynamic_update_slice(link, update_values, (initial_slot,))

def update_link_4(link, initial_slot, value):
    update_values = jnp.full((4, ), value)
    return jax.lax.dynamic_update_slice(link, update_values, (initial_slot,))

def update_link_5(link, initial_slot, value):
    update_values = jnp.full((5, ), value)
    return jax.lax.dynamic_update_slice(link, update_values, (initial_slot,))

def no_update(link, initial_slot, value):
    return link

def no_update_conditional(link, link_in_path, initial_slot, value):
    return link

def update_link_conditional_1(link, link_in_path, initial_slot, value):
    return jax.lax.cond(link_in_path == 1, lambda x: update_link_1(*x), lambda x: no_update(*x), (link, initial_slot, value))

def update_link_conditional_2(link, link_in_path, initial_slot, value):
    return jax.lax.cond(link_in_path == 1, lambda x: update_link_2(*x), lambda x: no_update(*x), (link, initial_slot, value))

def update_link_conditional_3(link, link_in_path, initial_slot, value):
    return jax.lax.cond(link_in_path == 1, lambda x: update_link_3(*x), lambda x: no_update(*x), (link, initial_slot, value))

def update_link_conditional_4(link, link_in_path, initial_slot, value):
    return jax.lax.cond(link_in_path == 1, lambda x: update_link_4(*x), lambda x: no_update(*x), (link, initial_slot, value))

def update_link_conditional_5(link, link_in_path, initial_slot, value):
    return jax.lax.cond(link_in_path == 1, lambda x: update_link_5(*x), lambda x: no_update(*x), (link, initial_slot, value))

#TODO - Problem - jax.lax.switch cannot be jitted
#def update_link_conditional(num_slots, branches, link, link_in_path, initial_slot, value):
#    return jax.lax.switch(num_slots, branches, *(link, link_in_path, initial_slot, value))

#@jax.jit
def vmap_update_link_conditional(num_slots, branches, link_array, path, initial_slot, value):
    return jax.vmap(update_link_conditional, in_axes=(None, None, 0, 0, None, None))(num_slots, branches, link_array, path, initial_slot, value)

branches = [no_update_conditional, update_link_conditional_1, update_link_conditional_2, update_link_conditional_3, update_link_conditional_4, update_link_conditional_5]


request = jax.lax.dynamic_slice(
    request_array,
    ((action_counter[-1]-1)*2, ),
    (3, )
)
print(request)

def update_node_departure_array(node_array, s_node, d_node, value, n=2):
    node_indices = jnp.arange(node_array.shape[0])
    node_array = vmap_update_node_departure(node_indices, node_array, s_node, d_node, value)
    return node_array

# def update_node_array(state, node, value):
#     node_indices = jnp.arange(state.node_array.shape[0])
#     state.node_array = jnp.where(node_indices == node, value, state.node_array)
#     return state

# def implement_selected_nodes(selected_nodes, s_node, d_node, s_request, d_request, n=2):
#     node_indices = jnp.arange(selected_nodes.shape[0])
#     selected_nodes = update_selected_nodes(node_indices, selected_nodes, d_node, d_request)
#     selected_nodes = jax.lax.cond(n == 2, lambda x: update_selected_nodes(*x), lambda x: x[1], (node_indices, selected_nodes, s_node, s_request))
#     return selected_node