In [19]:
import random
import gymnasium as gym
import numpy as np
from collections import deque
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation, Flatten, Convolution2D, MaxPooling2D
from tensorflow.keras.optimizers import Adam


In [28]:
# preprocessing the image

def preprop(obs):
    gray = 0.2989*obs[:,:,0] + 0.5870*obs[:,:,1] + 0.1140*obs[:,:,2]
    gray = gray[35:195]
    gray = gray[::2, ::2]
    gray = np.expand_dims(gray.reshape(80, 80, 1), axis=0)
    return gray


In [36]:
class DQN:
    def __init__(self, state_size, action_size):
        
        #input size
        self.state_size = state_size
        
        self.action_size = action_size
        
        self.replay_buffer = deque(maxlen=5000)
        
        self.epsilon = 0.2
        
        self.gamma = 0.99
        
        self.learning_rate = 0.00025
        
        self.update_rate = 1000  
        
        self.main_network = self.build_network()
        
        self.target_network = self.build_network()
        
        self.target_network.set_weights(self.main_network.get_weights()) 
        
    def build_network(self):
        
        model = Sequential()
        model.add(Convolution2D(32, (8,8), strides=(4,4), activation='relu', input_shape= self.state_size ))
        model.add(Convolution2D(64,(4,4),(2,2), activation="relu"))
        model.add(Convolution2D(64,(3,3), activation="relu"))
        model.add(Flatten())
        model.add(Dense(512,activation = "relu"))
        model.add(Dense(256,activation = "relu"))
        model.add(Dense(self.action_size,activation = "linear"))
        model.compile(optimizer = "adam", loss = "mse")
        
        return model
    
    def store_transition(self,s,a,r,s_,done):
        self.replay_buffer.append((s,a,r,s_,done))
    
    def choose_action(self, state):
        if random.random() < self.epsilon: 
            return random.randint(0,5)
        else: 
            prob_action = self.main_network.predict(state)
            return np.argmax(prob_action[0])
    
    def train(self):
        minibatch =  random.sample(self.replay_buffer, 8)
        
        for s,a,r,s_,done in minibatch:
            if done:
                y_i = r
            else:
                y_i = r + self.gamma*np.amax(self.target_network.predict(s_))
        
        Q_values = self.main_network.predict(state)
            
        Q_values[0][action] = y_i
            
        self.main_network.fit(state, Q_values, epochs=1, verbose=0)
    
    def update_target_network(self):
        self.target_network.set_weights(self.main_network.get_weights())
        
        
            
        

In [25]:
episodes = 5
frames = 4

env = gym.make("ALE/Pong-v5")



In [37]:
dqnet = DQN((80,80,1), env.action_space.n)
num_transitions = 0

for episode in range(episodes):
    score = 0
    obs, info = env.reset()
    state = preprop(obs)
    done = False
    while not(done):
        num_transitions += 1
        
        action = dqnet.choose_action(state)
        obs, reward, done, trunc, info = env.step(action)
        prev_state = state
        state = preprop(obs)
        dqnet.store_transition(prev_state, action, reward, state, done)
        
        if num_transitions % dqnet.update_rate ==0:
            dqnet.update_target_network()
        
        if len(dqnet.replay_buffer) > 8:
            dqnet.train()
        score += reward
        
    print(f"episode:{episode}, return: {score}")
        
        

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 309ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 64ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 71ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 65ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 69ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 83ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 68ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 284ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 66ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 74ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 92ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 68ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 68ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 

In [39]:
print(episode)

4
