In [27]:
using PyCall
py"""
import random
import numpy as np
import collections
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import sys
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

def write_(text):
    with open("debug.txt", "a") as file:
        file.write(f"python - {text}\n")

class QNetwork(nn.Module):
    def __init__(self, params, number_of_actions):
        super(QNetwork, self).__init__()
        self.f1 = nn.Linear(params['state_size'], params['first_layer_size']) # theta
        self.f2 = nn.Linear(params['first_layer_size'], params['second_layer_size'])
        self.f3 = nn.Linear(params['second_layer_size'], params['third_layer_size'])
        self.f4 = nn.Linear(params['third_layer_size'], number_of_actions)
        self.init_weights()

    def forward(self, x):
        x = F.relu(self.f1(x))
        x = F.relu(self.f2(x))
        x = F.relu(self.f3(x))
        x = self.f4(x)
        return x
        
    def init_weights(self):
        for layer in self.modules():
            if isinstance(layer, nn.Linear):
                # Kaiming (He) Uniform initialization
                nn.init.kaiming_uniform_(layer.weight)

class DQNAgent():
    def __init__(self, params, actions):
        super().__init__()
        self.number_of_actions = len(actions)
        self.actions = actions
        self.gamma = params['gamma']
        self.short_memory = np.array([])
        self.memory = collections.deque(maxlen=params['memory_size'])
        self.epsilon = params['epsilon']
        self.model = QNetwork(params, self.number_of_actions)
        self.model.to(DEVICE)
        self.optimizer = optim.Adam(self.model.parameters(), lr=params['learning_rate'])
        self.target_model = QNetwork(params, self.number_of_actions)
        self.target_model.to(DEVICE)
        self.target_model.load_state_dict(self.model.state_dict())
        self.target_model_update_iterations = params['target_model_update_iterations']
        self.current_iteration = 0

    def on_new_sample(self, state, action, reward, next_state, is_done):
        for i, value in enumerate(self.actions):
            if value == action:
                action_index = i
                break
        self.memory.append((state, action_index, reward, next_state, is_done))

    def replay_mem(self, batch_size):
        if len(self.memory) > batch_size:
            minibatch = random.sample(self.memory, batch_size)
        else:
            minibatch = self.memory

        self.model.train()
        torch.set_grad_enabled(True)
        self.optimizer.zero_grad()
        states, actions, rewards, next_states, is_dones = zip(*minibatch)
        states_tensor = torch.tensor(states, dtype=torch.float32).to(DEVICE)
        with torch.no_grad():
            targets = torch.tensor(self.get_targets(rewards, next_states, is_dones)).to(DEVICE)
        outputs = self.model.forward(states_tensor)
        outputs_selected = outputs[torch.arange(len(minibatch)), actions]
        loss = F.mse_loss(outputs_selected, targets)
        loss.backward()
        self.optimizer.step()

        self.current_iteration = self.current_iteration + 1
        if self.current_iteration % self.target_model_update_iterations == 0:
            self.target_model.load_state_dict(self.model.state_dict())

    def select_action_index(self, state, apply_epsilon_random):
        if apply_epsilon_random == True and random.uniform(0, 1) < self.epsilon:
            return self.actions[np.random.choice(self.number_of_actions)] # phidot, psidot actions

        with torch.no_grad():
            state_tensor = torch.tensor(np.array(state)[np.newaxis, :], dtype=torch.float32).to(DEVICE)
            prediction = self.model(state_tensor)
            return self.actions[np.argmax(prediction.detach().cpu().numpy()[0])]

    def get_targets(self, rewards, next_states, is_dones):
        rewards_tensor = torch.tensor(rewards, dtype=torch.float32).to(DEVICE)
        with torch.no_grad():
            next_states_tensor = torch.tensor(next_states, dtype=torch.float32).to(DEVICE)
            q_values_next_states = self.target_model.forward(next_states_tensor)
            max_values, _ = torch.max(q_values_next_states, dim=1)
            targets = rewards_tensor + self.gamma * max_values # Q-Learning is off-policy
            targets = [target if is_done == False else r for r,target,is_done in zip(rewards_tensor,targets,is_dones)]
        return targets
"""