In [45]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gym
import random
from collections import deque

In [46]:
class GaussianDQN(nn.Module):
    def __init__(self, state_dim, action_dim,hidden_dim=256):
        super(GaussianDQN, self).__init__()
        self.fc1 = nn.Linear(state_dim,hidden_dim)
        self.fc2 = nn.Linear(hidden_dim,hidden_dim)
        self.mean_head = nn.Linear(hidden_dim, action_dim)
        self.log_std_head = nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        mean = self.mean_head(x)
        log_std = self.log_std_head(x)
        std = torch.exp(log_std)
        return mean, std

In [47]:
def select_action(state, network, epsilon, action_dim):
    if np.random.rand() < epsilon:
        return np.random.randint(action_dim)
    else:
        state = torch.FloatTensor(state).unsqueeze(0)
        with torch.no_grad():
            mean, _ = network(state)
        return mean.argmax().item()

In [48]:
def load_checkpoint(filename='checkpoint.pth', map_location=None):
    if map_location:
        return torch.load(filename, map_location=map_location)
    return torch.load(filename)

In [49]:

PATH = "gdqn_kld2.pth"

In [50]:
# Load model if available
checkpoint_path = PATH

In [51]:
env = gym.make("LunarLander-v2",render_mode="human")
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
hidden_dim = 128

In [52]:
network = GaussianDQN(state_dim, action_dim,hidden_dim)
target_network = GaussianDQN(state_dim, action_dim,hidden_dim)

In [53]:
optimizer = optim.Adam(network.parameters(), lr=0.001)

In [54]:
try:
    map_location = torch.device('cpu') if not torch.cuda.is_available() else None
    checkpoint = load_checkpoint(checkpoint_path, map_location=map_location)
    network.load_state_dict(checkpoint['main_net_state_dict'])
    target_network.load_state_dict(checkpoint['target_net_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epsilon = checkpoint['epsilon']
    start_episode = checkpoint['episode'] + 1
    print(f"Loaded checkpoint from episode {start_episode}")
except FileNotFoundError:
    print("No checkpoint found, starting from scratch.")

Loaded checkpoint from episode 76


In [55]:
for episode in range(10):
    state = env.reset()
    episode_reward = 0

    while True:
        action = select_action(state, network, epsilon, action_dim)
        next_state, reward, done, _ = env.step(action)
        state = next_state
        episode_reward += reward
        

        if done:
            break

    print(f"Episode: {episode}, Reward: {episode_reward}")

    

env.close()

Episode: 0, Reward: 233.47002211016206
Episode: 1, Reward: -244.13384359922625
Episode: 2, Reward: -36.576091541151854
Episode: 3, Reward: -42.87249685932948
Episode: 4, Reward: 227.8820148537693
Episode: 5, Reward: -15.991582067515054
Episode: 6, Reward: 236.32616666124818
Episode: 7, Reward: -60.77750511346356
Episode: 8, Reward: -172.27186648239103
Episode: 9, Reward: -126.1109580625549
