<a href="https://colab.research.google.com/github/jimmy93029/Selected_topic_for_RL_sophomore_fall/blob/master/lab2/TA_hw2_code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### environment

In [None]:
%pip install gym==0.26.1
%pip install "gym[atari, accept-rom-license]"

In [None]:
import torch
import torch.nn as nn
import numpy as np
import os
import time
from collections import deque
from torch.utils.tensorboard import SummaryWriter
from abc import ABC, abstractmethod
from collections import deque
import random
from torch.utils.tensorboard import SummaryWriter
import gym

In [None]:
import ale_py
from gym.wrappers import AtariPreprocessing
# print(gym.envs.registry.keys())

### Atari Net network

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical

class AtariNet(nn.Module):
    def __init__(self, num_classes=4, init_weights=True):
        super(AtariNet, self).__init__()

        self.cnn = nn.Sequential(nn.Conv2d(4, 32, kernel_size=8, stride=4),
                                        nn.ReLU(True),
                                        nn.Conv2d(32, 64, kernel_size=4, stride=2),
                                        nn.ReLU(True),
                                        nn.Conv2d(64, 64, kernel_size=3, stride=1),
                                        nn.ReLU(True)
                                        )
        self.action_logits = nn.Sequential(nn.Linear(7*7*64, 512),
                                        nn.ReLU(True),
                                        nn.Linear(512, num_classes)
                                        )
        self.value = nn.Sequential(nn.Linear(7*7*64, 512),
                                        nn.ReLU(True),
                                        nn.Linear(512, 1)
                                        )

        if init_weights:
            self._initialize_weights()

    def forward(self, x, eval=False, a=[]):
        x = x.float() / 255.
        x = self.cnn(x)
        x = torch.flatten(x, start_dim=1)
        value = self.value(x)
        value = torch.squeeze(value)

        logits = self.action_logits(x)
        dist = Categorical(logits=logits)

        if eval:
            action = torch.argmax(logits, axis=1)
        else:
            action = dist.sample()

        if len(a) == 0:
            action_log_prob = dist.log_prob(action)
        else:
            action_log_prob = dist.log_prob(a)

        dist_entropy = dist.entropy().mean()
        action_log_prob = torch.squeeze(action_log_prob)
        return action, action_log_prob, value, dist_entropy

    def _initialize_weights(self):
        # orthogonal initialization for PPO
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.orthogonal_(m.weight, np.sqrt(2))
                nn.init.constant_(m.bias, 0.0)
            elif isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight, np.sqrt(2))
                nn.init.constant_(m.bias, 0.0)


class AtariNetDQN(nn.Module):
    def __init__(self, num_classes=4, init_weights=True):
        super(AtariNetDQN, self).__init__()
        self.cnn = nn.Sequential(nn.Conv2d(4, 32, kernel_size=8, stride=4),
                                        nn.ReLU(True),
                                        nn.Conv2d(32, 64, kernel_size=4, stride=2),
                                        nn.ReLU(True),
                                        nn.Conv2d(64, 64, kernel_size=3, stride=1),
                                        nn.ReLU(True)
                                        )
        self.classifier = nn.Sequential(nn.Linear(7*7*64, 512),
                                        nn.ReLU(True),
                                        nn.Linear(512, num_classes)
                                        )

        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = x.float() / 255.
        x = self.cnn(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1.0)
                nn.init.constant_(m.bias, 0.0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.constant_(m.bias, 0.0)


### RreplayBuffer

In [None]:
import numpy as np
import torch
from collections import deque
import random

class ReplayMemory(object):
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def __len__(self):
        return len(self.buffer)

    def append(self, *transition):
        """Saves a transition"""
        self.buffer.append(tuple(map(tuple, transition)))

    def sample(self, batch_size, device):
        """Sample a batch of transitions"""
        transitions = random.sample(self.buffer, batch_size)
        return (torch.tensor(np.asarray(x), dtype=torch.float, device=device) for x in zip(*transitions))

### DQNBaseAgent

In [None]:
import torch
import torch.nn as nn
import numpy as np
import os
import time
from collections import deque
from torch.utils.tensorboard import SummaryWriter
from replay_buffer.gae_replay_buffer import GaeSampleMemory
from replay_buffer.replay_buffer import ReplayMemory
from abc import ABC, abstractmethod

class DQNBaseAgent(ABC):
	def __init__(self, config):
		self.gpu = config["gpu"]
		self.device = torch.device("cuda" if self.gpu and torch.cuda.is_available() else "cpu")
		self.total_time_step = 0
		self.training_steps = int(config["training_steps"])
		self.batch_size = int(config["batch_size"])
		self.epsilon = 1.0
		self.eps_min = config["eps_min"]
		self.eps_decay = config["eps_decay"]
		self.eval_epsilon = config["eval_epsilon"]
		self.warmup_steps = config["warmup_steps"]
		self.use_double = config["use_double"]
		self.eval_interval = int(2**16)
		self.eval_episode = 16
		self.num_envs = config["num_envs"]
		self.gamma = config["gamma"]
		self.update_freq = config["update_freq"]
		self.update_target_freq = config["update_target_freq"]

		self.replay_buffer = ReplayMemory(int(config["replay_buffer_capacity"]))
		self.writer = SummaryWriter(config["logdir"])

	@abstractmethod
	def decide_agent_actions(self, observation, epsilon=0.0, action_space=None):
		# get action from behavior net, with epsilon-greedy selection

		return NotImplementedError

	def update(self):
		if self.total_time_step % self.update_freq == 0:
			self.update_behavior_network()
		if self.total_time_step % self.update_target_freq == 0:
			self.update_target_network()

	def update_behavior_network(self):
		# sample a minibatch of transitions
		state, action, reward, next_state, done = self.replay_buffer.sample(self.batch_size, self.device)
		action = action.type(torch.long)
		q_value = self.behavior_net(state).gather(1, action)
		with torch.no_grad():
			if self.use_double:
				q_next = self.behavior_net(next_state)
				action_index = q_next.max(dim=1)[1].view(-1, 1)
				# choose related Q from target net
				q_next = self.target_net(next_state).gather(dim=1, index=action_index.long())
			else:
				q_next = self.target_net(next_state).detach().max(1)[0].unsqueeze(1)

			# if episode terminates at next_state, then q_target = reward
			q_target = reward + self.gamma * q_next * (1 - done)

		criterion = nn.SmoothL1Loss()
		# criterion = nn.MSELoss()
		loss = criterion(q_value, q_target)

		self.writer.add_scalar('DQN/Loss', loss.item(), self.total_time_step)

		self.optim.zero_grad()
		loss.backward()
		# nn.utils.clip_grad_norm_(self.behavior_net.parameters(), 5)
		for param in self.behavior_net.parameters():
			param.grad.data.clamp_(-1, 1)

		self.optim.step()

	def update_target_network(self):
		self.target_net.load_state_dict(self.behavior_net.state_dict())

	def epsilon_decay(self):
		self.epsilon -= (1 - self.eps_min) / self.eps_decay
		self.epsilon = max(self.epsilon, self.eps_min)

	def train(self):
		self.behavior_net.train()
		observations, infos = self.env.reset()
		episode_rewards = [0] * self.num_envs
		episode_lens = [0] * self.num_envs

		while self.total_time_step <= self.training_steps:
			if self.total_time_step < self.warmup_steps:
				actions = self.decide_agent_actions(observations, 1.0, self.env.action_space.n)
			else:
				actions = self.decide_agent_actions(observations, self.epsilon, self.env.action_space.n)
				self.epsilon_decay()

			next_observations, rewards, terminates, truncates, infos = self.env.step(actions)

			for i in range(self.num_envs):
				self.replay_buffer.append(
						observations[i],
						[actions[i]],
						[rewards[i]],
						next_observations[i],
						[int(terminates[i])]
					)

			if self.total_time_step >= self.warmup_steps:
				self.update()

			episode_rewards = [episode_rewards[i] + rewards[i] for i in range(self.num_envs)]
			episode_lens = [episode_lens[i] + 1 for i in range(self.num_envs)]

			for i in range(self.num_envs):
				if terminates[i] or truncates[i]:
					if i == 0:
						self.writer.add_scalar('Train/Episode Reward', episode_rewards[0], self.total_time_step)
						self.writer.add_scalar('Train/Episode Len', episode_lens[0], self.total_time_step)
					print(f"[{self.total_time_step}/{self.training_steps}]\
						\tenv {i} \
	   					\tepisode reward: {episode_rewards[i]}\
						\tepisode len: {episode_lens[i]}\
						\tepsilon: {self.epsilon}\
						")
					episode_rewards[i] = 0
					episode_lens[i] = 0

			observations = next_observations
			self.total_time_step += self.num_envs

			if self.total_time_step % self.eval_interval == 0:
				# save model checkpoint
				avg_score = self.evaluate()
				self.save(os.path.join(self.writer.log_dir, f"model_{self.total_time_step}_{int(avg_score)}.pth"))
				self.writer.add_scalar('Evaluate/Episode Reward', avg_score, self.total_time_step)

	def evaluate(self):
		print("==============================================")
		print("Evaluating...")
		self.behavior_net.eval()
		episode_rewards = [0] * self.eval_episode
		all_rewards = [0] * self.eval_episode
		all_done = [False] * self.eval_episode
		observations, infos = self.test_env.reset()
		while True:
			actions = self.decide_agent_actions(observations, self.eval_epsilon, self.test_env.action_space.n)
			next_observations, rewards, terminates, truncates, infos = self.test_env.step(actions)
			for i in range(self.eval_episode):
				if (terminates[i] or truncates[i]) and not all_done[i]:
					print(f"env {i} terminated, reward: {episode_rewards[i]}")
					all_rewards[i] = episode_rewards[i]
					all_done[i] = True

			episode_rewards = [episode_rewards[i] + rewards[i] for i in range(self.eval_episode)]
			observations = next_observations

			# all episodes done, terminate
			if all(all_done):
				break


		avg = sum(all_rewards) / self.eval_episode
		print(f"average score: {avg}")
		print("==============================================")
		self.behavior_net.train()
		return avg

	# save model
	def save(self, save_path):
		torch.save(self.behavior_net.state_dict(), save_path)

	# load model
	def load(self, load_path):
		self.behavior_net.load_state_dict(torch.load(load_path))

	# load model weights and evaluate
	def load_and_evaluate(self, load_path):
		self.load(load_path)
		self.evaluate()


### AtariDQNAgent

In [None]:
import torch
import torch.nn as nn
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from base_agent import DQNBaseAgent
from models.atari_model import AtariNetDQN, AtariNetDuelingDQN
import envpool
import random

class AtariDQNAgent(DQNBaseAgent):
	def __init__(self, config):
		super(AtariDQNAgent, self).__init__(config)

		self.env = envpool.make(config["env_id"], env_type="gym", num_envs=self.num_envs, episodic_life=True, reward_clip=True)
		self.test_env = envpool.make(config["env_id"], env_type="gym", num_envs=self.eval_episode, episodic_life=False, reward_clip=False)
		self.behavior_net = AtariNetDQN(self.env.action_space.n)
		#　self.behavior_net = AtariNetDuelingDQN(self.env.action_space.n)
		self.behavior_net.to(self.device)
		self.target_net = AtariNetDQN(self.env.action_space.n)
		# self.target_net = AtariNetDuelingDQN(self.env.action_space.n)
		self.target_net.to(self.device)
		self.target_net.load_state_dict(self.behavior_net.state_dict())
		self.target_net.eval()
		self.lr = config["learning_rate"]
		self.optim = torch.optim.Adam(self.behavior_net.parameters(), lr=self.lr, eps=1.5e-4)

	def decide_agent_actions(self, observation, epsilon=0.0, action_space=None):
		observation = torch.from_numpy(observation)
		observation = observation.to(self.device, dtype=torch.float32)
		if random.random() < epsilon:
			action = np.random.randint(0, action_space, size=observation.shape[0])
		else:
			action = self.behavior_net(observation).argmax(dim=1).cpu().numpy()

		return action

	def evaluate(self):
		print("==============================================")
		print("Evaluating...")
		self.behavior_net.eval()
		episode_rewards = [0] * self.eval_episode
		all_rewards = [0] * self.eval_episode
		zero_reward_counter = [0] * self.eval_episode
		all_done = [False] * self.eval_episode
		observations, infos = self.test_env.reset()
		while True:
			actions = self.decide_agent_actions(observations, self.eval_epsilon, self.test_env.action_space.n)
			# breakout: help agent to fire
			for i in range(self.eval_episode):
				if zero_reward_counter[i] > 200:
					actions[i] = 1
					zero_reward_counter[i] = 0

			next_observations, rewards, terminates, truncates, infos = self.test_env.step(actions)
			for i in range(self.eval_episode):
				if (terminates[i] or truncates[i]) and not all_done[i]:
					print(f"env {i} terminated, reward: {episode_rewards[i]}")
					all_rewards[i] = episode_rewards[i]
					all_done[i] = True

			for i in range(self.eval_episode):
				if rewards[i] == 0:
					zero_reward_counter[i] += 1
				else:
					zero_reward_counter[i] = 0

			episode_rewards = [episode_rewards[i] + rewards[i] for i in range(self.eval_episode)]
			observations = next_observations

			# all episodes done, terminate
			if all(all_done):
				break


		avg = sum(all_rewards) / self.eval_episode
		print(f"average score: {avg}")
		print("==============================================")
		self.behavior_net.train()
		return avg

### main

In [None]:
if __name__ == '__main__':

	config = {
		"gpu": True,
		"training_steps": 1e8,
		"gamma": 0.99,
		"batch_size": 32,
		"eps_min": 0.1,
		"warmup_steps": 20000,
		"eps_decay": 1000000,
		"eval_epsilon": 0.01,
		"use_double": True,
		"replay_buffer_capacity": 100000,
		"logdir": 'log/dqn_eval/',
		"update_freq": 4,
		"update_target_freq": 10000,
		"learning_rate": 0.0000625,
		"env_id": 'ALE/MsPacman-v5',
		"num_envs": 4,
	}

	agent = AtariDQNAgent(config)
	agent.train()
