In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd /content/drive/My Drive/app

In [None]:
%cd dmsoroki-gym_interf

In [None]:
!pip install -e .

In [None]:
!pip install gym==0.12.1

In [None]:
import gym
import gym_interf

env = gym.make('interf-v1')

In [None]:
%matplotlib inline
import math
import random
import tqdm
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from collections import deque
from itertools import count
from copy import deepcopy
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
from torch.distributions import Categorical

is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display    
print("Is python : {}".format(is_ipython))


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device : {}".format(device))


ACTIONS_NUM = 8
print("Number of actions : {}".format(ACTIONS_NUM))

In [None]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))

class ReplayMemory:

    def __init__(self, capacity = 40000):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [None]:
class DQN(nn.Module):
    def __init__(self, in_channels=16, num_actions=ACTIONS_NUM):
        
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
        self.conv4 = nn.Conv2d(64,1024,kernel_size=4,stride=1)
        self.advantage = nn.Linear(512, num_actions)
        self.value = nn.Linear(512, 1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        advantage,value = torch.split(x,512,dim=1)
        
        advantage = advantage.view(advantage.shape[0],-1)
        value = value.view(value.shape[0],-1)
        
        advantage = self.advantage(advantage)
        value = self.value(value)
        q_value = value.expand(value.shape[0],ACTIONS_NUM) +\
        advantage-torch.mean(advantage,dim=1).unsqueeze(1).expand(advantage.shape[0],ACTIONS_NUM)
        return q_value

In [None]:
policy_net = DQN().to(device)
target_net = DQN().to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer =optim.Adam(policy_net.parameters(),lr=1e-5)

memory = ReplayMemory()

def select_action(state, eps_threshold):
    global steps_done
    sample = random.random()
    if sample > eps_threshold:
        with torch.no_grad():
            state=state.float()
            return policy_net(state).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[random.randrange(ACTIONS_NUM)]], device=device, dtype=torch.long)

train_rewards = []

mean_size = 100
mean_step = 1

def plot_rewards(rewards = train_rewards, name = "Train"):
    plt.figure(2)
    plt.clf()
    plt.title(name)
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(rewards)
    if len(rewards) > mean_size:
        means = np.array([rewards[i:i+mean_size:] for i in range(0, len(rewards) - mean_size, mean_step)]).mean(1)
        means = np.concatenate((np.zeros(mean_size - 1), means))
        plt.plot(means)

In [None]:
BATCH_SIZE = 32
GAMMA = 0.99

def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    
    
    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))
    
    
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.uint8)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                                if s is not None])
    state_batch = torch.cat(batch.state)
    
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    
    state_batch=state_batch.float()
    state_action_values = policy_net(state_batch).gather(1, action_batch)
    
    
    non_final_next_states=non_final_next_states.float()
    next_state_values = torch.zeros((BATCH_SIZE,1), device=device)
    next_state_actions = torch.zeros(BATCH_SIZE,dtype=torch.long, device=device)
    
    next_state_actions[non_final_mask] = policy_net(non_final_next_states).max(1)[1]
    next_state_values[non_final_mask] = target_net(non_final_next_states).gather(1, next_state_actions[non_final_mask].unsqueeze(1))
    next_state_values=next_state_values.squeeze(1)
    
    
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch
    
    
    
    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1).detach())
    
    
    
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()
    
    del non_final_mask
    del non_final_next_states
    del state_batch
    del action_batch
    del reward_batch
    del state_action_values
    del next_state_values
    del expected_state_action_values
    del loss

In [None]:
def test():
  TEST_EPS = 0.005
  state = env.reset() 
  total_reward = 0
  for i in count():
    state = np.array(state,dtype=np.float32)
    state = torch.tensor(state, dtype=torch.float32, device=device)
    state = state.unsqueeze(0)
    action = select_action(state, TEST_EPS)
    state, _, done, info = env.step(action)
    reward = -1.0+info['visib']
    total_reward+=reward
    if done:
       break
  return i+1,total_reward,info['visib']
  

In [None]:
NUM_EPISODES = 200000


OPTIMIZE_MODEL_STEP = 4


TARGET_UPDATE=10000



STEPS_BEFORE_TRAIN = 30000



EPS_START = 1
EPS_END = 0.1
EPS_DECAY = 1000000

EPS_START_v2 = 0.1
EPS_END_v2 = 0.01

policy_net.train()
target_net.eval()
test_rewards = []


steps_done = 0

for e in range(NUM_EPISODES):

    state = env.reset() 
    state = np.array(state,dtype=np.float32)
    state = torch.tensor(state, dtype=torch.float32, device=device)
    state = state.unsqueeze(0)
    ep_rewards=0
    
    for t in range(180000):

        
        if steps_done<EPS_DECAY:
            if steps_done>STEPS_BEFORE_TRAIN:
                fraction=min(float(steps_done)/EPS_DECAY,1)
                eps_threshold= EPS_START + (EPS_END - EPS_START) * fraction
                action = select_action(state,eps_threshold)
            else:
                action=torch.tensor([[random.randrange(ACTIONS_NUM)]], device=device, dtype=torch.long)
        
        else:
            fraction=min(float(steps_done)/2*EPS_DECAY,1)
            eps_threshold= EPS_START_v2 + (EPS_END_v2 - EPS_START_v2) * fraction
            action = select_action(state,eps_threshold)
            
            
        
        next_state, _, done,info = env.step(action.item())
        reward = -1.0+info['visib']
        ep_rewards += reward
        
        next_state = np.array(next_state,dtype=np.float32)
        next_state = torch.tensor(next_state,dtype=torch.float32,device=device)
        next_state = next_state.unsqueeze(0)
        
        
        reward = torch.tensor([reward], device=device)
        if not done:
            memory.push(state, action,next_state, reward)
        else:
            next_state=None
            memory.push(state, action,next_state, reward)  
              
        steps_done+=1
        state=next_state
       
        
    
    
    
        if (steps_done > STEPS_BEFORE_TRAIN) and steps_done % OPTIMIZE_MODEL_STEP == 0:
            optimize_model()
        

        if steps_done % TARGET_UPDATE == 0:
            print("Target net updated!")
            target_net.load_state_dict(policy_net.state_dict())
        

        if done:
            train_rewards.append(np.sum(ep_rewards))         
            plot_rewards()
            break 
    if e%100==0:
      policy_net.eval()
      total = 0
      val = 0 
      visib_end = 0.
      for _ in range(10):
        res0,res1,final_visib = test()
        val+= res0
        total += res1
        visib_end+=final_visib
      policy_net.train()
      print('---- steps_done {}  ---- Test_score {} ----- Number of steps needed {} --- final_visib = {}'.format(steps_done,total/10.,val/10.,visib_end/10.))

In [None]:
torch.save(policy_net.state_dict(),'policy_net')

In [None]:
!kill -9 -1

In [None]:
# TEST_EPS = 0.005
 
dist_all = []
action_all = []
visib_all = []
steps = []
TEST_EPS = 0.0
env.reset_actions = 1000
env.max_steps = 200
for _ in range(100):
  state = env.reset()
  total_reward = 0
  for i in count():
    state = np.array(state,dtype=np.float32)
    state = torch.tensor(state, dtype=torch.float32, device=device)
    state = state.unsqueeze(0)
    action = select_action(state, TEST_EPS)
    state, _, done, info = env.step(action)
    reward = -1.0+info['visib']
    total_reward+=reward
    # dist_all.append(info['dist'])
    # visib_all.append(info['visib'])
    # action_all.append(action)
    if done:
        steps.append(i)
        break
    

In [None]:
steps = np.array(steps)
steps[steps<100].shape

In [None]:
np.sort(steps)

In [None]:
steps = []
dist_all = []


In [None]:
# TEST_EPS = 0.005
 
action_all = []
visib_all = []
TEST_EPS = 0.0
env.reset_actions = 5000
env.max_steps = 200
state = env.reset()
total_reward = 0
for i in count():
  state = np.array(state,dtype=np.float32)
  state = torch.tensor(state, dtype=torch.float32, device=device)
  state = state.unsqueeze(0)
  action = select_action(state, TEST_EPS)
  state, _, done, info = env.step(action)
  reward = -1.0+info['visib']
  total_reward+=reward
  if i == 0:
    dist_all.append(info['dist'])
  visib_all.append(info['visib'])
  action_all.append(action)
  if done:
    
    steps.append(i)
    break

In [None]:
plt.scatter(steps,dist_all)

In [None]:
plt.plot(dist_all)

In [None]:
plt.plot(visib_all)

In [None]:
visib_all[-1]