# Deep Active Inference

We are creating neural networks to scale certain models within the active inference framework.

**Transition Network**: 
The input of the network is the state plus the action and the output is the likely values for the next state (obs)
- We use this network to compute the Predition Error (Pt) which is the difference between what the predicted next state was and what is actually was:
```python
transition_network(state_t0 + action_t0) - state_t1
```

**Policy Network**: 

The input to this network is the current state (observation in MDP's) of the environment and outputs a distribution over actions to be taken by the agent. Generally we then run this through a softmax with precision of 1

**Value Network**: 

This network is detetmining the EFE given a state, similar to the q-value network in reinforcement learning. 

- Input: State (Obs)
- Output: EFE of each action

The training of this network works based on the ReplayMemory. 

We gather a set of say 64 transitions from memory. We then compute the transition prediction error. This is what is used to train the transition network. We are then aiming to create some value approximate for the true EFE of taking an action (policy) in a certain state. 

We can compute what our policy network is outputing at time t2 given an observation, along with the target EFE from the value target network. We then weight the EFE in each state given by the respective probability of performing these actions, we then add the transition mse and with the reward obtained at t1. We then obtain the bootstrap EFE of performing the action we took at t0. Which we then calculate and use the MSE between the bootstrap EFE and the calculated one to train the value network. 

$$
\begin{aligned}
-F= & -E_{Q\left(s_t\right)}\left[\log p\left(o_t \mid s_t\right)\right]-K L\left[Q\left(s_t\right) \| p\left(s_t \mid s_{t-1}, a_{t-1}\right)\right] \\
& -E_{Q\left(s_t\right)} K L\left[Q\left(a_t \mid s_t\right) \| p\left(a_t \mid s_t\right)\right]
\end{aligned}
$$

Note:

- In this case the observation term, $E_{Q\left(s_t\right)}\left[\log p\left(o_t \mid s_t\right)\right]$ , is not neccesary since this is an mdp so the state is the observation and this infernce does not need to occur

- The different between the inferred states and the acutal states. Can I make this the difference between these 2.



# Neural Network

In [52]:
import torch.nn as nn 
import torch.nn.functional as F
import torch.optim as optim

class NeuralNetwork(nn.Module):

    def __init__(self, input_dimension: int, output_dimention: int, pass_through_softmax: bool = False, lr: float = 0.01):
        
        super(NeuralNetwork, self).__init__()

        self.layer_1 = nn.Linear(input_dimension, 64)
        self.layer_2 = nn.Linear(64, output_dimention)
        
        self.pass_through_softmax = pass_through_softmax
        self.optimizer = optim.Adam(self.parameters(), lr)

    def forward(self, x):

        x = F.relu(self.layer_1(x))

        if self.pass_through_softmax:
            return F.softmax(self.layer_2(x), dim=-1)
        else:
            return self.layer_2(x)

# Agent

In [75]:
import gymnasium
from stable_baselines3.common.buffers import ReplayBuffer
import torch
from torch.distributions import Categorical
import numpy as np

class Agent():

    def __init__(self, observation_space: gymnasium.Space, action_space: gymnasium.Space, memory_size: int = 1000, gamma = 1) -> None:

        self.memory = ReplayBuffer(buffer_size=memory_size, observation_space=observation_space, action_space=action_space, handle_timeout_termination=False)

        state_dimention = observation_space.shape[0]
        num_actions     = action_space.n

        self.transition_network = NeuralNetwork(input_dimension=state_dimention + 1, output_dimention=state_dimention)
        self.policy_network     = NeuralNetwork(input_dimension=state_dimention, output_dimention=num_actions, pass_through_softmax=True)

        self.value_network      = NeuralNetwork(input_dimension=state_dimention, output_dimention=num_actions)
        self.target_network     = NeuralNetwork(input_dimension=state_dimention, output_dimention=num_actions)
        self.target_network.load_state_dict(self.value_network.state_dict())

        self.gamma = gamma

    def compute_bootstrap_loss(self, batch_transition_loss, observation_batch, action_batch, next_observation_batch, reward_batch, dones_batch):

        # observation_batch       = torch.tensor(observation_batch)
        # next_observation_batch  = torch.tensor(next_observation_batch)

        with torch.no_grad():

            target_efe_batch            = self.target_network(next_observation_batch)
            action_distribution_batch   = self.policy_network(next_observation_batch)

            # print(target_efe_batch)
            # print(action_distribution_batch)

            weighted_target_efe = ((1-dones_batch) * action_distribution_batch * target_efe_batch).sum(-1).unsqueeze(1)

            # print(weighted_target_efe)

            bootstrap_efe   = - reward_batch + batch_transition_loss + weighted_target_efe

            # print(bootstrap_efe)

        efe = self.value_network(observation_batch).gather(1, action_batch)

        # print(efe)
        # print(bootstrap_efe)

        bootstrap_loss = F.mse_loss(bootstrap_efe, efe)

        return bootstrap_loss

    def compute_vfe(self, transition_loss, observation_batch) -> torch.Tensor:

        # observation_batch = torch.tensor(observation_batch)

        with torch.no_grad():
            efes_batch = self.value_network(observation_batch)

        action_distribution_batch = self.policy_network(observation_batch)

        boltzmann_efe_dist = torch.softmax(-self.gamma * efes_batch, dim=1)

        energy_batch = -(action_distribution_batch * torch.log(action_distribution_batch)).sum(1).unsqueeze(1)

        entropy_batch = -(action_distribution_batch * torch.log(boltzmann_efe_dist)).sum(1).unsqueeze(1)

        vfe = transition_loss + (energy_batch - entropy_batch)

        mean_vfe = torch.mean(vfe)

        return mean_vfe
    
    def compute_transition_loss(self, observation_batch, action_batch, next_observation_batch):
        
        # observation_batch           = torch.tensor(observation_batch)
        # action_batch                = torch.tensor(action_batch)

        batch_state_action_pairs    = torch.cat((observation_batch, action_batch), dim=1)
        transition_batch            = self.transition_network(batch_state_action_pairs)
        # next_observation_batch      = torch.tensor(next_obs)

        batch_transition_loss = torch.mean(F.mse_loss(transition_batch, next_observation_batch, reduction='none'), dim=1).unsqueeze(1)

        # print(f'transition_loss: {batch_transition_loss}')

        return batch_transition_loss
    
    def sample_action(self, obs):

        with torch.no_grad():

            action_distribution = self.policy_network(torch.tensor(obs))

            action_sample = Categorical(action_distribution).sample()

        return action_sample.numpy()
    
    def learn(self):

        batch           = self.memory.sample(4)

        transition_loss = self.compute_transition_loss(batch.observations, batch.actions, batch.next_observations)
        bootstrap_loss  = self.compute_bootstrap_loss(transition_loss, batch.observations, batch.actions, batch.next_observations, batch.rewards , batch.dones)
        vfe             = self.compute_vfe(transition_loss, batch.observations)

        self.policy_network.optimizer.zero_grad()
        self.transition_network.optimizer.zero_grad()
        self.value_network.optimizer.zero_grad()
        
        vfe.backward()
        bootstrap_loss.backward()

        self.policy_network.optimizer.step()
        self.transition_network.optimizer.step()
        self.value_network.optimizer.step()

    def train(self, env: gymnasium.Env, episodes: int = 100, max_episode_length: int = 1000000) -> None:

        episode_rewards = []

        for e in range(episodes):

            obs, _ = env.reset()

            total_episode_rewards = 0 
            total_length    = 0
            episode_length  = 0
            done            = False

            while not done and episode_length < max_episode_length:

                episode_length += 1
                total_length   += 1

                action = self.sample_action(obs)

                next_obs, reward, terminated, truncated, infos = env.step(action)

                self.memory.add(obs, next_obs, action, reward, done, infos)

                obs     = next_obs
                done    = terminated or truncated
                
                total_episode_rewards   += reward

                if total_length % 5 == 0:
                    self.learn()
                
                if total_length % 50 == 0:
                    self.target_network.load_state_dict(self.value_network.state_dict())

        
            episode_rewards.append(total_episode_rewards)
            if e % 100 == 0 and e != 0:
                print(f'Episode {e} done - Average Rewards Last 100: {np.average(episode_rewards[-100])}')

env = gymnasium.make("CartPole-v1")
agent = Agent(observation_space=env.observation_space, action_space=env.action_space)
agent.train(env=env, episodes=10000)

Episode 100 done - Average Rewards Last 100: 13.0
Episode 200 done - Average Rewards Last 100: 9.0
Episode 300 done - Average Rewards Last 100: 9.0
Episode 400 done - Average Rewards Last 100: 9.0
Episode 500 done - Average Rewards Last 100: 9.0
Episode 600 done - Average Rewards Last 100: 9.0
Episode 700 done - Average Rewards Last 100: 9.0
Episode 800 done - Average Rewards Last 100: 10.0
Episode 900 done - Average Rewards Last 100: 9.0
Episode 1000 done - Average Rewards Last 100: 9.0
Episode 1100 done - Average Rewards Last 100: 9.0
Episode 1200 done - Average Rewards Last 100: 8.0
Episode 1300 done - Average Rewards Last 100: 11.0
Episode 1400 done - Average Rewards Last 100: 10.0
Episode 1500 done - Average Rewards Last 100: 10.0
Episode 1600 done - Average Rewards Last 100: 8.0
Episode 1700 done - Average Rewards Last 100: 9.0
Episode 1800 done - Average Rewards Last 100: 10.0
Episode 1900 done - Average Rewards Last 100: 10.0
Episode 2000 done - Average Rewards Last 100: 9.0
Ep