In [1]:
import tensorflow as tf

class MultiHeadValueNetwork(tf.keras.Model):
    def __init__(self, num_features, hidden_size, learning_rate=0.01, learning_beta=0.0):
        super(MultiHeadValueNetwork, self).__init__()
        self.num_features = num_features
        self.hidden_size = hidden_size
        self.learning_beta = learning_beta
        self.eval_beta = learning_beta

        self.dense1 = tf.keras.layers.Dense(hidden_size, activation='relu', name='dense1')
        self.dense2 = tf.keras.layers.Dense(hidden_size, activation='relu', name='dense2')
        self.dense3_fair = tf.keras.layers.Dense(1, name='dense3fair')
        self.dense3_util = tf.keras.layers.Dense(1, name='dense3util')

        self.optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

    def call(self, inputs):
        x = self.dense1(inputs)
        x = self.dense2(x)
        output_fair = tf.reshape(self.dense3_fair(x), [-1])
        output_util = tf.reshape(self.dense3_util(x), [-1])
        return output_fair, output_util

    def train_step(self, inputs, targets, beta):
        with tf.GradientTape(persistent=True) as tape:
            output_fair, output_util = self(inputs)
            loss_fair = tf.reduce_mean(tf.square(targets[0] - output_fair))
            loss_util = tf.reduce_mean(tf.square(targets[1] - output_util))
            loss = loss_util + beta * loss_fair

        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        return loss_fair, loss_util


    def get(self, states):
        value_fair, value_util = self.predict(states)
        mult = 0 if self.learning_beta == 0 else self.eval_beta / self.learning_beta
        return value_util + value_fair * mult

    def get_util(self, states):
        _, value_util = self.predict(states)
        return value_util

    def get_fair(self, states):
        value_fair, _ = self.predict(states)
        return value_fair

    def save_model(self, save_path):
        self.save_weights(save_path)

    def load_model(self, save_path):
        self.load_weights(save_path)

# Usage example:
# Instantiate the network
network = MultiHeadValueNetwork(num_features=10, hidden_size=256)

# Sample data
inputs = tf.random.normal((32, 10))
targets = (tf.random.normal((32,)), tf.random.normal((32,)))

# Train the network
loss_fair, loss_util = network.train_step(inputs, targets, beta=0.5)

# Get predictions
predictions = network.get(inputs)
print(loss_fair, loss_util, predictions)


tf.Tensor(1.3094058, shape=(), dtype=float32) tf.Tensor(1.5538275, shape=(), dtype=float32) [ 1.1124276   1.74189     2.3752391   2.0405543   2.0525792   1.2576411
  1.2783599   0.6995169   1.1694058   1.5958024   1.3515745   1.714409
  0.7434132   1.5939664   2.4877775   1.2104254  -0.01271582  1.28111
  1.8360399   1.926757    1.9759626   1.4766139   1.751501    1.6695999
  3.2707632   1.762908    0.22818184  2.1524818   1.0753635   2.8108575
  1.975531    1.3309599 ]
