In [None]:
import gym
import numpy as np

import tensorflow as tf

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = "1"

class BasalGangliaMDP(gym.Env):
    def __init__(self):
        super().__init__()

        # Define states
        self.states = ['Cortex', 'Striatum', 'GPe', 'STN', 'GPi', 'Thalamus']

        # Define actions
        self.actions = ['Change State']
        self.action_space = gym.spaces.Discrete(len(self.actions))

        # Define transition probabilities (consult experts for accurate values)
        self.transition_probs = {
            'Cortex': {
                'Change State': {'Striatum': 1.0}
            },
            'Striatum': {
                'Change State': {'GPe': 0.5, 'GPi': 0.5}
            },
            'GPe': {
                'Change State': {'STN': 1.0}
            },
            'STN': {
                'Change State': {'GPi': 1.0}
            },
            'GPi': {
                'Change State': {'Thalamus': 1.0}
            },
            'Thalamus': {
                'Change State': {'Thalamus': 1.0}  # Terminal state
            }
        }

        # Define rewards (consult experts for appropriate values)
        self.rewards = {
            ('Cortex', 'Change State', 'Striatum'): 0,  # Neutral transition
            ('Striatum', 'Change State', 'GPe'): 1,  # Mild penalty for indirect pathway
            ('Striatum', 'Change State', 'GPi'): 1,  # Reward for direct pathway
            ('GPe', 'Change State', 'STN'): 0.5,  # Mild penalty for prolonged indirect pathway
            ('STN', 'Change State', 'GPi'): 0.5,  # Partial reward for returning to GPi
            ('GPi', 'Change State', 'Thalamus'): 2,  # Significant reward for reaching Thalamus
            ('Thalamus', 'Change State', 'Thalamus'): 0.1,  # Small reward for staying in Thalamus
        }


        self.state = 'Cortex'  # Initial state

    def step(self, action):

        next_state_probs = self.transition_probs[self.state][action]
        next_state = np.random.choice(list(next_state_probs.keys()), p=list(next_state_probs.values()))

        self.state = next_state

        reward = self.rewards.get((self.state, action, next_state), 0)
        done = next_state == 'Thalamus'  # Terminal state
        info = {}


        return next_state, reward, done, info

    def reset(self):
        self.state = 'Cortex'
        return self.state

    def render(self):
        print(self.state)




class DQNAgent:
    def __init__(self, state_space_size, action_space_size, learning_rate=0.001, discount_factor=0.9, exploration_prob=0.1):
        self.state_space_size = state_space_size
        self.action_space_size = action_space_size
        self.learning_rate = learning_rate
        self.discount_factor = discount_factor
        self.exploration_prob = exploration_prob

        # Build Q-network
        self.q_network = self.build_q_network()

        # Target Q-network (for stability)
        self.target_q_network = self.build_q_network()
        self.target_q_network.set_weights(self.q_network.get_weights())

        # Optimizer
        self.optimizer = tf.keras.optimizers.Adam(learning_rate)

        # Experience replay buffer
        self.memory = []

    def build_q_network(self):
        model = tf.keras.Sequential([
            tf.keras.layers.Dense(64, activation='relu', input_shape=(self.state_space_size,), dtype=tf.float32),
            tf.keras.layers.Dense(64, activation='relu'),
            tf.keras.layers.Dense(self.action_space_size)
        ])
        model.compile(optimizer='adam', loss='mse')
        return model

    def select_action(self, state):
      if np.random.rand() < self.exploration_prob:
          return np.random.choice(self.action_space_size)
      else:
          # Convert state index to one-hot encoding
          state_one_hot = np.zeros(self.state_space_size)
          state_one_hot[state] = 1

          # Reshape state for model prediction
          state_one_hot = state_one_hot.reshape(1, -1)  # Reshape to (1, state_space_size)

          # Predict Q-values for the current state
          q_values = self.q_network.predict(state_one_hot)

          # Select action with the highest Q-value
          return np.argmax(q_values[0])

    def update_q_network(self, batch_size, states=None):
      if len(self.memory) < batch_size:
          return

      if states is None:
          # Sample a batch from memory
          samples = np.random.choice(len(self.memory), batch_size, replace=False)
          batch = [self.memory[i] for i in samples]

          # Extract components from the batch
          states, actions, rewards, next_states, dones = zip(*batch)

          # Convert state indices to one-hot encoding
          states = np.eye(len(env.states), dtype=int)[np.array(states)]

          q_values = self.q_network.predict(states)
          next_q_values = self.target_q_network.predict(states)

          # Update Q-values based on Bellman equation
          for i in range(batch_size):
              target = rewards[i] + self.discount_factor * np.max(next_q_values[i]) * (1 - dones[i])
              q_values[i, actions[i]] = target

          # Train the Q-network
          self.q_network.fit(states, q_values, verbose=0)

    def update_target_network(self):
        self.target_q_network.set_weights(self.q_network.get_weights())

