In [9]:
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque


In [10]:
class QuantileNetwork(nn.Module):
    def __init__(self, state_size, action_size, num_quantiles):
        super(QuantileNetwork, self).__init__()
        self.fc1 = nn.Linear(state_size, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, action_size * num_quantiles)
        self.action_size = action_size
        self.num_quantiles = num_quantiles

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        quantiles = x.view(-1, self.action_size, self.num_quantiles)
        return quantiles


In [11]:
class Agent:
    def __init__(self, state_size, action_size, num_quantiles, gamma, lr, batch_size, buffer_size, device):
        self.state_size = state_size
        self.action_size = action_size
        self.num_quantiles = num_quantiles
        self.gamma = gamma
        self.batch_size = batch_size
        self.memory = deque(maxlen=buffer_size)
        self.device = device

        self.q_network = QuantileNetwork(state_size, action_size, num_quantiles).to(self.device)
        self.target_network = QuantileNetwork(state_size, action_size, num_quantiles).to(self.device)
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)
        
        self.update_target_network()
        
        # Initialize taus
        self.taus = torch.arange(0.5 / self.num_quantiles, 1.0, 1.0 / self.num_quantiles).to(self.device)

    def update_target_network(self):
        self.target_network.load_state_dict(self.q_network.state_dict())

    def act(self, state, epsilon=0.1):
        if random.random() > epsilon:
            state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
            with torch.no_grad():
                quantiles = self.q_network(state)
                q_values = quantiles.mean(dim=2)
                action = q_values.max(1)[1].item()
        else:
            action = random.choice(np.arange(self.action_size))
        return action

    def step(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))
        if len(self.memory) > self.batch_size:
            self.learn()

    def learn(self):
        states, actions, rewards, next_states, dones = zip(*random.sample(self.memory, self.batch_size))
        
        states = torch.FloatTensor(states).to(self.device)
        actions = torch.LongTensor(actions).to(self.device)
        rewards = torch.FloatTensor(rewards).to(self.device)
        next_states = torch.FloatTensor(next_states).to(self.device)
        dones = torch.FloatTensor(dones).to(self.device)

        # Get target quantiles
        with torch.no_grad():
            next_quantiles = self.target_network(next_states)
            next_q_values = next_quantiles.mean(dim=2)
            next_actions = next_q_values.max(1)[1].unsqueeze(1).unsqueeze(1).expand(self.batch_size, 1, self.num_quantiles)
            next_best_quantiles = next_quantiles.gather(1, next_actions).squeeze(1)
            target_quantiles = rewards.unsqueeze(1) + (self.gamma * next_best_quantiles * (1 - dones.unsqueeze(1)))

        # Get current quantiles
        quantiles = self.q_network(states)
        actions = actions.unsqueeze(1).unsqueeze(1).expand(self.batch_size, 1, self.num_quantiles)
        current_quantiles = quantiles.gather(1, actions).squeeze(1)

        # Ensure non-decreasing quantiles
        current_quantiles = torch.sort(current_quantiles, dim=1)[0]

        # Compute quantile regression loss
        td_errors = target_quantiles.unsqueeze(2) - current_quantiles.unsqueeze(1)
        huber_loss = 0.5 * td_errors.pow(2) * (td_errors.abs() <= 1.0).float() + (td_errors.abs() - 0.5) * (td_errors.abs() > 1.0).float()
        quantile_loss = torch.abs(self.taus.unsqueeze(0).unsqueeze(0) - (td_errors < 0).float()) * huber_loss
        loss = quantile_loss.sum(dim=1).mean(dim=1).mean()

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def save(self, filename):
        torch.save({
            'q_network_state_dict': self.q_network.state_dict(),
            'target_network_state_dict': self.target_network.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }, filename)

    def load(self, filename):
        checkpoint = torch.load(filename, map_location=self.device)
        self.q_network.load_state_dict(checkpoint['q_network_state_dict'])
        self.target_network.load_state_dict(checkpoint['target_network_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.q_network.to(self.device)
        self.target_network.to(self.device)


In [12]:
def test_agent(num_episodes,load_path=None, device='cpu'):
    env = gym.make('LunarLander-v2',render_mode= 'human')
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.n
    num_quantiles = 51
    gamma = 0.99
    lr = 0.001
    batch_size = 64
    buffer_size = 100000
    device = torch.device(device)

    agent = Agent(state_size, action_size, num_quantiles, gamma, lr, batch_size, buffer_size, device)
    
    if load_path:
        agent.load(load_path)
        print(f"Loaded model from {load_path}")

    for e in range(num_episodes):
        state = env.reset()
        total_reward = 0
        done = False
        while not done:
            action = agent.act(state)
            next_state, reward, done, _ = env.step(action)
            state = next_state
            total_reward += reward
            
        
        print(f"Episode {e+1}/{num_episodes}, Total Reward: {total_reward}")
    env.close()
        



In [13]:
test_agent(10,load_path='nqfn_lunar_lander.pth', device='cuda' if torch.cuda.is_available() else 'cpu')


Loaded model from nqfn_lunar_lander.pth
Episode 1/10, Total Reward: 278.18287873173983
Episode 2/10, Total Reward: 231.53907621572648
Episode 3/10, Total Reward: 275.96262742313814
Episode 4/10, Total Reward: 256.4528824094258
Episode 5/10, Total Reward: 285.74298931104886
Episode 6/10, Total Reward: 260.1863867506504
Episode 7/10, Total Reward: 276.49075048501606
Episode 8/10, Total Reward: 257.05529109964937
Episode 9/10, Total Reward: 266.20049155292406
Episode 10/10, Total Reward: 247.0169547618857
