## Masking in PyTorch

Minimal example to figure out how to mask for DQN action selection. 

In [1]:
import torch

In [164]:
# create a simple model assuming 1-d observations; binary action space
model = torch.nn.Sequential(
torch.nn.Linear(2, 20),
torch.nn.ReLU(),
torch.nn.Linear(20, 2))

# some observations
observations = torch.tensor([[1, 3], [2, 4], [-5, 1]], dtype=torch.float)
observations

tensor([[ 1.,  3.],
        [ 2.,  4.],
        [-5.,  1.]])

In [165]:
# these are the predicted q-values for each of the actions
model(observations)

tensor([[-0.1829,  0.4372],
        [-0.2390,  0.6570],
        [-0.0767,  0.4764]], grad_fn=<AddmmBackward>)

In [166]:
# action indices
action_indices = torch.tensor([[0], [1], [0]], dtype=torch.long)
action_indices

tensor([[0],
        [1],
        [0]])

In [167]:
# getting only q-values for observations corresponding to chosen actions
model(observations).gather(1, action_indices)

tensor([[-0.1829],
        [ 0.6570],
        [-0.0767]], grad_fn=<GatherBackward>)

Fitting values to a masked net

In [190]:
rewards = torch.tensor([1, 1, 0], dtype=torch.float)
next_observations = torch.tensor([[1, 1], [2, 0], [-5, 1]], dtype=torch.float)
terminal = torch.tensor([0, 0, 1], dtype=torch.uint8)

In [169]:
batch_size = 3
next_state_values = torch.zeros(batch_size)
next_state_values

tensor([0., 0., 0.])

In [170]:
model(next_observations)

tensor([[-0.3207,  0.1975],
        [-0.2168,  0.3437],
        [-0.0767,  0.4764]], grad_fn=<AddmmBackward>)

In [171]:
model(next_observations[1 - terminal]).max(1)[0].detach()

tensor([0.1975, 0.3437])

In [191]:
next_state_values = torch.zeros(batch_size)

next_state_values[1 - terminal] = model(next_observations[1 - terminal]).max(1)[0].detach()
next_state_values

tensor([0.1975, 0.3437, 0.0000])

In [173]:
gamma = 0.99

In [174]:
expected_state_action_values = rewards + gamma * next_state_values
expected_state_action_values

tensor([1.1955, 1.3403, 0.0000])

In [175]:
rewards

tensor([1., 1., 0.])

In [176]:
loss_fn = torch.nn.MSELoss()

In [177]:
action_indices

tensor([[0],
        [1],
        [0]])

In [178]:
expected_state_action_values.squeeze()

tensor([1.1955, 1.3403, 0.0000])

In [179]:
predicted_state_action_values = model(observations).gather(1, action_indices)
predicted_state_action_values

tensor([[-0.1829],
        [ 0.6570],
        [-0.0767]], grad_fn=<GatherBackward>)

In [180]:
loss_fn(expected_state_action_values, predicted_state_action_values.squeeze())

tensor(0.7909, grad_fn=<MeanBackward0>)

## Mapping minibatch list of tuples to multiple lists

In [185]:
minibatch = [([0.2, 0.1], 0, 1, [0.3, 0.4], 0), ([0.4, 0.9], 1, 1, [0.5, 0.0], 0)]

In [186]:
tuple([*zip(*minibatch)])

(([0.2, 0.1], [0.4, 0.9]), (0, 1), (1, 1), ([0.3, 0.4], [0.5, 0.0]), (0, 0))

In [187]:
observations, actions, rewards, next_observations, terminal_indicators = [*zip(*minibatch)]

In [188]:
observations = torch.tensor(observations, dtype=torch.float)
observations

tensor([[0.2000, 0.1000],
        [0.4000, 0.9000]])

In [162]:
actions

(0, 1)

In [3]:
test_tensor = torch.tensor([1, 2, 3])
test_tensor

tensor([1, 2, 3])

In [6]:
test_tensor.unsqueeze(1)

tensor([[1],
        [2],
        [3]])