In [None]:
!sudo apt-get install -y xvfb ffmpeg x11-utils
!pip install -q 'gym==0.10.11'
!pip install -q 'imageio==2.4.0'
!pip install -q PILLOW
!pip install -q 'pyglet==1.3.2'
!pip install -q pyvirtualdisplay
!pip install -q tf-agents
!pip install colabgymrender
!pip install gym_sokoban
!pip install gym
!apt-get install python-opengl -y
!apt-get install -y xvfb python-opengl > /dev/null 2>&1
!pip install gym pyvirtualdisplay > /dev/null 2>&1
!apt-get install python-opengl -y
!apt install xvfb -y
!pip install piglet
!pip install gym_sokoban


In [None]:
import gym
import gym_sokoban
import pandas as pd 
import seaborn as sns
sns.set_theme(style="whitegrid")

from pyvirtualdisplay import Display
display = Display(visible=0, size=(1400, 900))
display.start()

from colabgymrender.recorder import Recorder
from IPython import display as ipythondisplay
from pyvirtualdisplay import Display
Display().start()
import matplotlib.pyplot as plt
from IPython import display as ipythondisplay


import gym
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T


is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

In [None]:
env = gym.make("PushAndPull-Sokoban-v2")
env.max_steps =  200 
screen = env.render(mode='rgb_array')
print(screen.shape)

In [None]:
plt.imshow(screen)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward', 'expected_reward'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([],maxlen=capacity)

    def push(self, *args):
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def priority_sample(self, batch_size):      
      diffs = np.array([np.power(np.absolute(x.reward.item() - x.expected_reward), 0.8)  for x in self.memory])
      s1 = np.sum(diffs)
      probs = diffs/s1

      inds = np.random.choice(len(self.memory), batch_size, replace=False)

      return [self.memory[i] for i in inds]

    def __len__(self):
        return len(self.memory)

In [None]:
class DQN(nn.Module):
    def __init__(self, outputs):
        super().__init__()

        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)

        self.pre_adv = nn.Linear(10000, 512)
        self.adv = nn.Linear(512, outputs)

        self.pre_val = nn.Linear(10000, 512)
        self.val = nn.Linear(512, 1)
        

    def forward(self, x):
        x = x.to(device)
        x = self.pool1(F.leaky_relu(self.conv1(x)))

        
        x = self.pool2(F.leaky_relu(self.conv2(x)))
        x = torch.flatten(x, 1) 

        adv = F.leaky_relu(self.pre_adv(x))
        adv = self.adv(adv)

        value = F.leaky_relu(self.pre_val(x))
        value = self.val(value)
        
        
        return value + adv - adv.mean()
    
    def init_weights(self):
      torch.nn.init.xavier_uniform(self.conv1.weight)
      torch.nn.init.xavier_uniform(self.conv2.weight)
      
      torch.nn.init.xavier_uniform(self.pre_adv.weight)
      torch.nn.init.xavier_uniform(self.adv.weight)

      torch.nn.init.xavier_uniform(self.pre_val.weight)
      torch.nn.init.xavier_uniform(self.val.weight)

      return self


In [None]:
resize = T.Compose([
                    T.ToPILImage(),
                    T.Grayscale(num_output_channels=1),
                    T.ToTensor()])


def get_screen():
    screen = env.render(mode='rgb_array')
    return resize(screen)


In [None]:
BATCH_SIZE = 256
GAMMA = 0.9
EPS_START = 1.0
EPS_END = 0.1
EPS_DECAY = 60000
TARGET_UPDATE = 10

In [None]:
init_screen = get_screen()

_, screen_height, screen_width = init_screen.shape
n_actions = env.action_space.n

policy_net = DQN(n_actions)
policy_net.init_weights()
policy_net.to(device)

target_net = DQN( n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.RMSprop(policy_net.parameters(), lr=0.001)
criterion = nn.SmoothL1Loss() #nn.SmoothL1Loss(reduction='mean', beta=0.3) #beta=0.1

In [None]:
def select_action(state, deterministic = False):
  
    global steps_done
    steps_done += 1
    with torch.no_grad():
      res = policy_net(state).cpu().detach().numpy().flatten()
     
      if deterministic:
        action =  np.argmax(res)
        return torch.tensor([[action]], device=device, dtype=torch.long), xxx[action]
      else:
        m = res.min()
        vals = res - m + 0.0001 if m < 0 else res + 0.0001

        s  = np.sum(vals)
        ps = vals/s
        val = np.random.choice(n_actions, p = ps)
        return torch.tensor([[val]], device=device, dtype=torch.long), xxx[val]


In [None]:
def optimize_model():
  
    if len(memory) < BATCH_SIZE:
        return

    transitions = memory.priority_sample(BATCH_SIZE)

    batch = Transition(*zip(*transitions))

    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                                if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    state_action_values = policy_net(state_batch).gather(1, action_batch)

    expected_state_action_values = reward_batch 

    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
    
    optimizer.zero_grad()
    loss.backward()
    
    optimizer.step()

In [None]:
num_episodes = 10000
episode_durations = []
steps_done = 0
memory = ReplayMemory(1500)
policy_net.train()
target_net.eval()
sarsa_k = 1
window = 5


for i_episode in range(num_episodes):
    
    if i_episode % 50 == 0:
      print(i_episode)

    env.reset()
    state = get_screen().unsqueeze(0)

    temp_rewards = []
    temp_mempry =[]

    for t in count():

        action, expected_reward = select_action(state)
        _, reward, done, _ = env.step(action.item())
        
        next_state = get_screen().unsqueeze(0)

        if done:
          next_state = None
        temp_mempry.append((state, action, next_state, expected_reward))
        temp_rewards.append(reward)

        
        if done:

          for (index, data) in enumerate(temp_mempry):
            state, action, next_state, expected_reward = data 
            full_rewards = temp_rewards[index: index + window]

            if t == env.max_steps - 1:
              if len(full_rewards) < window:
                continue
            
            reward = np.sum([ x * np.power(0.8, i) for (i, x ) in  enumerate(full_rewards)])
            reward = torch.tensor([reward], device=device).float()
            memory.push(state, action, next_state, reward, expected_reward)


        state = next_state

        if steps_done % BATCH_SIZE == 0:
          optimize_model()
        if done:
            episode_durations.append(t + 1)
            break

    if i_episode % TARGET_UPDATE == 0:
        target_net.load_state_dict(policy_net.state_dict())
        target_net.eval()


In [None]:

durations_df = pd.DataFrame([(iteration, duration) for (iteration, duration) in enumerate(episode_durations)], columns =["iteration", "duration"])


sns.scatterplot(data=durations_df, x="iteration", y="duration")
