In [None]:
import random
from collections import deque
from typing import Tuple

import torch
import numpy as np
import torch.nn as nn
import gymnasium as gym
import torch.optim as optim

from environments.EnvRandomReturn import EnvRandomReturn

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Q-Network

In [None]:
class QNetwork(nn.Module):
    """ 
    The QNetwork should output Q-values for each discrete action.
    """
    def __init__(self, state_dim, action_dim, hidden_size=64):
        super().__init__()
        self.fc1 = nn.Linear(state_dim, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, action_dim)
    
    def forward(self, x):
        """ Defines the forward pass of the network """
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)  # Q-values for each action

# World-Model LSTM

In [None]:
class WorldModel(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_size=64, num_layers=2):
        super().__init__()

        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # LSTM for temporal modeling
        self.lstm = nn.LSTM(state_dim + action_dim, hidden_size, num_layers, batch_first=True)

        # Output layers
        self.fc_state = nn.Linear(hidden_size, state_dim)
        self.fc_reward = nn.Linear(hidden_size, 1)
        self.fc_done = nn.Linear(hidden_size, 1)

    def forward(self, observation, action, hidden_state=None):
        x = torch.cat([observation, action], dim=-1)  # (batch, seq_len, input_dim)

        # Initialize hidden state if not provided
        if hidden_state is None:
            h_0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size, device=x.device)
            c_0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size, device=x.device)
            hidden_state = (h_0, c_0)

        # Pass through LSTM
        lstm_out, hidden_state = self.lstm(x, hidden_state)  # lstm_out: (batch, seq_len, hidden_size)

        # Compute outputs
        predicted_state = self.fc_state(lstm_out)
        predicted_reward = torch.tanh(self.fc_reward(lstm_out))
        predicted_done = torch.sigmoid(self.fc_done(lstm_out))

        return predicted_state, predicted_reward, predicted_done, hidden_state

# Replay Buffer

In [None]:
class ReplayBuffer:
    def __init__(self, capacity, seq_len):
        self.capacity = capacity
        self.seq_len = seq_len  # Length of sequences to sample
        self.buffer = deque(maxlen=capacity)  # Use deque for efficiency

    def add(self, state, action, reward, next_state, done):
        """Stores a transition in the replay buffer."""
        # TODO: Maybe make it so we can filter good and bad episodes
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        """Samples a batch of sequence-length transitions, prioritizing recent data."""
        if len(self.buffer) < self.seq_len:
            return None  # Ensure enough data

        recent_fraction = 0.3  # Adjust this to control how much of the buffer is recent
        recent_size = max(self.seq_len, int(len(self.buffer) * recent_fraction))  
        max_index = len(self.buffer) - recent_size  # Start sampling from this index

        # Select valid starting indices with preference for recent transitions
        indices = np.random.choice(range(max_index, len(self.buffer) - self.seq_len), batch_size, replace=False)

        sequences = []
        for idx in indices:
            seq = list(self.buffer)[idx : idx + self.seq_len]  # Get a sequence
            states, actions, rewards, next_states, dones = zip(*seq)

            sequences.append((torch.tensor(states, dtype=torch.float32, device=DEVICE),
                            torch.tensor(actions, dtype=torch.long, device=DEVICE),  
                            torch.tensor(rewards, dtype=torch.float32, device=DEVICE).unsqueeze(-1),
                            torch.tensor(next_states, dtype=torch.float32, device=DEVICE),
                            torch.tensor(dones, dtype=torch.bool, device=DEVICE).unsqueeze(-1)))

        # Convert list of sequences into batch tensors
        batch = tuple(torch.stack(b, dim=0) for b in zip(*sequences))  
        return batch  # Shapes: (batch_size, seq_len, feature_dim)

    def __len__(self):
        return len(self.buffer)

# Environment

In [None]:
ENV = EnvRandomReturn()

# CartPole Environment (Testing)

In [None]:
# CART_ENV = gym.make("CartPole-v1", render_mode="human")

