Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Learnable Query #19

Merged
merged 18 commits into from Sep 28, 2021
Merged

Add Learnable Query #19

merged 18 commits into from Sep 28, 2021

Conversation

jacobbieker
Copy link
Member

@jacobbieker jacobbieker commented Sep 17, 2021

Pull Request

Description

This adds a Learnable Query constructor that creates a query with some randomness to possibly be able to "ensemble" predictions through querying the Perceiver output multiple times.

Fixes issue #16

How Has This Been Tested?

Unit tests

  • No
  • Yes

Checklist:

  • My code follows OCF's coding style guidelines
  • I have performed a self-review of my own code
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • I have checked my code and corrected any misspellings

@jacobbieker jacobbieker added the enhancement New feature or request label Sep 17, 2021
@jacobbieker jacobbieker self-assigned this Sep 17, 2021
@jacobbieker jacobbieker linked an issue Sep 17, 2021 that may be closed by this pull request
@jacobbieker jacobbieker marked this pull request as ready for review September 22, 2021 10:45
x = torch.randn((4, 6, 12, 16, 16))
out = query_creator(x)
# Output is flattened, so should be [B, T*H*W, C]
assert out.shape == (4, 16 * 16 * 6, 803)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I probably just dont understand, where does the 803 come from?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is the number of channels of the channel_dim so 32 + the number of Fourier Features, which for this is 771

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And the number of Fourier Features comes from (num input axis)((num_freq_bands2)+1) so for this 3*(128*2+1) = 771

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this

Copy link

@peterdudfield peterdudfield left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me,

had for me to judge if this is the right method - is it copied from somewhere? or reference in a paper?

@jacobbieker
Copy link
Member Author

Looks good to me,

had for me to judge if this is the right method - is it copied from somewhere? or reference in a paper?

This isn't copied from anywhere, just the general idea from using a random latent space to generate predictions from GANs, somewhat the Skillful Nowcasting GAN paper I guess

@jacobbieker jacobbieker added this to In progress in National Grid Nowcasting: WP1 via automation Sep 24, 2021
@jacobbieker jacobbieker moved this from In progress to Review in progress in National Grid Nowcasting: WP1 Sep 24, 2021
Copy link
Member

@JackKelly JackKelly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks really good! I do love this idea of using random queries to create an ensemble!

perceiver_pytorch/queries.py Show resolved Hide resolved
_LOG.setLevel(logging.WARN)


class LearnableQuery(torch.nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the topic of position encoding for the queries...

I think the position encoding needs to be 'internally consistent' across each entire example. That is, the position encoding should be consistent across data inputs and the queries so the model can see that, say, the last timestep of the recent history in the input is immediately before the first timestep of the query. And that, in general, the model can see that the timeseries of recent history and the timeseries of queries are two parts of a contiguous timeseries.

To give a concrete example: If the recent history in the input spans 11:00 to 11:55, and the query is for 12:00 to 12:55, then we want the model to see that the first timestep of the query is 5 minutes after the last timestep of the input... if that makes sense?!

One way to do this might be to encode the positions once for all timesteps in the example (i.e. the concatenation of the recent history and the forecast timesteps), and then concatenate the last forecast_timesteps of the position encoding to the queries? Or something like that?! Not sure what's best!

Perhaps this could wait for a future PR though!

This is just a hunch, of course!

And, ultimately, we might want the position encoding to include both the relative position ("the ith element in the array") and the absolute position in time and space ("2pm in South London"). But we can worry about that later :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I might add that in a follow up PR, but it should be fairly easy to add those as options. I think it's a good idea to try at least!


Args:
channel_dim: Channel dimension for the output of the network
query_shape: The final shape of the query, generally, the (T, H, W) of the output
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for being slow but what's the T in (T, H, W)? Time?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, its the number of timesteps

elif conv_layer == "2d":
conv = torch.nn.Conv2d
else:
raise ValueError(f"Value for 'layer' is {conv_layer} which is not one of '3d', '2d'")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

National Grid Nowcasting: WP1 automation moved this from Review in progress to Reviewer approved Sep 28, 2021
@jacobbieker jacobbieker merged commit 31631d2 into main Sep 28, 2021
National Grid Nowcasting: WP1 automation moved this from Reviewer approved to Done Sep 28, 2021
@jacobbieker jacobbieker deleted the jacob/learnable-queries branch September 28, 2021 17:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
No open projects
Development

Successfully merging this pull request may close these issues.

Add Learnable Query
3 participants