-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Python Tutorial and PyTorch model added to train TMVA_RNN_Classification.C #17
base: master
Are you sure you want to change the base?
Python Tutorial and PyTorch model added to train TMVA_RNN_Classification.C #17
Conversation
…ion.C We didn't had a PyTorch model to Train the RNN in TMVA similar to CNN, so here is the PyTorch Model with the Python tutorial for TMVA_RNN_Classification.C
Kindly review this PR then we can move it to Root repository if its fine. |
# CNN Model Definition | ||
net = torch.nn.Sequential( | ||
Reshape(), | ||
nn.Conv2d(1, 10, kernel_size=3, padding=1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks to me you are creating a CNN model instead of a RNN
# CNN Model Definition | ||
net = torch.nn.Sequential( | ||
Reshape(), | ||
nn.Conv2d(1, 10, kernel_size=3, padding=1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks to me you are creating a CNN model instead of a RNN
@@ -0,0 +1,308 @@ | |||
import ROOT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code should be integrated in the tutorial we are adding in the root-project#10442
@@ -0,0 +1,308 @@ | |||
import ROOT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code should be integrated in the tutorial we are adding in the root-project#10442
We didn't had a PyTorch model to Train the RNN in TMVA similar to CNN, so here is the improvement for the PyTorch Model with the Python tutorial for TMVA_RNN_Classification.C
This Pull request: Adds a PyTorch Model to Train the RNN in TMVA
Checklist: