Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[device] find best logical mesh #2342

Merged

Conversation

YuliangLiu0306
Copy link
Contributor

@YuliangLiu0306 YuliangLiu0306 commented Jan 5, 2023

What does this PR do

  1. implement search_best_logical_mesh function, which could find best logical mesh for the given device list.

    The best logical mesh is searched in following steps:

    1. detect homogeneous device groups, we assume that the devices in the alpha_beta_dict are homogeneous if the beta value is close enough.
    2. Find the best homogeneous device group contains all the physical devices. The best homogeneous device group means the lowest beta value in the groups which contains all the physical devices. And the reason we require the group contains all the physical devices is that the devices not in the group will decrease the bandwidth of the group.
    3. If the best homogeneous device group is found, we will construct the largest ring for each device based on the best homogeneous device group, and the best logical mesh will be the union of all the rings. Otherwise, the best logical mesh will be the balanced logical mesh, such as shape (2, 2) for 4 devices.

    Usage:

        >>> physical_devices = [0, 1, 2, 3]
        >>> ab_profiler = AlphaBetaProfiler(physical_devices)
        >>> best_logical_mesh = profiler.search_best_logical_mesh()
        >>> print(best_logical_mesh)
        [[0, 1], [2, 3]]
  2. implement extract_alpha_beta_for_device_mesh function which extract the mesh_alpha list and mesh_beta list based on the best logical mesh.

    Usage:

    	>>> physical_devices = [0, 1, 2, 3]
    	>>> ab_profiler = AlphaBetaProfiler(physical_devices)
    	>>> mesh_alpha, mesh_beta = profiler.extract_alpha_beta_for_device_mesh()
    	>>> print(mesh_alpha)
    	[2.5917552411556242e-05, 0.00010312341153621673]
    	>>> print(mesh_beta)
    	[5.875573704655635e-11, 4.7361584445959614e-12]
  3. construct test cases to test above features.

@feifeibear feifeibear merged commit 4e96039 into hpcaitech:main Jan 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants