Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add input_tensor input type #2951

Merged
merged 12 commits into from
May 12, 2024
Merged

Add input_tensor input type #2951

merged 12 commits into from
May 12, 2024

Conversation

kSkip
Copy link
Contributor

@kSkip kSkip commented Apr 26, 2024

These changes add a new input type that enables feeding networks with batches of data that already reside in device memory.

For context, I am developing a robot simulator for batch reinforcement learning. A deep Q-network receives inputs generated from an OpenGL pipeline that renders a camera view representation of the world to a texture. This texture can be read using CUDA graphics interoperability. The experiences then accumulate in device memory. To avoid the round trip from host to device, I added this input layer.

I think the functionality could be useful beyond my application.

@davisking
Copy link
Owner

Thanks for the PR.

Why not just call .forward() with your tensor you have on hand though? Like instead of calling operator() you can call forward() and give that the tensor on device directly.

@kSkip
Copy link
Contributor Author

kSkip commented Apr 28, 2024

As I understand, .forward() only supports feeding a single tensor that is already assembled. I can concatenate the tensors from replay memory before hand, but is that not the responsibility of the network input?

There is a larger issue though with training. The dnn_trainer expects the input to properly support to_tensor() for the expected input_type. Therefore, if I do not define this input class, I cannot train the model like below

// "replay" is a container of tensors that were read from device memory
auto batch = sample(replay.begin(), replay.end(), batch_size, rng);
trainer.train_one_step(batch.begin(), batch.end(), target_values.begin());

I would have to develop my own trainer.

For inference, the same situation applies. I will be streaming video frames from a camera connected to a Jetson system. nvarguscamerasrc is used to capture the frames, and they reside in device memory.

Let me know if I am missing something, and thanks!

@davisking
Copy link
Owner

Ah, didn't realize you wanted to train with it. Yeah this is cool, makes sense :D

Can you add a short unit test to check that it works and then I'll merge it?

@kSkip
Copy link
Contributor Author

kSkip commented May 6, 2024

Awesome. I added a unit test. Let me know what you think.

@davisking davisking merged commit 51c7a35 into davisking:master May 12, 2024
9 checks passed
@davisking
Copy link
Owner

Nice, thanks for the PR :)

@kSkip
Copy link
Contributor Author

kSkip commented May 12, 2024

Of course. Btw, dlib is great!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants