In [18]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import namedtuple, deque
import gym

In [19]:
env = gym.envs.make("LunarLander-v2")

  deprecation(
  deprecation(


In [20]:
# Define experience tuple for replay buffer
Experience = namedtuple('Experience', ('state', 'action', 'reward', 'next_state', 'done'))

In [21]:
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

In [22]:
class IQN(nn.Module):
    def __init__(self, state_dim, action_dim, num_quantiles, hidden_dim=128):
        super(IQN, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, action_dim * num_quantiles)
        self.num_quantiles = num_quantiles
        self.action_dim = action_dim

    def forward(self, x, taus):
        batch_size = x.size(0)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        quantiles = self.fc3(x).view(batch_size, self.num_quantiles, self.action_dim)
        return quantiles

In [23]:
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def add_experience(self, experience):
        self.buffer.append(experience)

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        return Experience(*zip(*batch))


In [24]:
num_quantiles = 10
hidden_dim = 128
capacity = 10000
batch_size = 64
gamma = 0.99
update_freq = 10  # Update target network every 10 episodes or steps
num_episodes = 1000  # Number of training episodes
epsilon_start = 1.0
epsilon_end = 0.01
epsilon_decay = 0.995

In [25]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [26]:
# Initialize networks and optimizer
main_net = IQN(state_dim, action_dim, num_quantiles, hidden_dim).to(device)
target_net = IQN(state_dim, action_dim, num_quantiles, hidden_dim).to(device)
# target_net.load_state_dict(main_net.state_dict())  # Initialize target network with main network's parameters
optimizer = optim.Adam(main_net.parameters(), lr=0.001)

In [27]:
def save_checkpoint(state, filename='checkpoint.pth'):
    torch.save(state, filename)

In [28]:
def load_checkpoint(filename='checkpoint.pth', map_location=None):
    if map_location:
        return torch.load(filename, map_location=map_location)
    return torch.load(filename)

In [29]:
# Load model if available
checkpoint_path = 'IQN_lunar_lander.pth'
try:
    checkpoint = load_checkpoint(checkpoint_path)
    main_net.load_state_dict(checkpoint['main_net_state_dict'])
    target_net.load_state_dict(checkpoint['target_net_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epsilon = checkpoint['epsilon']
    start_episode = checkpoint['episode'] + 1
    print(f"Loaded checkpoint from episode {start_episode}")
except FileNotFoundError:
    print("No checkpoint found, starting from scratch.")

Loaded checkpoint from episode 251


In [30]:
target_net.load_state_dict(main_net.state_dict())  # Initialize target network with main network's parameters

<All keys matched successfully>

In [31]:
# Initialize replay buffer
replay_buffer = ReplayBuffer(capacity)

In [32]:
def quantile_huber_loss(predictions, targets, taus, kappa=1.0):
    """Calculates the quantile Huber loss."""
    u = targets - predictions
    abs_u = torch.abs(u)
    huber_loss = torch.where(abs_u <= kappa, 0.5 * u ** 2, kappa * (abs_u - 0.5 * kappa))
    loss = (torch.abs(taus - (u < 0).float()) * huber_loss).mean()
    return loss

In [33]:
# Training loop
episode_rewards = []
epsilon = epsilon_start

In [34]:
for episode in range(num_episodes):
    # Update target network periodically
    if episode % update_freq == 0:
        target_net.load_state_dict(main_net.state_dict())

    state = env.reset()

    episode_reward = 0
    done = False

    while not done:
        # Epsilon-greedy action selection
        if random.random() < epsilon:
            action = env.action_space.sample()  # Random action
        else:
            state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)
            taus = torch.rand((1, num_quantiles), dtype=torch.float32).to(device)  # Sample quantile fractions
            with torch.no_grad():
                q_quantiles = main_net(state_tensor, taus)
            q_values = q_quantiles.mean(dim=1)
            action = q_values.argmax().item()  # Best action

        # Take action and observe next state, reward, and done flag
        next_state, reward, done, _ = env.step(action)

        # Store experience in replay buffer
        replay_buffer.add_experience(Experience(state, action, reward, next_state, done))

        # Update state and episode reward
        state = next_state
        episode_reward += reward

        # Sample minibatch from replay buffer
        if len(replay_buffer.buffer) >= batch_size:
            experiences = replay_buffer.sample(batch_size)

            # Prepare minibatch tensors
            states = torch.tensor(experiences.state, dtype=torch.float32).to(device)
            actions = torch.tensor(experiences.action).unsqueeze(1).to(device)
            rewards = torch.tensor(experiences.reward, dtype=torch.float32).unsqueeze(1).to(device)
            next_states = torch.tensor(experiences.next_state, dtype=torch.float32).to(device)
            dones = torch.tensor(experiences.done, dtype=torch.float32).unsqueeze(1).to(device)

            # Sample quantile fractions
            taus = torch.rand((batch_size, num_quantiles), dtype=torch.float32).to(device)

            # Compute Q-values and target Q-values
            q_quantiles = main_net(states, taus).gather(2, actions.unsqueeze(1).expand(-1, num_quantiles, -1)).squeeze(-1)
            with torch.no_grad():
                next_q_quantiles = target_net(next_states, taus)
                next_q_values = next_q_quantiles.mean(dim=1)
                next_actions = next_q_values.argmax(dim=1, keepdim=True)
                target_quantiles = rewards + gamma * next_q_quantiles.gather(2, next_actions.unsqueeze(1).expand(-1, num_quantiles, -1)).squeeze(-1) * (1 - dones)

            # Compute loss and update main network
            loss = quantile_huber_loss(q_quantiles, target_quantiles, taus)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    # Decay epsilon
    epsilon = max(epsilon_end, epsilon_decay * epsilon)



    # Logging and monitoring
    episode_rewards.append(episode_reward)
    print(f"Episode {episode}, Reward: {episode_reward}, Epsilon: {epsilon}")

    # Save model periodically
    if episode % 50 == 0:
        save_checkpoint({
            'episode': episode,
            'main_net_state_dict': main_net.state_dict(),
            'target_net_state_dict': target_net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'epsilon': epsilon
        }, checkpoint_path)
        print(f"Checkpoint saved at episode {episode}")

    if(sum(episode_rewards[-5:])>1000):
      print(sum(episode_rewards[-5:])>1000)
      print("Training done")
      save_checkpoint({
            'episode': episode,
            'main_net_state_dict': main_net.state_dict(),
            'target_net_state_dict': target_net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'epsilon': epsilon
        }, checkpoint_path)
      print(f"Checkpoint saved at episode {episode}")
      break



Episode 0, Reward: -93.33779579527578, Epsilon: 0.995
Checkpoint saved at episode 0
Episode 1, Reward: -41.0730593065657, Epsilon: 0.990025
Episode 2, Reward: -100.04860265017787, Epsilon: 0.985074875
Episode 3, Reward: -288.04664662710957, Epsilon: 0.9801495006250001
Episode 4, Reward: -167.40142037074318, Epsilon: 0.9752487531218751
Episode 5, Reward: -136.78710516066715, Epsilon: 0.9703725093562657
Episode 6, Reward: -377.3619240888442, Epsilon: 0.9655206468094844
Episode 7, Reward: -117.5860409270703, Epsilon: 0.960693043575437
Episode 8, Reward: -97.64211198198075, Epsilon: 0.9558895783575597
Episode 9, Reward: -152.21976949439943, Epsilon: 0.9511101304657719
Episode 10, Reward: -133.5577737952258, Epsilon: 0.946354579813443
Episode 11, Reward: -101.11919574765388, Epsilon: 0.9416228069143757
Episode 12, Reward: -177.5118428332503, Epsilon: 0.9369146928798039
Episode 13, Reward: -69.63324316770924, Epsilon: 0.9322301194154049
Episode 14, Reward: -88.7137438132228, Epsilon: 0.92756