In [None]:
!sudo apt-get install -y xvfb ffmpeg x11-utils
!pip install -q 'gym==0.10.11'
!pip install -q 'imageio==2.4.0'
!pip install -q PILLOW
!pip install -q 'pyglet==1.3.2'
!pip install -q pyvirtualdisplay
!pip install -q tf-agents
!pip install colabgymrender
!pip install gym_sokoban
!pip install gym
!apt-get install python-opengl -y
!apt-get install -y xvfb python-opengl > /dev/null 2>&1
!pip install gym pyvirtualdisplay > /dev/null 2>&1
!apt-get install python-opengl -y
!apt install xvfb -y
!pip install piglet

In [None]:
import gym
import gym_sokoban

from pyvirtualdisplay import Display
display = Display(visible=0, size=(1400, 900))
display.start()

from colabgymrender.recorder import Recorder
from IPython import display as ipythondisplay
from pyvirtualdisplay import Display
Display().start()
import matplotlib.pyplot as plt
from IPython import display as ipythondisplay


import gym
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T


is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

In [None]:
env = gym.make("PushAndPull-Sokoban-v2")
env.max_steps =  5000 
screen = env.render(mode='rgb_array')
print(screen.shape)

In [None]:
plt.imshow(screen)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([],maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [None]:
class DQN(nn.Module):
    def __init__(self, outputs):
        super().__init__()

        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)

        self.pre_adv = nn.Linear(10000, 512)
        self.adv = nn.Linear(512, outputs)

        self.pre_val = nn.Linear(10000, 512)
        self.val = nn.Linear(512, 1)
        

    def forward(self, x):
        x = x.to(device)
        x = self.pool1(F.relu(self.conv1(x)))

        
        x = self.pool2(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) 

        adv = F.relu(self.pre_adv(x))
        adv = self.adv(adv)

        value = F.relu(self.pre_val(x))
        value = self.val(value)
        
        
        return value + adv - adv.mean()

    
    def init_weights(self):
      torch.nn.init.xavier_uniform(self.conv1.weight)
      torch.nn.init.xavier_uniform(self.conv2.weight)
      
      torch.nn.init.xavier_uniform(self.pre_adv.weight)
      torch.nn.init.xavier_uniform(self.adv.weight)

      torch.nn.init.xavier_uniform(self.pre_val.weight)
      torch.nn.init.xavier_uniform(self.val.weight)

      return self


In [None]:
resize = T.Compose([
                    T.ToPILImage(),
                    T.Grayscale(num_output_channels=1),
                    T.ToTensor()])


def get_screen():
    screen = env.render(mode='rgb_array')
    return resize(screen)

In [None]:
BATCH_SIZE = 32
GAMMA = 0.9
EPS_START = 1.0
EPS_END = 0.1
EPS_DECAY = 60000
TARGET_UPDATE = 10

init_screen = get_screen()

_, screen_height, screen_width = init_screen.shape

n_actions = env.action_space.n

policy_net = DQN(n_actions)
policy_net.init_weights()
policy_net.to(device)

target_net = DQN(n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.RMSprop(policy_net.parameters(), lr=0.001)
criterion = nn.SmoothL1Loss()


In [None]:
def select_action(state):
    global steps_done
    steps_done += 1

    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    
    if sample > eps_threshold:
        with torch.no_grad():
            return policy_net(state).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)

In [None]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
        
    transitions = memory.sample(BATCH_SIZE)

    batch = Transition(*zip(*transitions))
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                                if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)
    state_action_values = policy_net(state_batch).gather(1, action_batch)

  
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
  
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch    
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    optimizer.zero_grad()
    loss.backward()

    optimizer.step()

In [None]:
num_episodes = 10000
episode_durations = []
steps_done = 0
memory = ReplayMemory(2500)
policy_net.train()
target_net.eval()
sarsa_k = 1


for i_episode in range(num_episodes):
    
    if i_episode % 50 == 0:
      print(i_episode)

    env.reset()
    state = get_screen().unsqueeze(0)

    prev_states = []
    sarsa_k_states = []
    for t in count():
        action = select_action(state)
        _, reward, done, _ = env.step(action.item())
        reward = torch.tensor([reward], device=device)
        next_state = get_screen().unsqueeze(0)

        if not done:
          if sarsa_k == 1:
            memory.push(state, action, next_state, reward)
          else:
            sarsa_k_states_new = []
            for prev_state, prev_action, _, prev_rewards, prev_counter in sarsa_k_states:
              k_rewards = prev_rewards + reward
              counter = prev_counter +  1

              if counter == sarsa_k:
                memory.push(prev_state, prev_action, next_state, k_rewards)
              else:
                sarsa_k_states_new.append((prev_state, prev_action, None, k_rewards, counter))
          
            sarsa_k_states_new.append((state, action, None, reward, 1))
            sarsa_k_states = sarsa_k_states_new
            sarsa_k_states_new = None
        else:
            next_state = None

            for prev_state, prev_action, _, prev_rewards, prev_counter in sarsa_k_states:
              k_rewards = prev_rewards + reward
              memory.push(prev_state, prev_action, next_state, k_rewards)
            
            memory.push(state, action, next_state, reward)

      

        state = next_state

        if steps_done % BATCH_SIZE == 0:
          optimize_model()

        if done:
            episode_durations.append(t + 1)
            break

    if i_episode % TARGET_UPDATE == 0:
        target_net.load_state_dict(policy_net.state_dict())
        target_net.eval()


env.close()