In [1]:
!pip install gym==0.25.2
!pip install swig
!pip install gym[box2d]



In [2]:
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
from collections import deque
import os

In [3]:
PATH = '/content/drive/MyDrive/Pytorch/rl/FQF'

  and should_run_async(code)


In [4]:
os.chdir(PATH)

In [5]:
!pwd

/content/drive/MyDrive/Pytorch/rl/FQF


In [6]:


# Define the FQF Network
class FQFDQN(nn.Module):
    def __init__(self, state_dim, action_dim, num_quantiles=51, hidden_dim=256):
        super(FQFDQN, self).__init__()
        self.num_quantiles = num_quantiles
        self.action_dim = action_dim

        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.quantile_head = nn.Linear(hidden_dim, action_dim * num_quantiles)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        quantiles = self.quantile_head(x)
        quantiles = quantiles.view(-1, self.action_dim, self.num_quantiles)
        return quantiles


In [7]:
def select_action(state, network, epsilon, action_dim, device):
    if random.random() < epsilon:
        return random.randrange(action_dim)
    else:
        state = torch.FloatTensor(state).unsqueeze(0).to(device)
        with torch.no_grad():
            quantiles = network(state)
            mean_quantiles = quantiles.mean(dim=2)
        return mean_quantiles.max(1)[1].item()

def quantile_huber_loss(pred_quantiles, target_quantiles, taus, kappa=1.0):
    diff = target_quantiles.unsqueeze(1) - pred_quantiles.unsqueeze(2)
    huber_loss = torch.where(
        torch.abs(diff) < kappa,
        0.5 * diff ** 2,
        kappa * (torch.abs(diff) - 0.5 * kappa)
    )
    taus = taus.unsqueeze(-1).expand_as(huber_loss)
    loss = (taus - (diff < 0).float()).abs() * huber_loss
    return loss.mean()

def save_checkpoint(state, filename='checkpoint.pth'):
    torch.save(state, filename)

def load_checkpoint(filename='checkpoint.pth', map_location=None):
    if map_location:
        return torch.load(filename, map_location=map_location)
    return torch.load(filename)

In [8]:
# Hyperparameters
num_episodes = 1000
batch_size = 64
gamma = 0.99
epsilon_start = 1.0
epsilon_end = 0.01
epsilon_decay = 0.995
learning_rate = 1e-3
target_update_steps = 1000
num_quantiles = 51
hidden_dim = 256

