<a href="https://colab.research.google.com/github/dude123studios/AdvancedReinforcementLearning/blob/main/Distributional_RL_(C51).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import random
from collections import deque
import math
import tensorflow as tf
from tensorflow.keras.layers import *
import gym

In [None]:
!wget http://www.atarimania.com/roms/Roms.rar
!unrar x Roms.rar
!unzip ROMS.zip
!python -m atari_py.import_roms ROMS

In [3]:
#Hyper Params
v_min = 0
v_max = 1000
atoms = 51
gamma = 0.9
batch_size = 64
update_target_net = 50
epsilon = 0.5

buffer_length=20000
replay_buffer = deque(maxlen=buffer_length)

In [4]:
env = gym.make('Tennis-v0')

In [5]:
class CategoricalDQN(object):

    def __init__(self, env):
        
        self.time_step = 0

        self.v_min = v_min
        self.v_max = v_max

        self.state_shape = env.observation_space.shape
        self.action_shape = env.action_space.n

        self.atoms = 51

        self.delta_z = (v_max - v_min) / (self.atoms - 1)

        self.z = tf.convert_to_tensor(np.array([[self.v_min + i * self.delta_z for i in range(self.atoms)]]))
        self.z = tf.cast(self.z, tf.float32)

        self.main = self.build_network()
        self.target = self.build_network()

        self.optimizer = tf.keras.optimizers.Adam(1e-3)

    
    def build_network(self):

        state = Input(self.state_shape, dtype=tf.float32, name="state")
        action = Input((1,), dtype=tf.float32, name="action")

        x = Conv2D(6, (5, 5), (2, 2), padding='same')(state)
        x = Conv2D(12, (3, 3), (2, 2), padding='same')(x)
        x = Flatten()(x)

        x = Dense(24, 'relu')(x)
        x = Dense(24, 'relu')(x)
        concat = Concatenate(axis=-1)([x, tf.reshape(action, (1, 1))])

        x = Dense(atoms, activation='softmax')(concat)

        return tf.keras.Model(inputs=[state, action], outputs=x)
    
    def target_q(self, state, action):
        probs = self.target([state, tf.convert_to_tensor(np.array([[action]]))])
        return tf.reduce_sum(probs * self.z)
    
    def main_q(self, state, action):
        probs = self.main([state, tf.convert_to_tensor(np.array([[action]]))])
        return tf.reduce_sum(probs * self.z)

    def train(self, s, r, action, s_):

        self.time_step += 1

        s = tf.convert_to_tensor(np.array([s]))
        s_ = tf.convert_to_tensor(np.array([s_]))
        action = tf.convert_to_tensor(np.array([action]))

        target_q_value = np.array([self.target_q(s, a).numpy() for a in range(self.action_shape)])

        a_ = np.argmax(target_q_value)

        m = np.zeros(self.atoms)

        p = self.target([s_, tf.convert_to_tensor(a_)]).numpy()[0]

        for j in range(self.atoms):
            Tz = min(self.v_max, max(self.v_min, r + gamma * self.z[0, j].numpy()))
            bj = (Tz - self.v_min) / self.delta_z 
            l,u = math.floor(bj),math.ceil(bj) 

            pj = p[j]

            m[int(l)] += pj * (u - bj)
            m[int(u)] += pj * (bj - l)
        
        self.update_network(s, action, tf.convert_to_tensor(m))

        if self.time_step % update_target_net == 0:
            self.update_target_network()
    
    @tf.function
    def update_network(self, state, action, m):
        m = tf.expand_dims(m,  axis=0)
        m = tf.cast(m, tf.float32)

        with tf.GradientTape() as tape:

            probs = self.main([state, action])

            loss = m * tf.math.log(probs)
        
        grads = tape.gradient(loss, self.main.trainable_variables)

        self.optimizer.apply_gradients(zip(grads, self.main.trainable_variables))
    
    def update_target_network(self):
        self.target.set_weights(self.main.get_weights())
        

    def select_action(self, state):
        if random.random() < epsilon:
            return random.randint(0, self.action_shape - 1)
        else:
            if state.ndim < 4: state = state[np.newaxis, :]
            state = tf.convert_to_tensor(state)
            state = tf.cast(state, tf.float32)
            q_values = np.array([self.main_q(state, a).numpy() for a in range(self.action_shape)])
            return np.argmax(q_values)

In [6]:
def sample_transitions(batch_size):
    batch = np.random.permutation(len(replay_buffer))[:batch_size]
    trans = np.array(replay_buffer)[batch]
    return trans

In [None]:
agent = CategoricalDQN(env)

num_episodes = 800

for i in range(num_episodes):

    done = False

    Return = 0

    state = env.reset()

    while not done:

        action = agent.select_action(state)

        next_state, reward, done, info = env.step(action)

        Return += reward

        replay_buffer.append([state, reward, [action], next_state])

        if len(replay_buffer) >= batch_size:

            trans = sample_transitions(2)

            for item in trans:
                
                agent.train(item[0], item[1], item[2], item[3])
        
        state = next_state
    
    print('Episode: {}, Return: {}'.format(i, Return))