Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NumpyBatch utils from PVNet #287

Merged
merged 6 commits into from
Mar 26, 2024

Conversation

markus-kreft
Copy link
Contributor

Pull Request

Description

Transfer functions from PVNet specific to NumpyBatches here. Type hints are very minimal because recursive type support is still quite tricky. I have tried to use a nested dict type NestedDict = dict[Any, Union[T, "NestedDict[T]"]] but have trouble getting mypy to recognize NumpyBatches as such. I would be very interested if someone can teach me how to do this properly.

Fixes openclimatefix/PVNet#111

How Has This Been Tested?

Logic comes for another repository. I have added a simple test in tests/batch/test_utils.py

Checklist:

  • My code follows OCF's coding style guidelines
  • I have performed a self-review of my own code
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • I have checked my code and corrected any misspellings

Copy link
Member

@dfulu dfulu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @markus-kreft, thanks very much for jumping on this. We really appreciate your help.

For the type-hint, I would use the NumpyBatch class which can be imported from ocf_datapipes.batch. Since we aim to support NumpyBatch with this function rather than any arbitrary nested dict

@markus-kreft
Copy link
Contributor Author

Hi @dfulu, thank you for the hint. Making the functions Batch-specific was also my first idea, but getting the types right also turns out to be quit difficult with the recursion (in the second call it will not be a NumpyBatch anymore), and allowing multiple types makes trouble when accessing the batch keys.

To enforce the desired types in the public facing API I have two ideas:

  1. Wrap the existing generic function (here _batch_to_tensor is the current function):
def batch_to_tensor(batch: NumpyBatch) -> TensorBatch:
    return _batch_to_tensor(batch)
  1. Replace the recursion by loops, because we know that a NumpyBatch is at most 3 levels deep. This will lead to a lot of duplicated code.

From this SO post I think turning Batches into TypedDict subclasses could help (though not sure?), but this would mean redoing the batch module. Or maybe someone sees another way?

@markus-kreft
Copy link
Contributor Author

I have implemented option 1 for now.

@dfulu dfulu self-requested a review March 26, 2024 15:48
Copy link
Member

@dfulu dfulu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @markus-kreft, sorry for the delay. We've had a lot of internal deadlines at OCF in the last couple of weeks. The type hinting here is a quite difficult because of the recursion, but I think you've come up with a good pragmatic solution. So thanks for this work

@dfulu
Copy link
Member

dfulu commented Mar 26, 2024

@all-contributors please add @markus-kreft for code

Copy link
Contributor

@dfulu

I've put up a pull request to add @markus-kreft! 🎉

@dfulu dfulu merged commit 2e8b064 into openclimatefix:main Mar 26, 2024
1 check passed
@dfulu dfulu removed the request for review from jacobbieker March 26, 2024 15:55
@markus-kreft
Copy link
Contributor Author

Hi James, no worries, I can imagine you have a lot going on at OCF right now. I am happy my approach to typing worked out :-) I will integrate the functions in PVNet.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Transfer batch transformation functions to ocf_datapipes
2 participants