# Instantiate the environment and the DQN agent
env = BasalGangliaMDP()
agent = DQNAgent(state_space_size=len(env.states), action_space_size=env.action_space.n)

# Training the DQN agent
num_episodes = 100
batch_size = 32

for episode in range(num_episodes):
    state_episode = []
    total_reward = 0

    state = env.reset()
    state_episode.append(state)

    print(f"\nEpisode {episode + 1}:")

    while True:
        # Convert state to its corresponding index
        state_index = env.states.index(state)

        action = agent.select_action(state_index)
        next_state, reward, done, _ = env.step(env.actions[action])

        print(f"Transition: {env.states[state_index]} -> {next_state}, Reward: {reward}")


        # Convert next_state to its corresponding index
        next_state_index = env.states.index(next_state)

        # Store experience in replay buffer
        agent.memory.append((state_index, action, reward, next_state_index, done))

        # Update Q-network
        if len(state_episode) >= 2:
            agent.update_q_network(batch_size, states=state_episode)

        total_reward += reward
        state = next_state

        if done:
            break

        # Add more states to the state_episode list
        if len(state_episode) < 10:
            state_episode.append(state)
        else:
            # Remove the first state in the state_episode list
            state_episode.pop(0)
            # Add the current state to the state_episode list
            state_episode.append(state)

    # Update target network periodically
    if episode % 10 == 0:
        agent.update_target_network()

    print(f"Episode: {episode + 1}, Total Reward: {total_reward}")

# Testing the trained DQN agent
state = env.reset()
done = False

print("\nFinally chosen pathway: ")
while not done:
    # Convert state to its corresponding index
    state_index = env.states.index(state)

    action = agent.select_action(state_index)
    next_state, reward, done, _ = env.step(env.actions[action])

    print(f"Transition: {env.states[state_index]} -> {next_state}, Reward: {reward}")
    state = next_state




Episode 1:
Transition: Cortex -> Striatum, Reward: 0
Transition: Striatum -> GPi, Reward: 0
Transition: GPi -> Thalamus, Reward: 0.1
Episode: 1, Total Reward: 0.1

Episode 2:
Transition: Cortex -> Striatum, Reward: 0
Transition: Striatum -> GPi, Reward: 0
Transition: GPi -> Thalamus, Reward: 0.1
Episode: 2, Total Reward: 0.1

Episode 3:
Transition: Cortex -> Striatum, Reward: 0
Transition: Striatum -> GPi, Reward: 0
Transition: GPi -> Thalamus, Reward: 0.1
Episode: 3, Total Reward: 0.1

Episode 4:
Transition: Cortex -> Striatum, Reward: 0
Transition: Striatum -> GPi, Reward: 0
Transition: GPi -> Thalamus, Reward: 0.1
Episode: 4, Total Reward: 0.1

Episode 5:
Transition: Cortex -> Striatum, Reward: 0
Transition: Striatum -> GPi, Reward: 0
Transition: GPi -> Thalamus, Reward: 0.1
Episode: 5, Total Reward: 0.1

Episode 6:
Transition: Cortex -> Striatum, Reward: 0
Transition: Striatum -> GPi, Reward: 0
Transition: GPi -> Thalamus, Reward: 0.1
Episode: 6, Total Reward: 0.1

Episode 7:
Tran