In [None]:
import gymnasium as gym
import torch
from torch.nn import functional as F
import numpy as np

In [None]:

def softplus( x ):
    return torch.log( 1 + torch.exp( x ) )

class Policy( torch.nn.Module ):

    def __init__(self, in_dim , out_dim):

        super( Policy , self ).__init__()

        self.linear1 = torch.nn.Linear( in_dim , 256 );
        self.linear2 = torch.nn.Linear( 256 , 256 );
        self.linear3 = torch.nn.Linear( 256 , 256 );
        self.linear4 = torch.nn.Linear( 256 , 256 );
        
        self.linear = torch.nn.Linear( 256 , out_dim );

    def forward( self , x ):
        
        feature = x;
        feature = F.relu( self.linear1( feature ) );
        feature = F.relu( self.linear2( feature ) );
        feature = F.relu( self.linear3( feature ) );
        feature = F.relu( self.linear4( feature ) );
        
        return self.linear( feature );

class Critic( torch.nn.Module ):

    def __init__(self, in_dim):

        super( Critic , self ).__init__()

        self.linear1 = torch.nn.Linear( in_dim , 256 );
        self.linear2 = torch.nn.Linear( 256 , 256 );
        self.linear3 = torch.nn.Linear( 256 , 256 );
        self.linear4 = torch.nn.Linear( 256 , 256 );
        self.linear5 = torch.nn.Linear( 256 , 256 );
        
        self.linear = torch.nn.Linear( 256 , 1 );

    def forward( self , x ):
        
        feature = x;
        feature = F.relu( self.linear1( feature ) );
        feature = F.relu( self.linear2( feature ) );
        feature = F.relu( self.linear3( feature ) );
        feature = F.relu( self.linear4( feature ) );
        feature = F.relu( self.linear5( feature ) );
        
        return self.linear( feature );


In [None]:

gamma = .99
lam = .95;
beta = .01
device = 'cuda'


In [None]:


def policy_gradient( log_probs , rewards , dones ):

    T = len( rewards );

    ret = torch.zeros( T , dtype = torch.float32 ).to( rewards.device );

    fut_return = 0;

    for t in reversed( range( T ) ):
        fut_return = rewards[t] + gamma * fut_return * ( 1 - dones[t] );
        ret[t] = fut_return;

    ret = ret * log_probs;

    return ret;


def n_step_return ( rewards , values , next_value , dones ):
    
    T = len( rewards )

    fut_return = next_value;

    v_target = torch.zeros( T , dtype = torch.float32 ).to( rewards.device );

    for t in reversed( range( T ) ):
        fut_return = rewards[t] + gamma * fut_return * ( 1 - dones[t] )
        v_target[t] = fut_return

    adv = v_target - values;

    return v_target , adv;


def gae( rewards , values , dones ):
    
    T = len( rewards )

    # values = [ values , next_values ]
    assert T + 1 == len( values )

    adv = torch.zeros( T , dtype = torch.float );
    
    fut_return = 0;

    for t in reversed( range( T ) ):

        # delta_t = reward + gamma * V_t+1 - V_t 
        delta = rewards[t] + gamma * values[t+1] - values[t];

        # A_PI_t = delta + gamma * A_PI_t-1
        fut_return = delta + gamma * lam * ( 1.0 - dones[t] ) * fut_return ; 

        adv[t] = fut_return;

    v_target = adv + values[:-1];

    return v_target , adv 


@torch.no_grad()
def test( policy , env ):

    policy = policy.eval();

    observation , info = env.reset()

    done = False;

    rewards = 0;

    while( not done ):
        
        logits = policy( torch.from_numpy( observation ).to( device ) )

        action = logits.argmax();

        observation, reward, done, truncated, info = env.step(action.item())
        rewards += reward;

    print( f'testing rewards : { rewards }')




In [None]:
# Initialise the environment
# env = gym.make("CartPole-v1", render_mode="human")
env = gym.make("CartPole-v1")
# env = gym.make("LunarLander-v3",render_mode="human")

policy = Policy( env.observation_space.shape[0] , env.action_space.n ).to( device );

critic = Critic( env.observation_space.shape[0] ).to( device )

optimizer = torch.optim.Adam( list( policy.parameters()) + list( critic.parameters()) , lr = 1e-5 );

observation, info = env.reset(seed=42)

for epoch in range( 40000 ):

    done = False ;

    # for loss computation
    logits = []
    values = []
    rewards = []
    dones = []
    log_probs = []

    entropies = []

    observation , info = env.reset()
    policy = policy.train();

    while( not done ):

        state = torch.from_numpy( observation ).to( device )

        logit = policy( state );

        value = critic( state );

        values.append( value );

        dist = torch.distributions.Categorical( logits = logit );

        entropies.append( dist.entropy() )

        action = dist.sample();

        # apply action
        next_observation, reward, terminated, truncated, info = env.step( action.item())

        # print( reward )

        with torch.no_grad():
            next_value = critic( torch.from_numpy( next_observation ).to( device ) );

        done = terminated or truncated;

        rewards.append( reward )

        dones.append( done )

        log_probs.append( dist.log_prob( action ) );
    
        observation = next_observation;
    

    # loss 
    # 1. policy gradient 
    log_probs = torch.stack( log_probs ).to( device );
    rewards = torch.tensor( rewards , dtype = torch.float32 , device = device );
    values = torch.cat( values ).to( device  );
    dones = torch.tensor( dones , dtype = torch.float32 , device=device );

    v_target , adv = n_step_return( rewards , values.detach() , next_value.detach(), dones )

    policy_loss = -(log_probs * adv.detach()).mean() - beta * torch.stack( entropies ).mean() ;

    v_loss = F.mse_loss( values , v_target , reduction = 'mean' );
    
    loss = policy_loss * .8 + v_loss ;
    
    if( ( epoch + 1 ) % 1 == 0 ):

        optimizer.zero_grad();
        loss.backward();
        optimizer.step();

        print( f'epoch :{ epoch }, loss : { loss.item() / 1 } , reward: { rewards.sum() / 1 }' )

    if( (epoch + 1) % 100 == 0 ):

        test( policy , env );

    if( (epoch + 1) % 1000 == 0 ):
        beta /= 2;

env.close()


In [None]:
env.close()
env = gym.make("CartPole-v1", render_mode="human")

for _ in range( 10 ):
    test( policy , env)

env.close()


In [None]:
torch.save( policy , 'policy.pth')
torch.save( critic , 'critic.pth' )