Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

What is the best way to reshape the observation before feedint it to the network? #31

Open
LuisFMCuriel opened this issue Feb 3, 2022 · 1 comment

Comments

@LuisFMCuriel
Copy link
Contributor

Hi Dawid,

I am trying to do an example for Sneks. However, the default network (FcNet) only accepts 1 dim as an input shape. In the game, we have an observation shape of 16x16x3. I want to flatten the input before feeding the observation to the network, but I want to keep the spatial information of the state. Therefore I am using the ConvNe with a Flatten layer before feeding it to FcNet.

def network_fn(state_dim, output_dim, device):
    conv_net = ConvNet(state_dim, hidden_layers=(10,10), device=device)
    return NetChainer(
        net_classes=[
            conv_net,
            nn.Flatten(),
            FcNet(conv_net.output_size, output_dim, hidden_layers=(100, 100, 50), device=device),
        ]
    )

First I want to ask; Is it necessary to reshape the observation from 16x16x3 to 1x3x16x16 every time I will compute the logits? If it is, what is the best way in ai-traineree to do this? Creating another Net for this?

  class ReshapeNet(NetworkType):
    def __init__(self, shape) -> None:
        super(ReshapeNet, self).__init__()
        self.shape = shape

    def forward(self, x):
        return torch.reshape(x, self.shape)

Thanks.

@laszukdawid
Copy link
Owner

First of all, I haven't checked that the example work in a long long time and as of now non 1D input isn't supported by AI Traineree. But, this is likely the best time to add officially the support.

Is it necessary to reshape the observation from 16x16x3 to 1x3x16x16 every time I will compute the logits?

No in general but in case of this Snek environment - yes. The reason for transforming is to actually preserve spatial information. That 3 relates to RGB colors and 16x16 is the spatial. It's a common practise to treat colors as channels and performing convolution on spatials. Actually, the ConvNet (code) is just a wrapper around PyTorch's Conv2d.

I'm not really sure what's your intention for reshaping (back and forth?) but pytorch networks should work in place. So, as there is nn.Flatten() there's also nn.Unflatten() (pytorch Unflatten). Can you describe what you're trying to do? What's the intended architecture?

I'm honestly trying to avoid adding new networks unless they're necessary. I think that long run it's easier and more beneficial to make all networks compatible with pytorch network. There's plenty of people working on pytorch and there is only me on AI Traineree; trying to catch up is going to fail without any gain.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants