In [1]:
import random
from collections import deque
from typing import Tuple

import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim

from environments.EnvRandomReturn import EnvRandomReturn

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Q-Network

In [2]:
class QNetwork(nn.Module):
    """ 
    The QNetwork should output Q-values for each discrete action.
    """
    def __init__(self, state_dim, action_dim, hidden_size=64):
        super().__init__()
        self.fc1 = nn.Linear(state_dim, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, action_dim)
    
    def forward(self, x):
        """ Defines the forward pass of the network """
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)  # Q-values for each action

# World-Model LSTM

In [3]:
class WorldModel(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_size=64, num_layers=2):
        super().__init__()

        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # LSTM for temporal modeling
        self.lstm = nn.LSTM(state_dim + action_dim, hidden_size, num_layers, batch_first=True)

        # Output layers
        self.fc_state = nn.Linear(hidden_size, state_dim)
        self.fc_reward = nn.Linear(hidden_size, 1)
        self.fc_done = nn.Linear(hidden_size, 1)

    def forward(self, observation, action, hidden_state=None):
        x = torch.cat([observation, action], dim=-1)  # (batch, seq_len, input_dim)

        # Initialize hidden state if not provided
        if hidden_state is None:
            h_0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size, device=x.device)
            c_0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size, device=x.device)
            hidden_state = (h_0, c_0)

        # Pass through LSTM
        lstm_out, hidden_state = self.lstm(x, hidden_state)  # lstm_out: (batch, seq_len, hidden_size)

        # Compute outputs
        predicted_state = self.fc_state(lstm_out)
        predicted_reward = torch.tanh(self.fc_reward(lstm_out))
        predicted_done = torch.sigmoid(self.fc_done(lstm_out))

        return predicted_state, predicted_reward, predicted_done, hidden_state

# Replay Buffer

In [4]:
class ReplayBuffer:
    def __init__(self, capacity, seq_len):
        self.capacity = capacity
        self.seq_len = seq_len  # Length of sequences to sample
        self.buffer = deque(maxlen=capacity)  # Use deque for efficiency

    def add(self, state, action, reward, next_state, done):
        """Stores a transition in the replay buffer."""
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        """Samples a batch of sequence-length transitions."""
        if len(self.buffer) < self.seq_len:  # Ensure enough data
            return None

        # Select valid starting indices
        indices = np.random.choice(len(self.buffer) - self.seq_len, batch_size, replace=False)

        sequences = []
        for idx in indices:
            seq = list(self.buffer)[idx : idx + self.seq_len]  # Get a sequence
            states, actions, rewards, next_states, dones = zip(*seq)

            sequences.append((torch.tensor(states, dtype=torch.float32, device=DEVICE),
                              # Long for dicrete space
                              torch.tensor(actions, dtype=torch.long, device=DEVICE),
                              torch.tensor(rewards, dtype=torch.float32, device=DEVICE).unsqueeze(-1),
                              torch.tensor(next_states, dtype=torch.float32, device=DEVICE),
                              torch.tensor(dones, dtype=torch.bool, device=DEVICE).unsqueeze(-1)))

        # Convert list of sequences into batch tensors
        batch = tuple(torch.stack(b, dim=0) for b in zip(*sequences))  
        return batch  # Shapes: (batch_size, seq_len, feature_dim)

    def __len__(self):
        return len(self.buffer)

# Parameters

In [5]:
EPISODES = 500
BATCH_SIZE = 64
LEARNING_RATE = 0.001
EPSILON = 1.0
EPSILON_DECAY = 0.995
MIN_EPSILON = 0.1
GAMMA = 0.99
CAPACITY_BUFFER = 10_000
OBS_SPACE_SIZE = 5
ACTION_SIZE = 6
SEQ_LEN = 32
SIMULATED_UPDATES = 32
IMAGINARY_HORIZON = 10

# Initialize components
Q_NETWORK = QNetwork(state_dim=OBS_SPACE_SIZE, action_dim=ACTION_SIZE)
WORLD_MODEL = WorldModel(state_dim=OBS_SPACE_SIZE, action_dim=ACTION_SIZE).to(DEVICE)
REPLAY_BUFFER = ReplayBuffer(capacity=CAPACITY_BUFFER, seq_len=SEQ_LEN)
IMAGINARY_REPLAY_BUFFER = ReplayBuffer(capacity=CAPACITY_BUFFER, seq_len=SEQ_LEN)

# Initialize optimizers
optimizer_q_network = torch.optim.Adam(Q_NETWORK.parameters(), lr=LEARNING_RATE)
optimizer_world_model = torch.optim.Adam(WORLD_MODEL.parameters(), lr=LEARNING_RATE)

# Initialize hidden state
h = torch.zeros(WORLD_MODEL.num_layers, 1, WORLD_MODEL.hidden_size).to(DEVICE)
c = torch.zeros(WORLD_MODEL.num_layers, 1, WORLD_MODEL.hidden_size).to(DEVICE)
hidden_state = (h, c)

# Initialize loss functions
world_model_loss_fn = torch.nn.MSELoss()

# Environment

In [6]:
ENV = EnvRandomReturn()

  gym.logger.warn(


# Training Loop

In [7]:
for episode in range(EPISODES):
    state, _info = ENV.reset()# Only get observation
    done = False

    while not done:
        # Epsilon greedy policy
        if random.random() < EPSILON:
            # Exploration
            action: int = random.choice(range(6))
        else:
            # Exploitation
            q_values = Q_NETWORK(torch.tensor(state, dtype=torch.float32))
            action = torch.argmax(q_values).item()

        # Transform action len=1 to one_hot len=6 for WorldModle input
        action_one_hot: torch.Tensor = torch.nn.functional.one_hot(
            torch.tensor(action, dtype=torch.long), num_classes=6
        )
        # Transform to numpy array for cleare understanding -> Not needed later i think
        action_one_hot: np.ndarray = action_one_hot.numpy()

        # Step in real environment
        next_state, reward, done, _ = ENV.step(action)

        # Apply to memory of current episode
        REPLAY_BUFFER.add(
            state, # np.ndarray
            action_one_hot, # np.ndarray
            reward, # np.float64
            next_state, # np.ndarray
            done # bool
        )

        state = next_state

    # Decay Epsilon
    EPSILON = max(MIN_EPSILON, EPSILON * EPSILON_DECAY)

    #----------Train WorldModel using real data----------#
    if len(REPLAY_BUFFER) >= BATCH_SIZE + SEQ_LEN:
        (states, actions, rewards, next_states, dones) = REPLAY_BUFFER.sample(BATCH_SIZE)

        states = states.to(DEVICE)
        actions = actions.to(DEVICE)
        rewards = rewards.to(DEVICE)
        next_states = next_states.to(DEVICE)
        dones = dones.to(DEVICE)

        optimizer_world_model.zero_grad()

        h_0 = torch.zeros(WORLD_MODEL.num_layers, BATCH_SIZE, WORLD_MODEL.hidden_size).to(DEVICE)
        c_0 = torch.zeros(WORLD_MODEL.num_layers, BATCH_SIZE, WORLD_MODEL.hidden_size).to(DEVICE)
        hidden_state = (h_0, c_0)

        # Forward pass
        pred_states, pred_rewards, pred_dones, _ = WORLD_MODEL(
            states, actions, hidden_state
        )

        # Compute losses
        state_loss = world_model_loss_fn(pred_states, next_states)
        reward_loss = world_model_loss_fn(pred_rewards, rewards)
        done_loss = world_model_loss_fn(pred_dones, dones.float())
        total_loss = state_loss + reward_loss + done_loss

        print(f"EPISODE: {episode}; WORLD_MODEL LOSS: {total_loss}")

        # Backpropagation
        total_loss.backward()
        optimizer_world_model.step()



    # ---- Use World Model as a simulator to train the Q-Network ----
    for idx, _ in enumerate(range(SIMULATED_UPDATES)):  # Train Q-network multiple times using synthetic experience
        state_sim = torch.tensor(ENV.reset()[0], dtype=torch.float32, device=DEVICE).unsqueeze(0).unsqueeze(0)  # 1 Batch, 1 Seq, 5 elements

        # Initialize hidden states for LSTM
        h_0 = torch.zeros(WORLD_MODEL.lstm.num_layers, 1, WORLD_MODEL.lstm.hidden_size, device=DEVICE)
        c_0 = torch.zeros(WORLD_MODEL.lstm.num_layers, 1, WORLD_MODEL.lstm.hidden_size, device=DEVICE)
        hidden_state = (h_0, c_0)

        for _ in range(IMAGINARY_HORIZON):  # Generate synthetic rollouts
            # Choose action using current Q-network
            q_values = Q_NETWORK(state_sim)
            action_sim = torch.argmax(q_values, dim=-1).item()
            # Convert action to one-hot encoding
            action_one_hot_sim = torch.nn.functional.one_hot(
                torch.tensor(action_sim, dtype=torch.long, device=DEVICE), num_classes=6
            ).unsqueeze(0).unsqueeze(0)
 
            # Predict next state, reward, and done using the World Model
            predicted_next_state, predicted_reward, predicted_done, hidden_state = WORLD_MODEL(
                state_sim, action_one_hot_sim, hidden_state
            )
            # print(action_one_hot_sim.squeeze(0).squeeze(0).detach().cpu().numpy())
            # Store synthetic experience
            IMAGINARY_REPLAY_BUFFER.add(
                state_sim.squeeze(0).squeeze(0).detach().cpu().numpy(),  # Remove batch dims and convert to numpy array
                action_one_hot_sim.squeeze(0).squeeze(0).detach().cpu().numpy(),  # Same for action
                predicted_reward.squeeze().item(),  # Squeeze to remove extra dimension and get scalar
                predicted_next_state.squeeze(0).squeeze(0).detach().cpu().numpy(),  # Squeeze and detach for next_state
                predicted_done.squeeze().item()  # Squeeze and convert to scalar
            )


            if predicted_done.item() > 0.5:  # Terminate if World Model predicts episode end
                break
            
            state_sim = predicted_next_state.detach()  # Move to predicted next state

    # ---- Train Q-Network on synthetic experience ----
    if len(IMAGINARY_REPLAY_BUFFER) >= BATCH_SIZE + SEQ_LEN:
        (states_sim, actions_sim, rewards_sim, next_states_sim, dones_sim) = IMAGINARY_REPLAY_BUFFER.sample(BATCH_SIZE)

        # Move to device
        states_sim = states_sim.to(DEVICE)
        actions_sim = actions_sim.to(DEVICE).long()  # Ensure correct type
        rewards_sim = rewards_sim.to(DEVICE)
        next_states_sim = next_states_sim.to(DEVICE)
        dones_sim = dones_sim.to(DEVICE).float()  # Convert to float

        # Get Q-values
        q_values = Q_NETWORK(states_sim)  
        actions_sim_indices = actions_sim.argmax(dim=-1).view(64, 32, 1)  # Shape [64, 32, 1]
        # Now we gather the Q-values corresponding to the actions
        q_value = q_values.gather(2, actions_sim_indices)
        
        # Compute target using Bellman equation
        next_q_values = Q_NETWORK(next_states_sim).max(1)[0].unsqueeze(1)  # Max Q-value for next state
        target = rewards_sim + GAMMA * next_q_values * (1 - dones_sim)  # Bellman update

        # Compute loss
        loss_q = torch.nn.MSELoss()(q_value, target)

        # Backpropagation
        optimizer_q_network.zero_grad()
        loss_q.backward()
        optimizer_q_network.step()

        print(f"Q-Network Loss: {loss_q.item():.4f}")


  sequences.append((torch.tensor(states, dtype=torch.float32, device=DEVICE),
  return F.mse_loss(input, target, reduction=self.reduction)


Q-Network Loss: 0.0245
Q-Network Loss: 0.0074
Q-Network Loss: 0.0019
Q-Network Loss: 0.0020
Q-Network Loss: 0.0024
Q-Network Loss: 0.0029
Q-Network Loss: 0.0034
Q-Network Loss: 0.0038
Q-Network Loss: 0.0035
EPISODE: 9; WORLD_MODEL LOSS: 2294.375
Q-Network Loss: 0.0047
EPISODE: 10; WORLD_MODEL LOSS: 2289.241943359375
Q-Network Loss: 0.0040
EPISODE: 11; WORLD_MODEL LOSS: 2284.106689453125
Q-Network Loss: 0.0022
EPISODE: 12; WORLD_MODEL LOSS: 2279.398193359375
Q-Network Loss: 0.0032
EPISODE: 13; WORLD_MODEL LOSS: 3549.140869140625
Q-Network Loss: 0.0028
EPISODE: 14; WORLD_MODEL LOSS: 3470.896484375
Q-Network Loss: 0.0035
EPISODE: 15; WORLD_MODEL LOSS: 3465.9599609375
Q-Network Loss: 0.0037
EPISODE: 16; WORLD_MODEL LOSS: 3466.772705078125
Q-Network Loss: 0.0049
EPISODE: 17; WORLD_MODEL LOSS: 3357.652099609375
Q-Network Loss: 0.0060
EPISODE: 18; WORLD_MODEL LOSS: 3363.933837890625
Q-Network Loss: 0.0067
EPISODE: 19; WORLD_MODEL LOSS: 3958.60302734375
Q-Network Loss: 0.0115
EPISODE: 20; WORL

error: Not connected to physics server.