New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improved readability of the VIN model, in addition to minor changes #12
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two notes, otherwise LGTM
model.py
Outdated
:param k: number of iterations | ||
:return: logits and softmaxed logits | ||
""" | ||
h = self.h(input_view) # intermediate output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Please capitalize the first letter for the new comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point! Fixed it.
model.py
Outdated
v, _ = torch.max(q, dim=1, keepdim=True) | ||
for i in range(0, config.k - 1): | ||
|
||
for i in range(k): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be range(k-1)
to stick with the paper
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see! I made a compromising commit by extracting the q evaluation step as a separate function.
In addition to the fixes you suggested, I also refactored the |
LGTM |
My main modification is in the forward method of the model where you extract the q_out from the q values, and not repeating q = F.conv2d(...) in two places. I also made minor improvements, such as adding argparse in the dataset creation script and changing .cuda() into .to(device) in test.py.