Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible code simplification? #9

Closed
Parskatt opened this issue Jun 23, 2021 · 4 comments
Closed

Possible code simplification? #9

Parskatt opened this issue Jun 23, 2021 · 4 comments

Comments

@Parskatt
Copy link

Hi,

Here:

def prune(self, ct):
bsz, ch, ha, wa, hb, wb = ct.size()
if not self.idx_initialized:
idxh = torch.arange(start=0, end=hb, step=self.stride[2:][0], device=ct.device)
idxw = torch.arange(start=0, end=wb, step=self.stride[2:][1], device=ct.device)
self.len_h = len(idxh)
self.len_w = len(idxw)
self.idx = (idxw.repeat(self.len_h, 1) + idxh.repeat(self.len_w, 1).t() * wb).view(-1)
self.idx_initialized = True
ct_pruned = ct.view(bsz, ch, ha, wa, -1).index_select(4, self.idx).view(bsz, ch, ha, wa, self.len_h, self.len_w)
return ct_pruned

Is a rather complicated function, which (if I understand it correctly) just strides the final two dimensions.
Could you instead simply do:

out1 = x[...,::2,::2]

Perhaps there is something I'm missing here? Otherwise, I think it would make the code more readable.

@juhongm999
Copy link
Owner

juhongm999 commented Jun 24, 2021

Applying strides on the final two dimensions will indeed simplify the code, having the same results as our code.
The reason why we wrote rather complicated functions is that in our experiments, torch.index_select() function performs much faster forward/backward passes (I remember it was the backward pass that achieves dramatic speed gains: I guess 2~3 times faster) than numpy style indexing as suggested.

@Parskatt
Copy link
Author

Ah, that's super interesting!
Perhaps this should be discussed with the authors of pytorch?

@juhongm999
Copy link
Owner

juhongm999 commented Jun 25, 2021

Yup. Here I provide related issues:
pytorch/pytorch#14231
pytorch/pytorch#15245
pytorch/pytorch#13420

I had this issue at the time I was working on different work but It seems they fixed it I guess.

@Parskatt
Copy link
Author

Ah, good to hear that they seem to have fixed the performance.
I'm using similar striding in my own project so I'm happy I don't have to complicate my code.

I'll close the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants