In [1]:
import torch, torchvision
import torch.nn.functional as F
import random, math
import numpy as np
import torch.nn as nn
import torchvision.transforms as T
import torch.optim as optim
from itertools import count
from collections import namedtuple

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

In [200]:
class DQN(nn.Module):
	def __init__(self):
		super(DQN, self).__init__()
		self.conv1 = nn.Conv2d(1, 16, kernel_size=5, stride=2)
		self.bn1 = nn.BatchNorm2d(16)
		self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
		self.bn2 = nn.BatchNorm2d(32)
		self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
		self.bn3 = nn.BatchNorm2d(32)
		self.head = nn.Linear(3200, 3) # ???
		# out : Left, Straight, Right

	def forward(self, x):
		x = F.relu(self.bn1(self.conv1(x)))
		x = F.relu(self.bn2(self.conv2(x)))
		x = F.relu(self.bn3(self.conv3(x)))
		# return x
		x = x.view(x.size(0), -1)
		# print(x.shape)
		return self.head(x)
    

BATCH_SIZE = 128
GAMMA = 0.999
EPS_START = 0.9
EPS_DECAY = 200
EPS_END = 0.05
TARGET_UPDATE = 10

policy_net = DQN().to(device)
target_net = DQN().to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.RMSprop(policy_net.parameters())

In [145]:
states = torch.zeros([10, 1, 102, 101], dtype=torch.float32).random_(0, 4)
# state = state.view((1, 32*32))
policy_net(states)


tensor([[ 0.2601, -0.0231,  0.2603],
        [ 0.0475,  0.0407,  0.3560],
        [ 0.2654,  0.0572,  0.1607],
        [ 0.1011, -0.0887,  0.1221],
        [ 0.2362,  0.1347,  0.0565],
        [ 0.0178, -0.0747,  0.1702],
        [-0.0426, -0.1332,  0.0722],
        [-0.3655, -0.0852,  0.1595],
        [-0.0318, -0.1807,  0.0860],
        [ 0.0242,  0.0762, -0.1082]])

In [167]:
steps_done = 0
def select_action(state):
	global steps_done
	sample = random.random()
	eps_threshold = EPS_END + (EPS_START - EPS_END) * \
		math.exp(-1. * steps_done / EPS_DECAY)
	steps_done += 1
	if sample > eps_threshold:
		with torch.no_grad():
			return policy_net(state).max(1)[1].view(1, 1)
	else:
		return torch.tensor([[random.randrange(3)]], device=device)

def plain_select_action(states):
    with torch.no_grad():
        return policy_net(states.view(-1, 1, 102, 101)).argmax(dim=1)

plain_select_action(states)

def countArea(state, id):
    return state.eq(torch.zeros(state.shape, dtype=torch.float32, device=device) + id).sum(dtype=torch.float32).to(device)

def reward(state, action, id):
    # next_state = get_next_state(state, action)
    next_state = state
    # opponent_done = check_done(1 - id)
    opponent_done = False
    # self_died = check_done(id)
    self_died = False
    return (countArea(state, id) / 10000 + torch.tensor(5. if opponent_done else 0.) + torch.tensor(-5. if self_died else 0.)).to(device)

In [168]:
countArea(states[0], 0)
reward(states[0], (), 2)

tensor(0.2650)

In [173]:
torch.cat([s for s in states]).shape

torch.Size([10, 102, 101])

In [186]:
action_val_batch = policy_net(states)
action_batch = action_val_batch.argmax(1).view(-1,1)
# action_batch

In [205]:
state_action_values = policy_net(states).gather(1, action_batch)

In [206]:
next_states = torch.zeros([10, 1, 102, 101], dtype=torch.float32).random_(0, 4)
# state = state.view((1, 32*32))
policy_net(next_states)

tensor([[ 0.2153,  0.0621,  0.1963],
        [ 0.3528,  0.3629, -0.1931],
        [ 0.4483,  0.4099, -0.0482],
        [ 0.0795, -0.0169, -0.2965],
        [ 0.1319,  0.0505,  0.0290],
        [ 0.5002, -0.0779,  0.0654],
        [-0.1783,  0.0192, -0.0485],
        [-0.1900, -0.1327, -0.1267],
        [ 0.1933,  0.0698,  0.1739],
        [ 0.0375,  0.0214, -0.0948]])

In [207]:
reward_batch=torch.zeros([10,1],dtype=torch.float32).random_(0,4)
# reward_batch

In [217]:
next_state_vals = target_net(next_states).max(1)[0].detach().view(10,1)
# next_state_expected_val
expected_next_state_vals = reward_batch + next_state_vals*GAMMA # (10,1)
# print(state_action_values.shape)
print(expected_next_state_vals.shape)
loss=F.smooth_l1_loss(state_action_values, expected_next_state_vals)

torch.Size([10, 1])


In [218]:
optimizer.zero_grad()
loss.backward()

In [219]:
for param in policy_net.parameters():
    param.grad.data.clamp(-1,1)
optimizer.step()

In [222]:
for p in policy_net.parameters(): print(p.shape)

torch.Size([16, 1, 5, 5])
torch.Size([16])
torch.Size([16])
torch.Size([16])
torch.Size([32, 16, 5, 5])
torch.Size([32])
torch.Size([32])
torch.Size([32])
torch.Size([32, 32, 5, 5])
torch.Size([32])
torch.Size([32])
torch.Size([32])
torch.Size([3, 3200])
torch.Size([3])
