In [1]:
import torch
import torch.nn as nn

In [10]:
class DQN(nn.Module):
    """
    Deep Q Network to model the Q function.
    """
    def __init__(self, input_dim, output_dim):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        #assume 4*7 in dim output
        x = x.view(-1,4,7)
        return x

In [4]:
%pip install torchsummary

Collecting torchsummary
  Using cached torchsummary-1.5.1-py3-none-any.whl.metadata (296 bytes)
Using cached torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1
Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 23.3.2 -> 25.0.1
[notice] To update, run: C:\Users\etien\AppData\Local\Microsoft\WindowsApps\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\python.exe -m pip install --upgrade pip


In [11]:
import numpy as np
from torchsummary import summary
model = DQN(4*42,7*4)
summary(model, (64, 4*42))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1              [-1, 64, 128]          21,632
            Linear-2              [-1, 64, 128]          16,512
            Linear-3               [-1, 64, 28]           3,612
Total params: 41,756
Trainable params: 41,756
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.04
Forward/backward pass size (MB): 0.14
Params size (MB): 0.16
Estimated Total Size (MB): 0.34
----------------------------------------------------------------


In [12]:
input = torch.randn(64, 4*42)
output = model(input).view(-1, 4, 7)

print(output.size())

torch.Size([64, 4, 7])


In [14]:
q_values = torch.softmax(model(input),dim=-1)
print(q_values.size())
actions = q_values.argmax(dim=-1)
actions

torch.Size([64, 4, 7])


tensor([[6, 0, 3, 4],
        [1, 0, 1, 6],
        [1, 0, 2, 6],
        [6, 0, 3, 0],
        [1, 0, 3, 0],
        [1, 1, 3, 4],
        [1, 0, 3, 4],
        [1, 4, 2, 0],
        [6, 5, 1, 0],
        [6, 0, 1, 4],
        [4, 5, 1, 1],
        [1, 4, 2, 0],
        [1, 4, 1, 4],
        [1, 4, 3, 4],
        [3, 0, 2, 1],
        [1, 4, 3, 4],
        [1, 2, 3, 2],
        [2, 0, 2, 6],
        [4, 0, 3, 2],
        [1, 4, 3, 2],
        [3, 0, 2, 6],
        [1, 4, 3, 1],
        [1, 5, 1, 0],
        [1, 2, 2, 4],
        [1, 0, 2, 4],
        [3, 0, 2, 1],
        [1, 2, 2, 4],
        [1, 0, 3, 4],
        [1, 2, 3, 0],
        [6, 0, 3, 1],
        [1, 0, 3, 1],
        [1, 0, 3, 4],
        [1, 4, 3, 4],
        [1, 4, 3, 4],
        [6, 0, 3, 0],
        [1, 0, 2, 1],
        [6, 1, 3, 0],
        [3, 0, 2, 0],
        [1, 0, 1, 1],
        [1, 5, 3, 3],
        [6, 4, 1, 4],
        [3, 2, 3, 0],
        [6, 0, 3, 4],
        [2, 4, 3, 4],
        [1, 4, 3, 0],
        [1