In [1]:
import torch
import gymnasium as gym
import numpy as np
import random

ModuleNotFoundError: No module named 'gymnasium'

In [None]:
class PolicyCriticNetwork( torch.nn.Module ):
    def __init__(self, state_dim , action_dim , critic_dim ):
        super( PolicyCriticNetwork , self ).__init__();
        self.shared_layers = torch.nn.Sequential(
            torch.nn.Linear( state_dim , 64 ),
            torch.nn.ReLU(  ),
            torch.nn.Linear( 64 , 64 ),
            torch.nn.ReLU(  )
        )
        self.policy_layers = torch.nn.Sequential(
            torch.nn.Linear(64 , action_dim )
        )
        self.critic_layers = torch.nn.Sequential(
            torch.nn.Linear( 64 , critic_dim )
        )

    def forward( self , state ):
        features = self.shared_layers( state );
        return self.policy_layers( features ) , self.critic_layers( features );

In [None]:
class Memory:
    def __init__(self):
        self.keys = [ 'state' , 'reward' , 'terminated' , 'action' , 'log_prob' , 'v_pred' ,  ];
        # self.keys = [ 'a' , 'b'  ];
        self.data = []
        pass

    def reset( self ):
        self.data = [];

    def add_experience( self , exp ):
        experience = [];
        for key in self.keys :
            if( key in exp ):
                experience.append( exp[ key ] );
            else:
                return False;

        self.data.append( experience );

        return True;

    def __len__( self ):
        return len( self.data );

    def __getitem__( self , index ):
        return self.data[ index ];

    def sample( self , batch_size = -1 , transpose = True , shuffle = False ):
        if( batch_size == -1 ):
            batch_size = len( self );
        batch_index = list( range( batch_size ) );

        if( shuffle ):
            random.shuffle( batch_index );

        result = [];
        for index in batch_index :
            result.append( self.data[ index ] )
        
        if( transpose ):
            result = list( zip( *result ) );
        
        return result;


In [None]:
env = gym.make( 'CartPole-v1' , render_mode = 'human' );
state_dim = env.observation_space.shape[0];
action_dim = env.action_space.n;
critic_dim = 1;
lr = 1e-4

memory = Memory();
model = PolicyCriticNetwork( state_dim , action_dim , critic_dim );
optimizer = torch.optim.Adam( model.parameters() , lr = lr )


In [None]:
def n_step( rewards , next_v_pred , terminations , gamma ):
    N = len( rewards );
    fut = next_v_pred;
    rets = []
    for t in reversed( range( N ) ):
        fut = rewards[t] + gamma * fut * ( 1 - terminations[t] ) ;
        rets.insert( 0 , fut );
    return torch.tensor( rets );

def n_step_adv( rewards , v_preds , next_v_pred , terminations , gamma ):
    v_target = n_step( rewards , next_v_pred , terminations , gamma );
    adv = v_target - v_preds;
    
    return adv , v_target;

def compute_gae(rewards, dones, values, next_value, gamma, lam):
    advantages = []
    gae = 0
    for i in reversed(range(len(rewards))):
        delta = rewards[i] + gamma * next_value * (1 - dones[i]) - values[i]
        gae = delta + gamma * lam * (1 - dones[i]) * gae
        advantages.insert(0, gae)
        next_value = values[i]
    advantages = torch.tensor(advantages, dtype=torch.float32);
    v_target = advantages + values;
    return advantages , v_target

In [None]:

gamma = .99
lam = .95
cliped_eps = .2;
mse_criterion = torch.nn.MSELoss( reduction = 'mean' )

def test( model ):
    model = model.eval();
    done = False;
    state , info = env.reset();
    total_rewards = 0
    while not done :
        state = torch.tensor( state );
        param , v_pred = model( state );
        action = param.argmax().item();
        state , reward , done , truncated, info = env.step( action );
        total_rewards += reward;
        env.render()

    print( f'Evaluation: {total_rewards}' )

def train( 
    model : torch.nn.Module ,
    epoch : int 
    ):

    model = model.train()
    rewards , states , actions , terminations , v_preds , log_probs = [] , [] , [] , [] , [] , []
    state , info = env.reset();
    terminated = False;
    while not terminated:
        state = torch.tensor( state );
        with torch.no_grad():
            pd_param , v_pred = model( state );
        dist = torch.distributions.Categorical( logits = pd_param );
        action = dist.sample();
        log_prob = dist.log_prob( action );

        next_state , reward , terminated , truncated , info = env.step( action.item() );

        states.append( state );
        rewards.append( reward );
        actions.append( action );
        terminations.append( terminated );
        log_probs.append( log_prob );
        v_preds.append( v_pred );

        state = next_state;
    
    with torch.no_grad():
        _,next_v_pred = model( torch.tensor( state ) )

    states = torch.stack( states )
    rewards = torch.tensor( rewards );
    actions = torch.tensor( actions );
    terminations = torch.tensor( terminations , dtype = torch.float );
    log_probs = torch.tensor( log_probs )
    v_preds = torch.cat( v_preds );

    # adv , v_target = n_step_adv( rewards , v_preds , next_v_pred , terminations , gamma );

    adv , v_target = compute_gae( rewards , terminations , v_preds , next_v_pred , gamma , lam )
    adv = ( adv - adv.mean() ) / ( adv.std() + 1e-8 )

    # update 
    for _ in range( 20 ):
        param , v_preds = model( states );
        dist = torch.distributions.Categorical( logits = param );
        new_log_probs = dist.log_prob( actions );

        A = torch.exp( new_log_probs - log_probs );
        cliped_A = torch.clamp( A , 1 - cliped_eps , 1 + cliped_eps );
        policy_grad = torch.min( A * adv , cliped_A * adv );
        policy_loss = -torch.mean( policy_grad );

        value_loss = mse_criterion( v_preds.squeeze( -1 ) , v_target );

        entropy_loss = dist.entropy().mean();

        loss = policy_loss + .5 * value_loss - .1 * entropy_loss;

        optimizer.zero_grad();
        loss.backward();
        optimizer.step()

try:
    for epoch in range( 1000 ):
        train( model  , 0 );
        if( (epoch + 1) % 10 == 0 ):
            test( model );
finally:
    env.close();

In [None]:
env.close();