In [1]:
import torch, torchvision
import torch.nn.functional as F
import random, math
import numpy as np
import torch.nn as nn
import torchvision.transforms as T
import torch.optim as optim
from itertools import count
from collections import namedtuple

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

In [55]:
class DQN(nn.Module):
	def __init__(self):
		super(DQN, self).__init__()
		self.conv1 = nn.Conv2d(1, 32, kernel_size=5, stride=2)
		self.bn1 = nn.BatchNorm2d(32)
		self.conv2 = nn.Conv2d(32, 64, kernel_size=5, stride=2)
		self.bn2 = nn.BatchNorm2d(64)
		self.conv3 = nn.Conv2d(64, 32, kernel_size=5, stride=2)
		self.bn3 = nn.BatchNorm2d(32)
		self.head = nn.Linear(3200, 3) # ???
		# out : Left, Straight, Right

	def forward(self, x):
		x = F.relu(self.bn1(self.conv1(x)))
		x = F.relu(self.bn2(self.conv2(x)))
		x = F.relu(self.bn3(self.conv3(x)))
		# return x
		x = x.view(x.size(0), -1)
		print(x.shape)
		return self.head(x)
    

BATCH_SIZE = 128
GAMMA = 0.999
EPS_START = 0.9
EPS_DECAY = 200
EPS_END = 0.05
TARGET_UPDATE = 10

policy_net = DQN().to(device)
target_net = DQN().to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.RMSprop(policy_net.parameters())

In [56]:
states = torch.zeros([10, 1, 102, 101], dtype=torch.float32).random_(0, 4)
# state = state.view((1, 32*32))
policy_net(states)


torch.Size([10, 3200])


tensor([[ 0.0520, -0.4447,  0.1911],
        [ 0.2817, -0.1657,  0.0655],
        [-0.3039, -0.3061,  0.0703],
        [ 0.0751, -0.0557, -0.3605],
        [-0.0781, -0.0136,  0.4026],
        [-0.0499, -0.2134,  0.3827],
        [-0.2112, -0.0813, -0.0105],
        [ 0.1924, -0.4436,  0.0512],
        [-0.2738, -0.2689, -0.1009],
        [ 0.0495, -0.3047,  0.3029]])

In [5]:
steps_done = 0
def select_action(state):
	global steps_done
	sample = random.random()
	eps_threshold = EPS_END + (EPS_START - EPS_END) * \
		math.exp(-1. * steps_done / EPS_DECAY)
	steps_done += 1
	if sample > eps_threshold:
		with torch.no_grad():
			return policy_net(state).max(1)[1].view(1, 1)
	else:
		return torch.tensor([[random.randrange(3)]], device=device)

def plain_select_action(states):
    with torch.no_grad():
        return policy_net(states.view(-1, 1, 102, 101)).argmax(dim=1)

plain_select_action(states)

def countArea(state, id):
    return state.eq(torch.zeros(state.shape, dtype=torch.float32, device=device) + id).sum(dtype=torch.float32).to(device)

def reward(state, action, id):
    # next_state = get_next_state(state, action)
    next_state = state
    # opponent_done = check_done(1 - id)
    opponent_done = False
    # self_died = check_done(id)
    self_died = False
    return (countArea(state, id) / 10000 + torch.tensor(5. if opponent_done else 0.) + torch.tensor(-5. if self_died else 0.)).to(device)

In [6]:
countArea(states[0], 0)
reward(states[0], (), 2)

tensor(0.2506)

In [7]:
torch.cat([s for s in states]).shape

torch.Size([10, 102, 101])

In [8]:
action_val_batch = policy_net(states)
action_batch = action_val_batch.argmax(1).view(-1,1)
# action_batch

In [9]:
state_action_values = policy_net(states).gather(1, action_batch)

In [10]:
next_states = torch.zeros([10, 1, 102, 101], dtype=torch.float32).random_(0, 4)
# state = state.view((1, 32*32))
policy_net(next_states)

tensor([[-0.0556,  0.1022,  0.2268],
        [-0.4877, -0.0440,  0.1352],
        [-0.7305, -0.1994,  0.0613],
        [-0.3419,  0.1413, -0.1669],
        [-0.1917,  0.3157, -0.1461],
        [-0.5572, -0.0786,  0.1484],
        [-0.1007,  0.1157, -0.0268],
        [-0.2318,  0.0986, -0.1993],
        [-0.7225,  0.5529,  0.0708],
        [-0.3112,  0.0133, -0.0953]])

In [11]:
reward_batch=torch.zeros([10,1],dtype=torch.float32).random_(0,4)
# reward_batch

In [12]:
next_state_vals = target_net(next_states).max(1)[0].detach().view(10,1)
# next_state_expected_val
expected_next_state_vals = reward_batch + next_state_vals*GAMMA # (10,1)
# print(state_action_values.shape)
print(expected_next_state_vals.shape)
loss=F.smooth_l1_loss(state_action_values, expected_next_state_vals)

torch.Size([10, 1])


In [13]:
optimizer.zero_grad()
loss.backward()

In [21]:
for param in policy_net.parameters():
    param.grad.data.clamp(-1,1)
optimizer.step()

AttributeError: 'NoneType' object has no attribute 'data'

In [23]:
d=torch.tensor(-3).abs()
d

tensor(3)

In [29]:
d=torch.zeros([4,5,6],dtype=torch.float32).random_(0,4)
d

tensor([[[ 0.,  2.,  1.,  2.,  0.,  1.],
         [ 3.,  2.,  2.,  0.,  2.,  2.],
         [ 0.,  0.,  3.,  1.,  2.,  1.],
         [ 0.,  3.,  0.,  2.,  2.,  1.],
         [ 1.,  3.,  0.,  2.,  0.,  3.]],

        [[ 3.,  1.,  2.,  2.,  0.,  2.],
         [ 1.,  2.,  0.,  1.,  2.,  3.],
         [ 2.,  2.,  0.,  1.,  3.,  0.],
         [ 1.,  3.,  2.,  1.,  3.,  1.],
         [ 1.,  2.,  3.,  2.,  2.,  3.]],

        [[ 2.,  0.,  0.,  1.,  3.,  2.],
         [ 3.,  2.,  0.,  2.,  0.,  0.],
         [ 3.,  2.,  1.,  1.,  0.,  1.],
         [ 2.,  0.,  3.,  3.,  2.,  0.],
         [ 1.,  2.,  1.,  2.,  1.,  2.]],

        [[ 2.,  3.,  3.,  1.,  3.,  2.],
         [ 0.,  0.,  3.,  2.,  0.,  0.],
         [ 1.,  2.,  2.,  2.,  3.,  2.],
         [ 2.,  2.,  0.,  0.,  3.,  2.],
         [ 1.,  2.,  1.,  0.,  0.,  0.]]])

In [30]:
d.eq(2)

tensor([[[ 0,  1,  0,  1,  0,  0],
         [ 0,  1,  1,  0,  1,  1],
         [ 0,  0,  0,  0,  1,  0],
         [ 0,  0,  0,  1,  1,  0],
         [ 0,  0,  0,  1,  0,  0]],

        [[ 0,  0,  1,  1,  0,  1],
         [ 0,  1,  0,  0,  1,  0],
         [ 1,  1,  0,  0,  0,  0],
         [ 0,  0,  1,  0,  0,  0],
         [ 0,  1,  0,  1,  1,  0]],

        [[ 1,  0,  0,  0,  0,  1],
         [ 0,  1,  0,  1,  0,  0],
         [ 0,  1,  0,  0,  0,  0],
         [ 1,  0,  0,  0,  1,  0],
         [ 0,  1,  0,  1,  0,  1]],

        [[ 1,  0,  0,  0,  0,  1],
         [ 0,  0,  0,  1,  0,  0],
         [ 0,  1,  1,  1,  0,  1],
         [ 1,  1,  0,  0,  0,  1],
         [ 0,  1,  0,  0,  0,  0]]], dtype=torch.uint8)

In [31]:
a = torch.tensor([[3]])

In [39]:
int(a[0][0].data)

3

In [40]:
int(a)

3

In [47]:
torch.cat([torch.tensor([[3]]) for _ in range(3)]).unsqueeze(1).shape

torch.Size([3, 1, 1])

In [50]:
policy_net(states).argmax(1).view(-1, 1, 1)

tensor([[[ 0]],

        [[ 0]],

        [[ 2]],

        [[ 1]],

        [[ 0]],

        [[ 1]],

        [[ 0]],

        [[ 2]],

        [[ 0]],

        [[ 2]]])

In [51]:
d.sum()

tensor(178.)