In [None]:
import gymnasium as gym
import torch
from torch.nn import functional as F
import numpy as np

In [None]:

class Policy( torch.nn.Module ):

    def __init__(self, in_dim , out_dim):

        super( Policy , self ).__init__()

        self.linear1 = torch.nn.Linear( in_dim , 256 );
        self.linear2 = torch.nn.Linear( 256 , 256 );
        self.linear3 = torch.nn.Linear( 256 , 256 );
        self.linear4 = torch.nn.Linear( 256 , 256 );
        
        self.linear = torch.nn.Linear( 256 , out_dim );

    def forward( self , x ):
        
        feature = x;
        feature = F.relu( self.linear1( feature ) );
        feature = F.relu( self.linear2( feature ) );
        feature = F.relu( self.linear3( feature ) );
        feature = F.relu( self.linear4( feature ) );
        
        return self.linear( feature );


In [None]:

gamma = .99
device = 'cuda'
log_probs = []
rewards = []
terminateds = []


In [None]:

def update( log_probs , rewards  ):

    n = len( rewards )

    ret = np.empty( n );

    future_return = 0;
    
    # 秦久韶 algorithm 
    for t in reversed( range( n ) ):
        future_return = rewards[t] + gamma * future_return
        ret[t] = future_return ;

    prob = torch.stack( log_probs )
    ret = torch.tensor( ret ).to( device )

    mean = ret.mean();

    ret -= mean;

    loss = - ( prob * ret ).sum()

    return loss , sum( rewards ); 

@torch.no_grad()
def test( policy , env ):

    policy = policy.eval();
    observation , info = env.reset()

    done = False;

    rewards = 0;

    while( not done ):

        logits = policy( torch.from_numpy( observation ).to( device ) )

        # dist = torch.distributions.Categorical( logits = logits )

        action = logits.argmax();

        observation, reward, done, truncated, info = env.step(action.item())
        rewards += reward;

    print( f'testing rewards : { rewards }')



In [None]:
# Initialise the environment
# env = gym.make("CartPole-v1")
env = gym.make("CartPole-v1", render_mode="human")

policy = Policy( env.observation_space.shape[0] , env.action_space.n ).to( device );

optimizer = torch.optim.Adam( policy.parameters() , lr = 1e-5 );

observation, info = env.reset(seed=42)

for epoch in range( 40000 ):

    done = False ;

    log_probs = []
    rewards = []

    observation , info = env.reset()

    policy = policy.train();

    batch_loss = 0;
    batch_reward = 0;

    while( not done ):

        logits = policy( torch.from_numpy( observation ).to( device ) )

        # generate distribution and do sampling 
        dist = torch.distributions.Categorical( logits = logits )

        action = dist.sample();

        log_probs.append( dist.log_prob( action ) )

        # apply action
        observation, reward, terminated, truncated, info = env.step(action.item())

        done = terminated or truncated;

        rewards.append( reward )

    loss , reward = update( log_probs , rewards );

    batch_loss += loss;
    batch_reward += reward;
    
    if( ( epoch + 1 ) % 1 == 0 ):

        optimizer.zero_grad();
        batch_loss.backward();
        optimizer.step();

        print( f'epoch :{ epoch }, loss : { batch_loss.item() / 1 } , reward: { batch_reward / 1 }' )
        
    if( ( epoch + 1 ) % 100 == 0 ):
        test( policy , env );

env.close()

In [6]:

torch.save( policy , 'policy.pth' )
