Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[autoparallel] add pooling metainfo #1968

Merged
merged 76 commits into from
Nov 18, 2022

Conversation

Cypher30
Copy link
Contributor

What’s New?

In this PR, I implement the metainfo generator for pooling operations, including AdaptiveAvgPool and MaxPool. Also I found one interesting point during aligning the estimated memory cost with the real one. The _split in comm_spec.py actually being triggered twice when you meet sharing spec like S01, it will split the tensor along two dimensions of device mesh respectively, producing a piece of memory which could confusing when you measure the memory during runtime.

For example, you have an input with the shape of [4, 128, 64, 64] with dtype=float32, it takes 8192KB memory, and you want to split it on a device mesh with shape of (2, 2) and the sharding spec is RS01RR. To split it, you will found the shape consistency will first call _split on one dimension, producing a tensor with the shape of [4, 64, 64, 64], which will consume 4096KB extra memory because split the tensor on dimension 1 will create non-contiguous tensor. Then the second split will produce a tensor with the shape of [4, 32, 64, 64] to meet our requirement, thus producing another 2048KB memory, and the former created 4096KB memory will be discarded. Thus, you will observe a peak of 4096KB and the actual memory allocated is 2048KB. It is not being discovered in the previous op patch because the output is much bigger than the input, as we test the memory peak and memory allocated for the whole forward phase, the output it produces is much bigger than the peak that _split produces, so it covers this tricky little case.

Cypher30 and others added 30 commits July 14, 2022 16:07
@Cypher30 Cypher30 merged commit c26f21d into hpcaitech:main Nov 18, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants