In [1]:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions

# Define custom prior layer
class CustomPriorLayer(tfp.layers.DistributionLambda):
    def __init__(self, event_size, dtype=tf.float32, **kwargs):
        super(CustomPriorLayer, self).__init__(
            make_distribution_fn=self.make_distribution_fn,
            convert_to_tensor_fn=tfd.Distribution.sample,
            dtype=dtype,
            **kwargs
        )
        self.event_size = event_size

    def make_distribution_fn(self, t):
        return tfd.Normal(loc=t * 0.0, scale=1.0)  # Change the prior distribution as needed

# Define Bayesian Q Network
class BayesianQNetwork(tf.keras.Model):
    def __init__(self, state_dim, action_dim):
        super(BayesianQNetwork, self).__init__()
        # Define Bayesian layers with specific priors
        self.dense1 = tfp.layers.DenseFlipout(
            128,
            activation='relu',
            kernel_posterior_fn=tfp.layers.default_mean_field_normal_fn(),
            kernel_posterior_tensor_fn=lambda d: d.sample(),
            kernel_divergence_fn=(lambda q, p, _: tfd.kl_divergence(q, p) / tf.cast(state_dim, dtype=tf.float32)),
            bias_posterior_fn=tfp.layers.default_mean_field_normal_fn(),
            bias_posterior_tensor_fn=lambda d: d.sample(),
            bias_divergence_fn=(lambda q, p, _: tfd.kl_divergence(q, p) / tf.cast(state_dim, dtype=tf.float32)),
        )
        self.dense2 = tfp.layers.DenseFlipout(
            action_dim,
            activation=None,
            kernel_posterior_fn=tfp.layers.default_mean_field_normal_fn(),
            kernel_posterior_tensor_fn=lambda d: d.sample(),
            kernel_divergence_fn=(lambda q, p, _: tfd.kl_divergence(q, p) / tf.cast(state_dim, dtype=tf.float32)),
            bias_posterior_fn=tfp.layers.default_mean_field_normal_fn(),
            bias_posterior_tensor_fn=lambda d: d.sample(),
            bias_divergence_fn=(lambda q, p, _: tfd.kl_divergence(q, p) / tf.cast(state_dim, dtype=tf.float32)),
        )

    def call(self, state):
        x = self.dense1(state)
        q_values = self.dense2(x)
        return q_values

# Define BDQN
class BayesianDQNAgent:
    def __init__(self, state_dim, action_dim):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.model = BayesianQNetwork(state_dim, action_dim)
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

    def train_step(self, states, actions, rewards, next_states):
        with tf.GradientTape() as tape:
            # Compute Q values for the current and next states
            q_values = self.model(states)
            next_q_values = self.model(next_states)

            # Use Huber loss for training
            target_q_values = rewards + discount_factor * tf.reduce_max(next_q_values, axis=1)
            action_masks = tf.one_hot(actions, self.action_dim)
            selected_q_values = tf.reduce_sum(q_values * action_masks, axis=1)
            loss = tf.reduce_mean(tf.losses.huber(target_q_values, selected_q_values))

        gradients = tape.gradient(loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))

    def train(self, num_episodes, batch_size):
        for episode in range(num_episodes):
            states, actions, rewards, next_states = generate_synthetic_data_with_turnaround_time(batch_size, self.state_dim, self.action_dim)
            self.train_step(states, actions, rewards, next_states)

# Generate synthetic dataset with turnaround time
def generate_synthetic_data_with_turnaround_time(num_samples, state_dim, action_dim):
    states = np.random.rand(num_samples, state_dim).astype(np.float32)
    actions = np.random.randint(action_dim, size=num_samples)
    next_states = np.random.rand(num_samples, state_dim).astype(np.float32)

    # Compute turnaround time and use it to calculate rewards
    turnaround_times = compute_turnaround_times(states, actions)
    rewards = 1.0 / (turnaround_times+1e-6)  # Example: reward is inversely proportional to turnaround time

    return states, actions, rewards, next_states

# Compute turnaround times based on states and actions
def compute_turnaround_times(states, actions):
    if len(actions.shape) == 1:
        # If actions is 1D, convert it to a 2D array with a single column
        actions = actions[:, np.newaxis]

    turnaround_times = np.sum(states * actions, axis=1)
    return turnaround_times


# Round Robin Scheduling
def round_robin_scheduling(states, action_dim):
    num_samples = states.shape[0]
    actions = np.tile(np.arange(action_dim), num_samples // action_dim)[:, np.newaxis]
    return actions

# First Come First Serve Scheduling
def first_come_first_serve_scheduling(states, action_dim):
    num_samples = states.shape[0]
    actions = np.arange(num_samples)[:, np.newaxis] % action_dim
    return actions




# Comparison function
def compare_algorithms(state_dim, action_dim, num_samples):
    # Generate synthetic data
    states, _, _, _ = generate_synthetic_data_with_turnaround_time(num_samples, state_dim, action_dim)

    # Round Robin Scheduling
    rr_actions = round_robin_scheduling(states, action_dim)
    rr_turnaround_times = compute_turnaround_times(states, rr_actions)
    print("Round Robin Turnaround Times:", rr_turnaround_times)

    # First Come First Serve Scheduling
    fcfs_actions = first_come_first_serve_scheduling(states, action_dim)
    fcfs_turnaround_times = compute_turnaround_times(states, fcfs_actions)
    print("First Come First Serve Turnaround Times:", fcfs_turnaround_times)

    # Bayesian Deep Q Network (BDQN) Agent
    agent = BayesianDQNAgent(state_dim, action_dim)
    agent.train(num_episodes=100, batch_size=32)  # Train the agent
    bdqn_actions = np.argmax(agent.model.predict(states), axis=1)
    bdqn_turnaround_times = compute_turnaround_times(states, bdqn_actions)
    print("BDQN Agent Turnaround Times:", bdqn_turnaround_times)

# Hyperparameters
state_dim = 10
action_dim = 5
discount_factor = 0.99
num_samples = 100

# Compare algorithms
compare_algorithms(state_dim, action_dim, num_samples)


Round Robin Turnaround Times: [ 0.          4.91109149  7.68943365 10.26971244 26.1755574   0.
  4.2998883   9.18228814 17.62279399 17.1511758   0.          5.03489941
 11.03335495  7.56094029 19.54314934  0.          4.90240199  8.11998791
 17.0424419  20.75178397  0.          4.19641779 10.25697216 17.07620962
 16.38059296  0.          6.03649858  9.5873768  10.55666609 19.73934418
  0.          5.21571323  9.9474069  16.83855839 25.49075949  0.
  3.88664019  9.66506543 10.32707239 19.28122073  0.          5.67188334
 11.22446498 16.61866951 18.72075322  0.          5.08760454  9.3381322
 12.70178927 19.23595476  0.          3.50312119  9.31405682 13.61681503
 22.74642542  0.          6.7799308   9.02213368 13.22170704 23.35412055
  0.          6.63748199  9.83531651 15.53478035 24.58244264  0.
  4.48033702 11.74939877 11.02520859 21.07400441  0.          5.46844797
 14.0488165  15.73491769 29.32312554  0.          5.02374877 10.12651366
 15.42438024 16.38272485  0.          5.348908

  loc = add_variable_fn(
  untransformed_scale = add_variable_fn(


BDQN Agent Turnaround Times: [ 5.31281598  0.          0.          0.          6.54388935  8.96310205
  8.59977659  9.18228814 11.74852933  8.5755879   7.30143046 15.10469824
 11.03335495  5.04062686 14.657362    6.5621141   4.90240199  8.11998791
 11.36162794 10.37589198 15.30148858  8.39283559 10.25697216 11.38413975
  8.19029648 15.43698223 12.07299715 14.38106519  7.03777739  0.
  0.         10.43142647 19.89481381  0.         12.74537975 20.59410128
  0.         19.33013086  3.44235746  0.         21.20111158 11.34376669
 11.22446498 22.15822601  0.         19.47830671  0.         18.67626441
  4.23392976  9.61797738 10.19368899  3.50312119  9.31405682  4.53893834
 22.74642542  5.93585399  0.         18.04426736  8.81447136  0.
  4.19548065 13.27496397 19.67063302 10.35652024 12.29122132  6.8578963
  8.96067403 11.74939877  7.35013906  5.2685011   0.         10.93689594
 14.0488165  10.48994513  0.          0.         10.04749754 10.12651366
  0.          0.          0.         10