In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import namedtuple, deque
import gym

In [14]:
env = gym.envs.make("LunarLander-v2",render_mode="human")
# env = gym.envs.make("MountainCar-v0",render_mode="human")

  deprecation(
  deprecation(


In [15]:
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

In [16]:
class IQN(nn.Module):
    def __init__(self, state_dim, action_dim, num_quantiles, hidden_dim=128):
        super(IQN, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, action_dim * num_quantiles)
        self.num_quantiles = num_quantiles
        self.action_dim = action_dim

    def forward(self, x, taus):
        batch_size = x.size(0)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        quantiles = self.fc3(x).view(batch_size, self.num_quantiles, self.action_dim)
        return quantiles

In [17]:
num_quantiles = 10
hidden_dim = 128
# capacity = 10000
# batch_size = 64
# gamma = 0.99

num_episodes = 10  # Number of training episodes

In [18]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [19]:
PATH= 'IQN_lunar_lander.pth'
# PATH = 'IQN_V5_car.pth'

In [20]:
def save_checkpoint(state, filename='checkpoint.pth'):
    torch.save(state, filename)

def load_checkpoint(filename='checkpoint.pth'):
    return torch.load(filename)

In [21]:
# Initialize networks and optimizer
main_net = IQN(state_dim, action_dim, num_quantiles, hidden_dim).to(device)
target_net = IQN(state_dim, action_dim, num_quantiles, hidden_dim).to(device)
# target_net.load_state_dict(main_net.state_dict())  # Initialize target network with main network's parameters
optimizer = optim.Adam(main_net.parameters(), lr=0.001)

In [22]:
def load_checkpoint(filename='checkpoint.pth', map_location=None):
    if map_location:
        return torch.load(filename, map_location=map_location)
    return torch.load(filename)

# Load model if available
checkpoint_path = PATH
try:
    map_location = torch.device('cpu') if not torch.cuda.is_available() else None
    checkpoint = load_checkpoint(checkpoint_path, map_location=map_location)
    main_net.load_state_dict(checkpoint['main_net_state_dict'])
    target_net.load_state_dict(checkpoint['target_net_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epsilon = checkpoint['epsilon']
    start_episode = checkpoint['episode'] + 1
    print(f"Loaded checkpoint from episode {start_episode}")
except FileNotFoundError:
    print("No checkpoint found, starting from scratch.")


Loaded checkpoint from episode 309


In [23]:
for episode in range(num_episodes):
  
    state = env.reset()
    episode_reward = 0
    done = False

    while not done:
        # Epsilon-greedy action selection
        if random.random() < epsilon:
            action = env.action_space.sample()  # Random action
        else:
            state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)
            taus = torch.rand((1, num_quantiles), dtype=torch.float32).to(device)  # Sample quantile fractions
            with torch.no_grad():
                q_quantiles = main_net(state_tensor, taus)
            q_values = q_quantiles.mean(dim=1)
            action = q_values.argmax().item()  # Best action

        # Take action and observe next state, reward, and done flag
        next_state, reward, done, _ = env.step(action)

       
        # Update state and episode reward
        state = next_state
        episode_reward += reward

    print(f"Episode{episode}: Reward {episode_reward}")


env.close()


Episode0: Reward -104.68130946356618
Episode1: Reward 239.49598295719156
Episode2: Reward 108.26389394741501
Episode3: Reward 266.1209549474313
Episode4: Reward 261.9478788367451
Episode5: Reward 214.05632271692815
Episode6: Reward 237.62608872380812
Episode7: Reward 39.79109498555382
Episode8: Reward 113.13347337425124
Episode9: Reward 176.96015179287696
