In [None]:
## Actor Critic Model - testing on the half cheetah environment

In [4]:
## IMPORTANT: In order for mujoco to work must run it from the command line, i.e. jupyter lab 
import gym
from tensorflow.keras import Model
from tensorflow.keras.layers import Dense

#############################################################################################

# ENVIRONMENT SUMMARY: 
# -> Continuous action space

# Observation -> 17 items
# These correspond to the position, angle and speed of the joints

# Action -> 6 items
# These correspond to 6 actuators in the cheetah 
# Value specifies the torque

# Reward -> 1 item
# Reward is proportional to the velocity of the cheetah
# Reward is subtracted by a ctrl cost which I think is proportional to the applied torque.

############################################################################################

# TODO: 
# - Add the correct layer dimensions

# VARIABLES:
NB_FRAMES = 10000

# returns the value function 
class Critic(Model):
    def __init__(self):
        super().__init__()
        self.d1 = Dense(2048,activation='relu')
        self.d2 = Dense(1536,activation='relu')
        self.v = Dense(1, activation = None)

    def call(self, input_data):
        x = self.d1(input_data)
        x = self.d2(x)
        v = self.v(x)
        return v

# returns the action
class Actor(Model):
    def __init__(self):
        super().__init__()
        self.d1 = Dense(2048,activation='relu')
        self.d2 = Dense(1536,activation='relu')
        self.a = Dense(6,activation='softmax')

    def call(self, input_data):
        x = self.d1(input_data)
        x = self.d2(x)
        a = self.a(x)
        return a
        
        

# intitialise the environment
env = gym.make('CartPole-v0')
observation = env.reset()


for frame in range(NB_FRAMES):
    env.render()
    
    action = env.action_space.sample()
    
    observation, reward, done, info = env.step(action)
    
    
    if done:
        print("Episode finished after {} timesteps".format(frame+1))
        break
            
# close the environment
env.close()


1
0
0
1
0
0
0
1
1
1
1
1
0
1
0
1
0
1
0
1
1
1
1
1
1
1
1
1
0
0
0
Episode finished after 31 timesteps