# Parameters

In [None]:
EPISODES = 500
BATCH_SIZE = 64
LEARNING_RATE = 0.001
EPSILON = 1.0
EPSILON_DECAY = 0.995
MIN_EPSILON = 0.1
GAMMA = 0.99
CAPACITY_BUFFER = 100_000
OBS_SPACE_SIZE = 5
ACTION_SIZE = 6
SEQ_LEN = 32
SIMULATED_UPDATES = 32
IMAGINARY_HORIZON = 10
TRAIN_WORLD_MODEL_EVERY = 10
EPOCHS = 30

# Initialize components
Q_NETWORK = QNetwork(state_dim=OBS_SPACE_SIZE, action_dim=ACTION_SIZE)
WORLD_MODEL = WorldModel(state_dim=OBS_SPACE_SIZE, action_dim=ACTION_SIZE).to(DEVICE)
REPLAY_BUFFER = ReplayBuffer(capacity=CAPACITY_BUFFER, seq_len=SEQ_LEN)
IMAGINARY_REPLAY_BUFFER = ReplayBuffer(capacity=CAPACITY_BUFFER, seq_len=SEQ_LEN)

# Initialize optimizers
optimizer_q_network = torch.optim.Adam(Q_NETWORK.parameters(), lr=LEARNING_RATE)
optimizer_world_model = torch.optim.Adam(WORLD_MODEL.parameters(), lr=LEARNING_RATE)

# Initialize hidden state
h = torch.zeros(WORLD_MODEL.num_layers, BATCH_SIZE, WORLD_MODEL.hidden_size).to(DEVICE)
c = torch.zeros(WORLD_MODEL.num_layers, BATCH_SIZE, WORLD_MODEL.hidden_size).to(DEVICE)
hidden_state_real = (h, c)
# last_real_hidden_state = hidden_state_real

# Initialize loss functions
world_model_loss_fn = torch.nn.MSELoss()

# Fill Replay Buffer

In [None]:
t = 0
while t <= REPLAY_BUFFER.capacity:
    state, _info = ENV.reset()# ENV.reset()
    done = False

    while not done:
        """ 
        Only for the pretraining of the WORLDMODEL before entering the main training loop
        make the randomness of up and down lesser so we get longer episodes and the model
        learns the dependencies between states and its rewards better aswell as logner episodes.
        """
        action: int = random.choice(range(ACTION_SIZE))
        action_one_hot: torch.Tensor = torch.nn.functional.one_hot(
            torch.tensor(action, dtype=torch.long), num_classes=ACTION_SIZE
        )
        # Transform to numpy array for cleare understanding -> Not needed later i think
        action_one_hot: np.ndarray = action_one_hot.numpy()

        # Step in real environment
        next_state, reward, done, _info = ENV.step(action=action)# ENV.step(action)

        # Apply to memory of current episode
        REPLAY_BUFFER.add(
            state, # np.ndarray
            action_one_hot, # np.ndarray
            reward, # np.float64
            next_state, # np.ndarray
            done # bool
        )
       
        state = next_state

        # Update idx
        t += 1

        if t % 10_000 == 0:
            print(f"Length REPLAY_BUFFER {t}/{REPLAY_BUFFER.capacity}")


#----Train world model----#
for epoch in range(EPOCHS):
    num_samples = len(REPLAY_BUFFER)  # Total transitions in buffer
    num_batches = num_samples // BATCH_SIZE  # Total batches per epoch

    hidden_state_real = None  # Reset LSTM hidden state for each epoch
    epoch_loss = 0

    for _ in range(num_batches):  
        batch = REPLAY_BUFFER.sample(BATCH_SIZE)
        if batch is None:  # Skip if buffer is too small
            continue

        states, actions, rewards, next_states, dones = batch

        # Forward pass
        pred_states, pred_rewards, pred_dones, hidden_state_real = WORLD_MODEL(
            states, actions, hidden_state_real
        )

        # Detach hidden state between sequences to prevent gradient leakage
        hidden_state_real = tuple([h.detach() for h in hidden_state_real])  

        # Compute loss
        state_loss = world_model_loss_fn(pred_states, next_states)
        reward_loss = world_model_loss_fn(pred_rewards, rewards)
        done_loss = world_model_loss_fn(pred_dones, dones.float())
        total_loss = state_loss + reward_loss + done_loss

        epoch_loss += total_loss.item()

        # Backpropagation
        optimizer_world_model.zero_grad()
        total_loss.backward()
        optimizer_world_model.step()

    avg_epoch_loss = epoch_loss / num_batches if num_batches > 0 else 0
    print(f"EPOCH: {epoch}; AVERAGE WORLD_MODEL LOSS: {avg_epoch_loss:.4f}")


# Training Loop

In [None]:
# ENV = EnvRandomReturn()
# TODO: Test code on cartpole environment
for episode in range(EPISODES):
    state, _info = ENV.reset()# Only get observation
    done = False

    while not done:
        # Epsilon greedy policy
        if random.random() < EPSILON:
            action: int = random.choice(range(ACTION_SIZE)) # Exploration
        else:
            q_values = Q_NETWORK(torch.tensor(state, dtype=torch.float32))
            action = torch.argmax(q_values).item() # Exploitation

        # Transform action len=1 to one_hot len=6 for WorldModle input
        action_one_hot: torch.Tensor = torch.nn.functional.one_hot(
            torch.tensor(action, dtype=torch.long), num_classes=ACTION_SIZE
        )
        # Transform to numpy array for cleare understanding -> Not needed later i think
        action_one_hot: np.ndarray = action_one_hot.numpy()

        # Step in real environment
        next_state, reward, done, _ = ENV.step(action)

        # Apply to memory of current episode
        REPLAY_BUFFER.add(
            state, # np.ndarray
            action_one_hot, # np.ndarray
            reward, # np.float64
            next_state, # np.ndarray
            done # bool
        )

        state = next_state

    # Decay Epsilon
    EPSILON = max(MIN_EPSILON, EPSILON * EPSILON_DECAY)

    #----------Train WorldModel using real data----------#
    if len(REPLAY_BUFFER) >= BATCH_SIZE + SEQ_LEN and episode % TRAIN_WORLD_MODEL_EVERY == 0:
        # Sample real data from REPLAY_BUFFER
        (states, actions, rewards, next_states, dones) = REPLAY_BUFFER.sample(BATCH_SIZE)

        # Write params to the DEVICE for GPU support
        states = states.to(DEVICE)
        actions = actions.to(DEVICE)
        rewards = rewards.to(DEVICE)
        next_states = next_states.to(DEVICE)
        dones = dones.to(DEVICE)


        # Forward pass
        pred_states, pred_rewards, pred_dones, hidden_state_real = WORLD_MODEL(
            states, actions, hidden_state_real
        )
        hidden_state_real = tuple([h.detach() for h in hidden_state_real])

        # Compute losses
        state_loss = world_model_loss_fn(pred_states, next_states)
        reward_loss = world_model_loss_fn(pred_rewards, rewards)
        done_loss = world_model_loss_fn(pred_dones, dones.float())
        total_loss = state_loss + reward_loss + done_loss

        print(f"EPISODE: {episode}; WORLD_MODEL LOSS: {total_loss}")

        optimizer_world_model.zero_grad()
        # Backpropagation
        total_loss.backward()
        optimizer_world_model.step()

        # Update last_real_hidden_state for LSTM Learning
        last_real_hidden_state = hidden_state_real


    # ---- Use World Model as a simulator to train the Q-Network ----
    for idx, _ in enumerate(range(SIMULATED_UPDATES)):  # Train Q-network multiple times using synthetic experience
        state_sim = torch.tensor(ENV.reset()[0], dtype=torch.float32, device=DEVICE).unsqueeze(0).unsqueeze(0)  # 1 Batch, 1 Seq, 5 elements

        # Intitialize hidden state sim from last real hidden state (num_layers, BATCH, hidden_size)
        hidden_state_sim = (
            last_real_hidden_state[0][:, 0:1, :].detach(),  # Shape: (num_layers, 1, hidden_size)
            last_real_hidden_state[1][:, 0:1, :].detach()
        )

        for _ in range(IMAGINARY_HORIZON):
            # Choose action using current Q-network
            q_values = Q_NETWORK(state_sim)
            action_sim = torch.argmax(q_values, dim=-1).item()

            # Convert action to one-hot encoding
            action_one_hot_sim = torch.nn.functional.one_hot(
                torch.tensor(action_sim, dtype=torch.long, device=DEVICE), num_classes=ACTION_SIZE
            ).unsqueeze(0).unsqueeze(0)
 
            # Predict next state, reward, and done using the World Model
            predicted_next_state, predicted_reward, predicted_done, hidden_state_sim = WORLD_MODEL(
                state_sim, action_one_hot_sim, hidden_state_sim
            )

            # Store synthetic experience
            # Remove dimensionalitys which were needed for LSTM to store in buffer
            IMAGINARY_REPLAY_BUFFER.add(
                state_sim.squeeze(0).squeeze(0).detach().cpu().numpy(),
                action_one_hot_sim.squeeze(0).squeeze(0).detach().cpu().numpy(),
                predicted_reward.squeeze().item(),
                predicted_next_state.squeeze(0).squeeze(0).detach().cpu().numpy(),
                predicted_done.squeeze().item()
            )


            # Ensure termination if certain 90%
            if predicted_done.item() > 0.9:
                break
            
            state_sim = predicted_next_state.detach()  # Move to predicted next state

    # ---- Train Q-Network on synthetic experience ----
    if len(IMAGINARY_REPLAY_BUFFER) >= BATCH_SIZE + SEQ_LEN:
        # Train from real and synthetic experiences
        real_sample_size = int(BATCH_SIZE * 0.5)  # Half from real data
        sim_sample_size = BATCH_SIZE - real_sample_size  # Half from simulated

        real_experiences = REPLAY_BUFFER.sample(real_sample_size)
        simulated_experiences = IMAGINARY_REPLAY_BUFFER.sample(sim_sample_size)

        # Combine real and simulated experiences
        states_sim = torch.cat([real_experiences[0], simulated_experiences[0]], dim=0).to(DEVICE)
        actions_sim = torch.cat([real_experiences[1], simulated_experiences[1]], dim=0).to(DEVICE)
        rewards_sim = torch.cat([real_experiences[2], simulated_experiences[2]], dim=0).to(DEVICE)
        next_states_sim = torch.cat([real_experiences[3], simulated_experiences[3]], dim=0).to(DEVICE)
        dones_sim = torch.cat([real_experiences[4], simulated_experiences[4]], dim=0).to(DEVICE).float()

        # Get Q-values
        q_values = Q_NETWORK(states_sim)  
        actions_sim_indices = actions_sim.argmax(dim=-1).unsqueeze(-1)# .view(64, 32, 1)  # Shape [64, 32, 1]

        # Now we gather the Q-values corresponding to the actions
        q_value = q_values.gather(2, actions_sim_indices)
        
        # Compute target using Bellman equation
        next_q_values = Q_NETWORK(next_states_sim).max(-1)[0].unsqueeze(-1)  # Max Q-value for next state
        target = rewards_sim + GAMMA * next_q_values * (1 - dones_sim)  # Bellman update

        # Compute loss
        loss_q = torch.nn.MSELoss()(q_value, target)

        # Backpropagation
        optimizer_q_network.zero_grad()
        loss_q.backward()
        optimizer_q_network.step()

        print(f"Q-Network Loss: {loss_q.item():.4f}")


# Save models
torch.save(WORLD_MODEL.state_dict(), "world_model.pth")
torch.save(Q_NETWORK.state_dict(), "dqn.pth")