In [1]:
import gymnasium as gym
import torch
from torch import nn
import kan
import torch.optim as optim
from torch.distributions import Normal
from os import path

In [2]:
class PolicyNetwork(nn.Module):
    
    def __init__(self,state_space_dim,action_space_dim,device='cpu',hiden_size1=100,hidden_size2=50):
        super().__init__()
        self.klayer1 = kan.KANLayer(state_space_dim,hiden_size1).to(device)
        self.klayer2 = kan.KANLayer(hiden_size1,hidden_size2).to(device)
        self.mean = kan.KANLayer(hidden_size2,action_space_dim).to(device)
        self.log_std = kan.KANLayer(hidden_size2,action_space_dim).to(device)
        self.norm = nn.LayerNorm(hiden_size1).to(device)
    def forward(self,x):
        x,_,_,_ = self.klayer1(x)
        x = self.norm(x)
        x,_,_,_ = self.klayer2(x)
        mean,_,_,_ = self.mean(x)
        log_std,_,_,_ = self.log_std(x)
        mean = 0.4*torch.tanh(mean)
        std = 0.8*torch.sigmoid(torch.exp(log_std))
        return mean,std


In [3]:
# Define the Agent class
class REINFORCEAgent:
    def __init__(self, state_dim, action_dim, device,lr=1e-3,gamma=0.99):
        self.policy = PolicyNetwork(state_dim, action_dim,device)
        self.optimizer = optim.Adam(self.policy.parameters(),lr=lr,maximize=True)
        self.episode_rewards = []
        self.log_probs = []
        self.gamma = gamma
        self.device = device

    def sample_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        mean,std = self.policy(state)
        dist = Normal(mean+1e-6,std+1e-6)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        self.log_probs.append(log_prob)
        action = action.squeeze()
        return action.detach().cpu().numpy()

    def update_policy(self):
        R = 0
        returns = []
        policy_loss = []
        for r in self.episode_rewards[::-1]:
            R = r + self.gamma * R
            returns.insert(0, R)
        returns = torch.tensor(returns)
        # returns = (returns - returns.mean()) / (returns.std() + 1e-6)
        for log_prob, R in zip(self.log_probs, returns):
            policy_loss.append(log_prob * R)
        
        self.optimizer.zero_grad()
        policy_loss = torch.cat(policy_loss).sum()
        policy_loss.backward()
        self.optimizer.step()

        self.episode_rewards = []
        self.log_probs = []

#### Convertining solver_iter to solver_niter
https://github.com/Farama-Foundation/Gymnasium/issues/749

#### Removing self.env.close from close function in wrappers/monitoring/video_recorder.py
https://github.com/Farama-Foundation/Gymnasium/issues/455

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
video_path = path.join('KAN_videos')
vide_episode_per_save = 20
episode_per_print = 50
#
episode_trigger = lambda t: (t+1) % vide_episode_per_save==0
num_episodes = 1000
#
env = gym.make('Humanoid-v4',render_mode='rgb_array')
env = gym.wrappers.RecordVideo(env,video_folder=video_path,episode_trigger=episode_trigger,disable_logger=True)
#
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
agent = REINFORCEAgent(state_dim, action_dim,device=device,lr=1e-3)
observation,_= env.reset()
Rewards = []


## Training

In [None]:
for t in range(num_episodes):
    total_reward = 0
    done = False
    act_count = 0
    observation,info= env.reset()
    while not done:
        action = agent.sample_action(observation)  # agent policy that uses the observation and info
        observation, reward, terminated, truncated, info = env.step(action)
        agent.episode_rewards.append(reward)
        total_reward += reward
        act_count +=1
        done = terminated or truncated    
    Rewards.append(total_reward)
    agent.update_policy()
    
    if (t+1)%episode_per_print==0:
        print(f"Episode: {t+1}, Reward:{total_reward}, Action in Episode:{act_count}")

env.close()

## Plot results

In [None]:
import matplotlib.pyplot as plt
plt.plot(range(1,len(Rewards)+1),Rewards)
plt.xlabel('Episode')
plt.ylabel('Reward')
plt.show()
torch.max(torch.tensor(Rewards)), torch.min(torch.tensor(Rewards))