In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import gymnasium as gym

import random
import math
import numpy as np

from collections import deque

In [3]:
class QNetwork(nn.Module):
    def __init__(self, state_dims, action_dims):
        super().__init__()
        self.fc1 = nn.Linear(state_dims, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_dims)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)
    
q_network = QNetwork(8, 4)
optimizer = optim.Adam(q_network.parameters(), lr=0.001)



In [4]:
env = gym.make("LunarLander-v2")

In [5]:
# action = argmax_a Q_policy(s_t, a_t)

def select_action(q_network, state):
    state = torch.tensor(state, dtype = torch.float32)
    q_values = q_network(state)
    action = torch.argmax(q_values).item()
    return action

gamma = 0.99

# Bellman eq.
# Q_policy(s_t, a_t) = max_a_Q_policy_{s_t+1, a_t+1} + R(s_t+1, a_t+1)

# loss = [ max_a_Q_policy_{s_t+1, a_t+1} + R(s_t+1, a_t+1) - Q_policy(s_t, a_t) ]^2
def calculate_loss(q_network, state, action, reward, next_state,done):
    state = torch.tensor(state, dtype=torch.float32)
    next_state = torch.tensor(next_state, dtype=torch.float32)
    reward = torch.tensor(reward, dtype=torch.float32)
    done = torch.tensor(done, dtype=torch.float32)

    q_values = q_network(state)

    current_state_q_values = q_values[action]

    next_state_q_values = q_network(next_state).max().detach()

    target_q_values = reward + next_state_q_values * gamma * (1-done)

    loss = nn.MSELoss()(current_state_q_values, target_q_values)

    return loss



In [6]:
for i in range(1000):
    state, info = env.reset()
    done = False
    step = 0
    while not done:
        step += 1
        action = select_action(q_network, state)

        next_state, reward, terminate, truncated, _ = env.step(action)

        done = terminate or truncated

        loss = calculate_loss(q_network, state, action, reward, next_state,done)

        optimizer.zero_grad()

        loss.backward()

        optimizer.step()

        state = next_state

