In [None]:
import sys
import torch
import gym
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random
import cv2

In [None]:
env_name = "SpaceInvaders-v4"
env = gym.make(env_name, obs_type = 'grayscale')

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class ActorCritic(nn.Module):    
    def __init__(self, input_shape, n_actions, lr):
        super(ActorCritic, self).__init__()
        self.network = nn.Sequential(
            nn.Conv2d(input_shape, 32, kernel_size = 8, stride = 4, padding = 1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size = 4, stride = 2, padding = 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size = 3, padding = 1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size = 3, padding = 1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(14400, 256),
            nn.ReLU(),
        )
        self.actor_end_layer = nn.Linear(256, n_actions)
        self.critic_end_layer = nn.Linear(256, 1)
    def forward(self, state):
        value = self.critic_end_layer(self.network(state))
        probs = F.softmax(self.actor_end_layer(self.network(state)), dim = 1)
        return value, probs

In [None]:
def show_game(model):
    env2 = gym.make(env_name, render_mode = 'human', obs_type = 'grayscale')
    state = env2.reset()[0]
    done = False
    steps = 0
    flag = True
    while not done or steps < 800:
        steps += 1
        state = preprocess(state)
        _, probs = model(state)
        action_dist = Categorical(probs)
        if flag:
            print(f'Action_probs: {probs}')
            flag = False
        action = action_dist.sample().item()
        next_state, reward, terminated, truncated, _ = env2.step(action)
        done = terminated or truncated
        state = next_state
    env2.close()
    return

In [None]:
hidden_size = 256
lr = 3e-4
GAMMA = 0.99
num_steps = 512
max_episodes = 900000
n_actions = env.action_space.n
n_states = 1

In [None]:
def preprocess(img):
    img = cv2.resize(img, (128,128))
    img = torch.FloatTensor(img).to(device).unsqueeze(0).unsqueeze(0)
#     plt.imshow(img)
    img /= 255.0
    return img

In [None]:
state = env.reset()[0]
plt.imshow(preprocess(state).cpu()[0].permute(1,2,0).numpy())
plt.show()

In [None]:
def select_action(action_dist):
    global epsilon
    global epsilon_decay
    if np.random.uniform() < epsilon:
        action = np.random.choice(n_actions)
    else:
        action = action_dist.sample().item()
    epsilon *= epsilon_decay
    return action

In [None]:
actor_critic = ActorCritic(n_states, n_actions, lr).to(device)
optimizer = optim.AdamW(actor_critic.parameters(), lr = lr, amsgrad = True)

all_lengths = []
average_lengths = []
all_rewards = []
epsilon = 0.999
epsilon_decay = 0.99999

In [None]:
# Training
for i_episode in range(max_episodes):
    log_probs = []
    values = []
    rewards = []
    probs_list = []
#     entropy_term = 0
    
    state = env.reset()[0]
    done = False
    steps = 1
    while not done:
        steps +=1
        state = preprocess(state)
#         print(state.shape)
#         state = torch.permute(state, (0,3,1,2))
        value, probs = actor_critic.forward(state)
        
        action_dist = Categorical(probs)
        action = select_action(action_dist)
        
        log_prob = torch.log(probs.squeeze(0)[action])
        probs_list.append(probs)
        next_state, reward, terminated, truncated, _ = env.step(action)
        
        done = terminated or truncated
        
        rewards.append(reward)
        values.append(value)
        log_probs.append(log_prob)
        state = next_state
        
        if steps%num_steps == 0 and not done:
            flag = True
            next_state = preprocess(state)
            all_rewards.append(np.sum(rewards))
            Qval, _ = actor_critic(next_state)
#             print(values[0].shape)
            Qvals = []
            for t in reversed(range(len(rewards))):
                Qval = rewards[t] + GAMMA * Qval
                Qvals.append(Qval)
            Qvals = list(reversed(Qvals))
            values = torch.cat(values)
            Qvals = torch.FloatTensor(Qvals).to(device)

            advantage = Qvals - values
            actor_loss = -torch.stack(log_probs) * advantage.detach()
#             print(values.shape, Qvals.shape)
            critic_loss = F.smooth_l1_loss(values, Qvals.unsqueeze(-1)) 
            entropy_term = -(torch.stack(probs_list) * (torch.stack(probs_list).log())).sum(dim = -1).mean()
            
            ac_loss = actor_loss.mean() + critic_loss - 1*entropy_term
            
            optimizer.zero_grad()
            ac_loss.backward()
            optimizer.step()
            log_probs = []
            values = []
            rewards = []
            probs_list = []
            if random.randint(1, 150) == 150 or True:
                print(f'Trained Episode: {i_episode} | steps: {steps} | loss = {ac_loss:.4f} |\
 Actor_Loss: {actor_loss.mean():.2f} | Critic_Loss: {critic_loss:.2f} | entropy: {entropy_term:.2f} | epsilon: {epsilon:.2f}')
        
        if done:
            next_state = preprocess(state)
            Qval, _ = actor_critic(next_state)
            Qval_np = Qval.detach().to('cpu').numpy()[0,0]
            all_rewards.append(np.sum(rewards))
            all_lengths.append(steps)
            average_lengths.append(np.mean(all_lengths[-10:]))
            if i_episode % 10 == 0:
                print(f"Episode:{i_episode} | Reward: {np.sum(rewards)} | Total_Length: {steps} | Average_Length: {average_lengths[-1]}")
    
            Qvals = []
            for t in reversed(range(len(rewards))):
                Qval = rewards[t] + GAMMA * Qval
                Qvals.append(Qval)
            Qvals = list(reversed(Qvals))
            values = torch.cat(values)
            Qvals = torch.FloatTensor(Qvals).to(device)

            advantage = Qvals - values
            actor_loss = -torch.stack(log_probs) * advantage.detach()
            critic_loss = F.smooth_l1_loss(values, Qvals.unsqueeze(-1))
            entropy_term = -(torch.stack(probs_list) * (torch.stack(probs_list).log())).sum(dim = -1).mean()
            
            ac_loss = actor_loss.mean() + critic_loss - 1*entropy_term
            
            optimizer.zero_grad()
            ac_loss.backward()
            optimizer.step()
            break
        
    
    if (i_episode+1) % 50 == 0:
        smoothed_rewards = pd.Series.rolling(pd.Series(all_rewards), 10).mean()
        smoothed_rewards = [x for x in smoothed_rewards]
        plt.plot(all_rewards)
        plt.plot(smoothed_rewards)
        plt.plot()
        plt.xlabel('Episode')
        plt.ylabel('Reward')
        plt.show()
        
        plt.plot(all_lengths)
        plt.plot(average_lengths)
        plt.xlabel('Episode')
        plt.ylabel('Episode Length')
        plt.show()
    if (i_episode+1) % 150 == 0:
        show_game(actor_critic)