In [None]:
import gymnasium as gym
import torch
from torch.nn import functional as F
import numpy as np
import random
from collections import deque

In [None]:

def softplus( x ): 
    return torch.log( torch.exp( x ) + 1 );

class Policy( torch.nn.Module ):

    def __init__(self, in_dim , out_dim):

        super( Policy , self ).__init__()

        self.linear1 = torch.nn.Linear( in_dim , 256 );

        self.linear2 = torch.nn.Linear( 256 , 256 );
        self.linear3 = torch.nn.Linear( 256 , 256 );
        self.linear4 = torch.nn.Linear( 256 , 256 );
        self.linear5 = torch.nn.Linear( 256 , 256 );
        
        self.linear = torch.nn.Linear( 256 , out_dim );

    def forward( self , x ):
        
        feature = x;
        feature = F.relu( self.linear1( feature ) );
        
        feature = F.relu( self.linear2( feature ) );
        feature = F.relu( self.linear3( feature ) );
        feature = F.relu( self.linear4( feature ) );
        feature = F.relu( self.linear5( feature ) );

        feature = softplus( self.linear( feature ) )
        
        return feature


class ReplayBuffer():
    def __init__( self , capacity = 2048 , device = 'cpu' ):
        self.queue = deque( maxlen = capacity )
        self.device = device;

    def record( self , state , action , reward , next_state , next_action , done ):
        self.queue.append( ( state , action , reward , next_state , next_action , done ) )
        

    def sample( self , batch_size ):
        data = random.sample( self.queue , batch_size );
        states , actions , rewards , next_states , next_actions , dones = zip( *data )

        states = torch.tensor( [ s for s in states ] , device = self.device )
        actions = torch.tensor( actions , device = self.device );
        rewards = torch.tensor( rewards , device = self.device );

        next_states = torch.tensor( [ s for s in next_states ] , device = self.device )
        next_actions = torch.tensor( next_actions , device = self.device )
        dones = torch.tensor( dones , dtype = torch.float , device = self.device )

        return states , actions , rewards , next_states , next_actions , dones


    def __len__( self ):
        return len( self.queue )

In [None]:

gamma = .99
eps = .3;
device = 'cuda'
capacity = 2048;
batch_size = 512;


In [None]:

def sample( env , q_value , eps = 0.0 ) :

    r = random.random();

    if( r < eps ):
        return env.action_space.sample();

    action = q_value.argmax()
    
    return action.item();

def test( policy , env ):

    policy = policy.eval();
    
    observation , info = env.reset()

    done = False;

    rewards = 0;

    while( not done ):

        logits = policy( torch.from_numpy( observation ).to( device ) )

        action = logits.argmax();

        observation, reward, done, truncated, info = env.step(action.item())
        
        rewards += reward;

    print( f'testing rewards : { rewards }')



In [5]:
# Initialise the environment
env = gym.make("CartPole-v1", render_mode="human")
# env = gym.make("CartPole-v1")

policy = Policy( env.observation_space.shape[0] , env.action_space.n ).to( device );

optimizer = torch.optim.Adam( policy.parameters() , lr = 1e-5 , weight_decay = .08 );

observation, info = env.reset()

buffer = ReplayBuffer( capacity , device )

for epoch in range( 30000 ):

    done = False;

    observation , info = env.reset()

    policy = policy.train();

    while( not done ):

        with torch.no_grad():
            q_value = policy( torch.from_numpy( observation ).to( device ) )

            # epsilon greedy 
            action = sample( env , q_value , eps );

            next_observation, reward, terminated, truncated, info = env.step( action )

            # next q_value 

            next_q_value = policy( torch.from_numpy( next_observation ).to( device ) );
        
            next_action = sample( env , next_q_value , eps );

            done = terminated or truncated;

            buffer.record( observation , action , reward , next_observation , next_action , done );

        observation = next_observation;

    if( len( buffer ) < batch_size ):
        continue;
    
    states , actions , rewards , next_states , next_actions , dones = buffer.sample( batch_size );

    q_value = policy( states.to( device ) )

    with torch.no_grad():
        next_q_value = policy( next_states.to( device ) );
    
    pred_q_value = q_value.gather( -1 , actions.unsqueeze( -1 ) ).squeeze( -1 )
    target_q_value = reward + gamma * next_q_value.gather( -1 , next_actions.unsqueeze( -1 ) ).squeeze( -1 ) * ( 1.0 - dones )

    loss = F.mse_loss( pred_q_value , target_q_value , reduction = 'sum' )

    optimizer.zero_grad();
    loss.backward();
    torch.nn.utils.clip_grad_norm_( policy.parameters() , max_norm = 2.0 )
    optimizer.step();
    
    print( f'epoch: {epoch}, loss: {loss.item()}, eps: {eps}' )

    if( ( epoch + 1 ) % 100 == 0 ):

        if( eps > 0 ):
            eps -= .01
        else:
            eps = 0.0;
        
        test( policy , env )
    
    if( (epoch + 1) % 1000 == 0 ):
        if( eps > 0 ):
            eps -= .1
        else:
            eps = 0.0;
    
        

env.close()


KeyboardInterrupt: 

In [None]:

# Initialise the environment
env.close()
env = gym.make("CartPole-v1", render_mode="human")

for _ in range( 10 ):
    test( policy , env )

env.close()
