In [250]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Dense


class Actor(tf.keras.Model):
    def __init__(self, state_shape, action_dim, max_action, units=[40,10], name="Actor"):
        super().__init__(name=name)

        self.l1 = Dense(units[0], name="L1",kernel_initializer='random_uniform',
                bias_initializer='random_uniform')
        self.l2 = Dense(units[1], name="L2",kernel_initializer='random_uniform',
                bias_initializer='random_uniform')
        self.l3 = Dense(action_dim, name="L3",kernel_initializer='random_uniform',
                bias_initializer='random_uniform')

        self.max_action = max_action

        with tf.device("/cpu:0"):
            self(tf.constant(np.zeros(shape=(1,)+state_shape, dtype=np.float32)))

    def call(self, inputs):
        features = tf.nn.relu(self.l1(inputs))
        features = tf.nn.relu(self.l2(features))
        features = self.l3(features)
        action = self.max_action * tf.nn.tanh(features)
        return action
    
class Critic(tf.keras.Model):
    def __init__(self, state_shape, action_dim, units=[400, 300], name="Critic"):
        super().__init__(name=name)

        self.l1 = Dense(units[0], name="L1")
        self.l2 = Dense(units[1], name="L2")
        self.l3 = Dense(1, name="L3")

        dummy_state = tf.constant(
            np.zeros(shape=(1,)+state_shape, dtype=np.float32))
        dummy_action = tf.constant(
            np.zeros(shape=[1, action_dim], dtype=np.float32))
        with tf.device("/cpu:0"):
            self([dummy_state, dummy_action])

    def call(self, inputs):
        states, actions = inputs
        features = tf.concat([states, actions], axis=1)
        features = tf.nn.relu(self.l1(features))
        features = tf.nn.relu(self.l2(features))
        features = self.l3(features)
        return features
    
def update_towards_net2(net1,net2,tau=.01):
    for source_variable,target_variable in zip(net1.trainable_variables, net2.trainable_variables):
        source_variable.assign(tau*source_variable + (1.0 - tau)*target_variable)
    return

In [251]:
critic_optimizer = tf.keras.optimizers.Adam(lr=0.001)
actor_optimizer = tf.keras.optimizers.Adam(lr=0.001)

In [252]:

def give_action():
    with tf.device("/cpu:0"):
        dummy_state = tf.constant([[1.]])
        action = actor(dummy_state)
        action += tf.random.normal(shape=action.shape, mean=0., stddev=0.01, dtype=tf.float32)
        return tf.clip_by_value(action, -actor.max_action, actor.max_action)

In [253]:
def learn(rewards, actions):
    with tf.device("/cpu:0"):
        with tf.GradientTape() as tape:
            tape.watch(critic.trainable_variables)
            
            q_values = critic(actions)
            loss = tf.keras.losses.MSE(rewards, q_values)
            loss = tf.square(q_values)
            critic_grad = tape.gradient(loss,critic.trainable_variables)
            critic_optimizer.apply_gradients(zip(critic_grad, critic.trainable_variables))
            
            
        with tf.GradientTape() as tape:
            tape.watch(actor.trainable_variables)
            actor_loss = -tf.reduce_mean(critic(actions))
            actor_grad = tape.gradient(actor_loss, actor.trainable_variables)
            #actor_optimizer.apply_gradients(zip(actor_grad, actor.trainable_variables))
            return
        
actor=Actor((1,),1,3)


length = 10**3
alpha=0.56
cum_freq=0
batch_length=10
rewards_batch=[]
actions_batch=[]
for i in range(length):
    phase = np.random.choice([-1,1],1)[0]
    beta = give_action()
    p0 = np.exp(-(beta.numpy().flatten()[0]-(phase*alpha))**2)
    outcome = np.random.choice([0,1],1,p=[p0,1-p0])
    if (-1)**(outcome+1) == phase:
        reward=1
    else:
        reward = 0
    rewards_batch.append(tf.constant([reward]))
    actions_batch.append(beta)
    
    if i%batch_length==1:
        actions_batch = tf.stack(actions_batch)
        print(actions_batch)
        learn(rewards_batch, actions_batch)
        actions_batch, rewards_batch = [], []

tf.Tensor(
[[[0.04178788]]

 [[0.01929726]]], shape=(2, 1, 1), dtype=float32)
tf.Tensor(
[[[0.04815506]]

 [[0.03836218]]

 [[0.02831914]]

 [[0.04812082]]

 [[0.02739566]]

 [[0.04378349]]

 [[0.04458015]]

 [[0.05192588]]

 [[0.02413571]]

 [[0.05465184]]], shape=(10, 1, 1), dtype=float32)
