diff --git a/pytorch_rl/model.py b/pytorch_rl/model.py index 7d950147..19c78738 100755 --- a/pytorch_rl/model.py +++ b/pytorch_rl/model.py @@ -50,7 +50,7 @@ def __init__(self, num_inputs, action_space, use_gru): self.conv4 = nn.Conv2d(32, 32, 4, stride=1) self.linear1_drop = nn.Dropout(p=0.5) - self.linear1 = nn.Linear(32 * 9 * 14, 256) + self.linear1 = nn.Linear(32 * 74 * 54, 256) if use_gru: self.gru = nn.GRUCell(512, 512) @@ -111,8 +111,7 @@ def forward(self, inputs, states, masks): x = self.conv4(x) x = F.leaky_relu(x) - - x = x.view(-1, 32 * 9 * 14) + x = x.view(-1, 32 * 74 * 54) x = self.linear1_drop(x) x = self.linear1(x) x = F.leaky_relu(x)