#import

In [26]:
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
from collections import deque
import os

In [27]:


# Define the FQF Network
class FQFDQN(nn.Module):
    def __init__(self, state_dim, action_dim, num_quantiles=51, hidden_dim=256):
        super(FQFDQN, self).__init__()
        self.num_quantiles = num_quantiles
        self.action_dim = action_dim

        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.quantile_head = nn.Linear(hidden_dim, action_dim * num_quantiles)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        quantiles = self.quantile_head(x)
        quantiles = quantiles.view(-1, self.action_dim, self.num_quantiles)
        return quantiles


In [28]:
# Hyperparameters
num_episodes = 10
learning_rate = 1e-3
num_quantiles = 51
hidden_dim = 256

env = gym.make('LunarLander-v2',render_mode='human')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

network = FQFDQN(state_dim, action_dim, num_quantiles, hidden_dim).to(device)
target_network = FQFDQN(state_dim, action_dim, num_quantiles, hidden_dim).to(device)

  deprecation(
  deprecation(


In [29]:
optimizer = optim.Adam(network.parameters(), lr=learning_rate)

In [30]:
def load_checkpoint(filename='checkpoint.pth', map_location=None):
    if map_location:
        return torch.load(filename, map_location=map_location)
    return torch.load(filename)

In [31]:
# Load model if available
checkpoint_path = 'fqf.pth'

In [32]:
try:
    map_location = torch.device('cpu') if not torch.cuda.is_available() else None
    checkpoint = load_checkpoint(checkpoint_path, map_location=map_location)
    network.load_state_dict(checkpoint['main_net_state_dict'])
    target_network.load_state_dict(checkpoint['target_net_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epsilon = checkpoint['epsilon']
    start_episode = checkpoint['episode'] + 1
    print(f"Loaded checkpoint from episode {start_episode}")
except FileNotFoundError:
    print("No checkpoint found, starting from scratch.")

Loaded checkpoint from episode 344


In [33]:
def select_action(state, network, epsilon, action_dim, device):
    if random.random() < epsilon:
        return random.randrange(action_dim)
    else:
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        with torch.no_grad():
            quantiles = network(state)
            mean_quantiles = quantiles.mean(dim=2)
        return mean_quantiles.max(1)[1].item()

In [34]:
for episode in range(10):
    state = env.reset()
    episode_reward = 0

    while True:
        action = select_action(state, network, epsilon, action_dim,device)
        next_state, reward, done, _ = env.step(action)
        state = next_state
        episode_reward += reward
        

        if done:
            break

    print(f"Episode: {episode}, Reward: {episode_reward}")

    

env.close()

  if not isinstance(terminated, (bool, np.bool8)):


Episode: 0, Reward: 9.028107613112525
Episode: 1, Reward: 26.916383843211563
Episode: 2, Reward: -247.92586626066748
Episode: 3, Reward: 277.68905029201443
Episode: 4, Reward: 159.66918545341514
Episode: 5, Reward: -27.25110431160665
Episode: 6, Reward: 156.16109161482973
Episode: 7, Reward: 53.13752000032764
Episode: 8, Reward: 238.28824062507
Episode: 9, Reward: -2.59856231795483