tf.Tensor(
[[[0.02841054]]

 [[0.04786437]]

 [[0.047496  ]]

 [[0.05577853]]

 [[0.05543279]]

 [[0.04215296]]

 [[0.03653286]]

 [[0.03201769]]

 [[0.03591206]]

 [[0.03287494]]], shape=(10, 1, 1), dtype=float32)
tf.Tensor(
[[[0.03415036]]

 [[0.05028983]]

 [[0.03732533]]

 [[0.04837058]]

 [[0.05198766]]

 [[0.04517936]]

 [[0.04044571]]

 [[0.05877864]]

 [[0.03674575]]

 [[0.04136035]]], shape=(10, 1, 1), dtype=float32)
tf.Tensor(
[[[0.03752531]]

 [[0.05620482]]

 [[0.03494908]]

 [[0.03550891]]

 [[0.03815434]]

 [[0.04292254]]

 [[0.03939573]]

 [[0.02537421]]

 [[0.04692455]]

 [[0.03019547]]], shape=(10, 1, 1), dtype=float32)
tf.Tensor(
[[[0.03139348]]

 [[0.05911915]]

 [[0.03012076]]



tf.Tensor(
[[[0.04271964]]

 [[0.05377543]]

 [[0.03021811]]

 [[0.04763088]]

 [[0.03902977]]

 [[0.05274549]]

 [[0.02930564]]

 [[0.0374557 ]]

 [[0.04729333]]

 [[0.04471428]]], shape=(10, 1, 1), dtype=float32)
tf.Tensor(
[[[0.05199075]]

 [[0.04686821]]

 [[0.03647422]]

 [[0.02120236]]

 [[0.04684663]]

 [[0.05162344]]

 [[0.04598078]]

 [[0.04470536]]

 [[0.02795076]]

 [[0.02860451]]], shape=(10, 1, 1), dtype=float32)
tf.Tensor(
[[[0.03842534]]

 [[0.02807076]]

 [[0.05409264]]

 [[0.02702149]]

 [[0.04404267]]

 [[0.04387343]]

 [[0.0467727 ]]

 [[0.0282722 ]]

 [[0.05566067]]

 [[0.0276312 ]]], shape=(10, 1, 1), dtype=float32)
tf.Tensor(
[[[0.06101277]]

 [[0.04255589]]

 [[0.04683475]]

 [[0.03377665]]

 [[0.03894231]]

 [[0.03080815]]

 [[0.0398516 ]]

 [[0.05670654]]

 [[0.04860478]]

 [[0.04984761]]], shape=(10, 1, 1), dtype=float32)
tf.Tensor(
[[[0.0412303 ]]

 [[0.05659445]]

 [[0.05143331]]

 [[0.04078837]]

 [[0.02967511]]

 [[0.045781  ]]

 [[0.03648749]]

 [[0.06595

tf.Tensor(
[[[0.03921567]]

 [[0.02588993]]

 [[0.04431048]]

 [[0.04469596]]

 [[0.04904335]]

 [[0.02431863]]

 [[0.03717357]]

 [[0.02723612]]

 [[0.03712326]]

 [[0.0483913 ]]], shape=(10, 1, 1), dtype=float32)
tf.Tensor(
[[[0.06319162]]

 [[0.032299  ]]

 [[0.03653095]]

 [[0.04208719]]

 [[0.03679271]]

 [[0.04113271]]

 [[0.04808589]]

 [[0.0405428 ]]

 [[0.0468794 ]]

 [[0.03678706]]], shape=(10, 1, 1), dtype=float32)
tf.Tensor(
[[[0.04317117]]

 [[0.05958069]]

 [[0.02961113]]

 [[0.04050848]]

 [[0.05548758]]

 [[0.03841047]]

 [[0.05736721]]

 [[0.02480891]]

 [[0.03434369]]

 [[0.05065618]]], shape=(10, 1, 1), dtype=float32)
tf.Tensor(
[[[0.03217567]]

 [[0.05733767]]

 [[0.05269406]]

 [[0.04016976]]

 [[0.05173142]]

 [[0.05685483]]

 [[0.04430058]]

 [[0.05437972]]

 [[0.04047025]]

 [[0.05223935]]], shape=(10, 1, 1), dtype=float32)
tf.Tensor(
[[[0.03436744]]

 [[0.05468274]]

 [[0.03366679]]

 [[0.04792485]]

 [[0.04197284]]

 [[0.03125559]]

 [[0.0514247 ]]

 [[0.04772

In [256]:
critic=Actor((1,),1,3)

In [258]:
critic(np.array([[2]]))

<tf.Tensor: id=1406862, shape=(1, 1), dtype=float32, numpy=array([[-0.13987286]], dtype=float32)>