In [9]:
env = gym.make('LunarLander-v2')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  deprecation(
  deprecation(


In [10]:
network = FQFDQN(state_dim, action_dim, num_quantiles, hidden_dim).to(device)
target_network = FQFDQN(state_dim, action_dim, num_quantiles, hidden_dim).to(device)

In [11]:
optimizer = optim.Adam(network.parameters(), lr=learning_rate)

In [12]:
checkpoint_path = 'fqf.pth'

  and should_run_async(code)


In [13]:
try:
    map_location = torch.device('cpu') if not torch.cuda.is_available() else None
    checkpoint = load_checkpoint(checkpoint_path, map_location=map_location)
    network.load_state_dict(checkpoint['main_net_state_dict'])
    target_network.load_state_dict(checkpoint['target_net_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epsilon = checkpoint['epsilon']
    start_episode = checkpoint['episode'] + 1
    print(f"Loaded checkpoint from episode {start_episode}")
except FileNotFoundError:
    print("No checkpoint found, starting from scratch.")

No checkpoint found, starting from scratch.


In [14]:

target_network.load_state_dict(network.state_dict())

<All keys matched successfully>

In [15]:



replay_buffer = deque(maxlen=10000)
epsilon = epsilon_start
total_steps = 0
episode_rewards = []

for episode in range(num_episodes):
    state = env.reset()
    episode_reward = 0

    while True:
        action = select_action(state, network, epsilon, action_dim, device)
        next_state, reward, done, _ = env.step(action)
        replay_buffer.append((state, action, reward, next_state, done))

        if len(replay_buffer) > batch_size:
            batch = random.sample(replay_buffer, batch_size)
            states, actions, rewards, next_states, dones = zip(*batch)

            states = torch.FloatTensor(states).to(device)
            actions = torch.LongTensor(actions).to(device)
            rewards = torch.FloatTensor(rewards).to(device)
            next_states = torch.FloatTensor(next_states).to(device)
            dones = torch.FloatTensor(dones).to(device)

            quantiles = network(states)
            actions = actions.unsqueeze(1).unsqueeze(2).expand(batch_size, 1, num_quantiles)
            quantiles = quantiles.gather(1, actions).squeeze(1)

            with torch.no_grad():
                next_quantiles = target_network(next_states)
                next_actions = next_quantiles.mean(dim=2).max(1)[1]
                next_actions = next_actions.unsqueeze(1).unsqueeze(2).expand(batch_size, 1, num_quantiles)
                next_target_quantiles = next_quantiles.gather(1, next_actions).squeeze(1)
                target_quantiles = rewards.unsqueeze(1) + gamma * (1 - dones.unsqueeze(1)) * next_target_quantiles

            taus = torch.rand(batch_size, num_quantiles).to(device)
            loss = quantile_huber_loss(quantiles, target_quantiles, taus)

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(network.parameters(), 1.0)  # Gradient clipping
            optimizer.step()

        state = next_state
        episode_reward += reward
        total_steps += 1

        if done:
            break

        if total_steps % target_update_steps == 0:
            target_network.load_state_dict(network.state_dict())

    epsilon = max(epsilon_end, epsilon * epsilon_decay)
    # Logging and monitoring
    episode_rewards.append(episode_reward)
    print(f"Episode {episode + 1}, Reward: {episode_reward}, Epsilon: {epsilon:.2f}")

    # Save model periodically
    if episode % 50 == 0:
        save_checkpoint({
            'episode': episode,
            'main_net_state_dict': network.state_dict(),
            'target_net_state_dict': target_network.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'epsilon': epsilon
        },checkpoint_path)
        print(f"Checkpoint saved at episode {episode}")


    # Early stopping condition
    if sum(episode_rewards[-5:]) > 1000:
        print("Training done")
        save_checkpoint({
            'episode': episode,
            'main_net_state_dict': network.state_dict(),
            'target_net_state_dict': target_network.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'epsilon': epsilon
        },checkpoint_path)
        print(f"Checkpoint saved at episode {episode}")
        break


  if not isinstance(terminated, (bool, np.bool8)):
  states = torch.FloatTensor(states).to(device)


Episode 1, Reward: -95.08900329208883, Epsilon: 0.99
Checkpoint saved at episode 0
Episode 2, Reward: -523.5931071922687, Epsilon: 0.99
Episode 3, Reward: -112.40846539255574, Epsilon: 0.99
Episode 4, Reward: 82.0524337492685, Epsilon: 0.98
Episode 5, Reward: -290.1639843312145, Epsilon: 0.98
Episode 6, Reward: -187.79072942766783, Epsilon: 0.97
Episode 7, Reward: -58.93888695483702, Epsilon: 0.97
Episode 8, Reward: -123.25033492164492, Epsilon: 0.96
Episode 9, Reward: -334.2461458971462, Epsilon: 0.96
Episode 10, Reward: -44.34722865072406, Epsilon: 0.95
Episode 11, Reward: -389.44943607561567, Epsilon: 0.95
Episode 12, Reward: -68.37985559047175, Epsilon: 0.94
Episode 13, Reward: -401.55595302330056, Epsilon: 0.94
Episode 14, Reward: -144.27076314964893, Epsilon: 0.93
Episode 15, Reward: -262.0000517320025, Epsilon: 0.93
Episode 16, Reward: -85.17927355132683, Epsilon: 0.92
Episode 17, Reward: -153.64137043627505, Epsilon: 0.92
Episode 18, Reward: -95.70030688169854, Epsilon: 0.91
Ep