Skip to content

Conversation

@a-r-r-o-w
Copy link
Contributor

What does this PR do?

Internal discussion: https://huggingface.slack.com/archives/C065E480NN9/p1727418894443269

Code
import numpy as np
import torch
from diffusers.models.embeddings import get_1d_sincos_pos_embed_from_grid, aryan_get_1d_sincos_pos_embed_from_grid
from diffusers.models.embeddings import get_2d_sincos_pos_embed_from_grid, aryan_get_2d_sincos_pos_embed_from_grid
from diffusers.models.embeddings import get_2d_sincos_pos_embed, aryan_get_2d_sincos_pos_embed
from diffusers.models.embeddings import get_3d_sincos_pos_embed, aryan_get_3d_sincos_pos_embed
from diffusers.models.embeddings import get_2d_rotary_pos_embed, aryan_get_2d_rotary_pos_embed
from diffusers.models.embeddings import get_3d_rotary_pos_embed, aryan_get_3d_rotary_pos_embed


@torch.no_grad()
def test__get_1d_sincos_pos_embed_from_grid():
    base_size = 16
    interpolation_scale = 1.0
    for embed_dim in [128, 256, 1024]:
        for grid_size in [16, 32, 64, 128]:
            numpy_grid = np.arange(grid_size, dtype=np.float32) / (grid_size / base_size) / interpolation_scale
            sincos_pos_embed_numpy = get_1d_sincos_pos_embed_from_grid(embed_dim, numpy_grid)

            for device in ["cpu", "cuda"]:
                torch_grid = torch.from_numpy(numpy_grid).to(device)
                sincos_pos_embed_torch = aryan_get_1d_sincos_pos_embed_from_grid(embed_dim, torch_grid).cpu().numpy()
                print(f"===== testing {test__get_1d_sincos_pos_embed_from_grid.__name__}({embed_dim=}, {grid_size=}, {device=}) =====")
                print(np.abs(sincos_pos_embed_numpy - sincos_pos_embed_torch).max())
                print(np.abs(sincos_pos_embed_numpy - sincos_pos_embed_torch).sum())
                print("==========")
                assert np.allclose(sincos_pos_embed_numpy, sincos_pos_embed_torch, rtol=1e-12, atol=1e-12)


@torch.no_grad()
def test__get_2d_sincos_pos_embed_from_grid():
    base_size = 16
    interpolation_scale = 1.0
    for embed_dim in [128, 256, 1024]:
        for grid_size in [(32, 32), (64, 64), (128, 64), (64, 128)]:
            grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
            grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
            grid = np.meshgrid(grid_w, grid_h)
            grid = np.stack(grid)
            grid_numpy = grid.reshape(2, 1, *grid_size)
            sincos_pos_embed_numpy = get_2d_sincos_pos_embed_from_grid(embed_dim, grid_numpy)
            
            for device in ["cpu", "cuda"]:
                grid_torch = torch.from_numpy(grid_numpy).to(device)
                sincos_pos_embed_torch = aryan_get_2d_sincos_pos_embed_from_grid(embed_dim, grid_torch).cpu().numpy()
                print(f"===== testing {test__get_2d_sincos_pos_embed_from_grid.__name__}({embed_dim=}, {grid_size=}, {device=}) =====")
                print(np.abs(sincos_pos_embed_numpy - sincos_pos_embed_torch).max())
                print(np.abs(sincos_pos_embed_numpy - sincos_pos_embed_torch).sum())
                print("==========")
                assert np.allclose(sincos_pos_embed_numpy, sincos_pos_embed_torch, rtol=1e-12, atol=1e-12)


@torch.no_grad()
def test__get_2d_sincos_pos_embed():
    for base_size in [16, 32, 64]:
        for interpolation_scale in [1.0, 2.0]:
            for embed_dim in [256, 1024]:
                for grid_size in [(64, 64), (128, 64), (64, 128)]:
                    for cls_token in [False, True]:
                        for extra_tokens in [0, 16]:
                            sincos_pos_embed_numpy = get_2d_sincos_pos_embed(
                                embed_dim=embed_dim,
                                grid_size=grid_size,
                                cls_token=cls_token,
                                extra_tokens=extra_tokens,
                                interpolation_scale=interpolation_scale,
                                base_size=base_size,
                            )

                            for device in ["cpu", "cuda"]:
                                sincos_pos_embed_torch = aryan_get_2d_sincos_pos_embed(
                                    embed_dim=embed_dim,
                                    grid_size=grid_size,
                                    cls_token=cls_token,
                                    extra_tokens=extra_tokens,
                                    interpolation_scale=interpolation_scale,
                                    base_size=base_size,
                                    device=device,
                                ).cpu().numpy()
                                print(f"===== testing {test__get_2d_sincos_pos_embed.__name__}({base_size=}, {interpolation_scale=}, {embed_dim=}, {grid_size=}, {cls_token=}, {extra_tokens=}, {device=}) =====")
                                print(np.abs(sincos_pos_embed_numpy - sincos_pos_embed_torch).max())
                                print(np.abs(sincos_pos_embed_numpy - sincos_pos_embed_torch).sum())
                                print("==========")
                                assert np.allclose(sincos_pos_embed_numpy, sincos_pos_embed_torch, rtol=1e-12, atol=1e-12)


@torch.no_grad()
def test__get_3d_sincos_pos_embed():
    for embed_dim in [256, 1024]:
        for spatial_size in [(64, 64), (128, 64), (64, 128)]:
            for temporal_size in [8, 16]:
                for spatial_interpolation_scale in [1.0, 2.0]:
                    for temporal_interpolation_scale in [1.0, 2.0]:
                        sincos_pos_embed_numpy = get_3d_sincos_pos_embed(
                            embed_dim=embed_dim,
                            spatial_size=spatial_size,
                            temporal_size=temporal_size,
                            spatial_interpolation_scale=spatial_interpolation_scale,
                            temporal_interpolation_scale=temporal_interpolation_scale,
                        )

                        for device in ["cpu", "cuda"]:
                            sincos_pos_embed_torch = aryan_get_3d_sincos_pos_embed(
                                embed_dim=embed_dim,
                                spatial_size=spatial_size,
                                temporal_size=temporal_size,
                                spatial_interpolation_scale=spatial_interpolation_scale,
                                temporal_interpolation_scale=temporal_interpolation_scale,
                                device=device,
                            ).cpu().numpy()
                            print(f"===== testing {test__get_2d_sincos_pos_embed.__name__}({embed_dim=}, {spatial_size=}, {temporal_size=}, {spatial_interpolation_scale=}, {temporal_interpolation_scale=}) =====")
                            print(np.abs(sincos_pos_embed_numpy - sincos_pos_embed_torch).max())
                            print(np.abs(sincos_pos_embed_numpy - sincos_pos_embed_torch).sum())
                            print("==========")
                            assert np.allclose(sincos_pos_embed_numpy, sincos_pos_embed_torch, rtol=1e-12, atol=1e-12)


@torch.no_grad()
def test__get_2d_rotary_pos_embed():
    for embed_dim in [256, 1024]:
        for crops_coords in [[(0, 0), (8, 8)], [(0, 0), (16, 16)]]:
            for grid_size in [(64, 64), (128, 64), (64, 128)]:
                rope_numpy = get_2d_rotary_pos_embed(
                    embed_dim=embed_dim,
                    crops_coords=crops_coords,
                    grid_size=grid_size,
                )

                for device in ["cpu", "cuda"]:
                    rope_torch = aryan_get_2d_rotary_pos_embed(
                        embed_dim=embed_dim,
                        crops_coords=crops_coords,
                        grid_size=grid_size,
                    )
                    rope_torch = rope_torch[0].cpu().numpy(), rope_torch[1].cpu().numpy()
                    print(f"===== testing {test__get_2d_sincos_pos_embed.__name__}({embed_dim=}, {crops_coords=}, {grid_size=}, {device=}) =====")
                    print(np.abs(rope_numpy[0] - rope_torch[0]).max())
                    print(np.abs(rope_numpy[1] - rope_torch[1]).max())
                    print(np.abs(rope_numpy[0] - rope_torch[0]).sum())
                    print(np.abs(rope_numpy[1] - rope_torch[1]).sum())
                    print("==========")
                    assert np.allclose(rope_numpy, rope_torch, rtol=1e-12, atol=1e-12)


@torch.no_grad()
def test__get_3d_rotary_pos_embed():
    for embed_dim in [256, 1024]:
        for crops_coords in [[(0, 0), (8, 8)], [(0, 0), (16, 16)]]:
            for grid_size in [(64, 64), (128, 64), (64, 128)]:
                for temporal_size in [8, 16]:
                    rope_numpy = get_3d_rotary_pos_embed(
                        embed_dim=embed_dim,
                        crops_coords=crops_coords,
                        grid_size=grid_size,
                        temporal_size=temporal_size,
                    )

                    for device in ["cpu", "cuda"]:
                        rope_torch = aryan_get_3d_rotary_pos_embed(
                            embed_dim=embed_dim,
                            crops_coords=crops_coords,
                            grid_size=grid_size,
                            temporal_size=temporal_size,
                            device=device,
                        )

                        # ============== ============== ============== ============== ============== ============== ============== 
                        # NOTE/TODO: NOT SURE WHY THIS NEEDS HIGHER TOLERANCE EVEN THOUGH ALL THE OPERATIONS ARE SIMILAR-ISH TO THE 2D CASE
                        # - IT IS EXACTLY ZERO ON CPU
                        # - BUT ON CUDA, IT IS SLIGHTLY NUMERICALLY DIFFERENT
                        # ============== ============== ============== ============== ============== ============== ============== 

                        rope_torch = rope_torch[0].cpu().numpy(), rope_torch[1].cpu().numpy()
                        print(f"===== testing {test__get_3d_sincos_pos_embed.__name__}({embed_dim=}, {crops_coords=}, {grid_size=}, {temporal_size=}, {device=}) =====")
                        print(np.abs(rope_numpy[0] - rope_torch[0]).max())
                        print(np.abs(rope_numpy[1] - rope_torch[1]).max())
                        print(np.abs(rope_numpy[0] - rope_torch[0]).sum())
                        print(np.abs(rope_numpy[1] - rope_torch[1]).sum())
                        print("==========")
                        assert np.allclose(rope_numpy, rope_torch, rtol=1e-12, atol=1e-6)


test__get_1d_sincos_pos_embed_from_grid()
test__get_2d_sincos_pos_embed_from_grid()
test__get_2d_sincos_pos_embed()
test__get_3d_sincos_pos_embed()
test__get_2d_rotary_pos_embed()
test__get_3d_rotary_pos_embed()
import numpy as np
assert np.finfo(np.float32).eps == np.float32(1.1920929e-07)

Most of the numerical differences that are seen are okay to have when comparing numpy array with tensor-to-numpy array, and are within the magnitude of float32.eps (in fact, similar on the order of 1e-12 so these changes should be safe to make). They are sometimes even 0 if you do the comparison in tensors instead of numpy. The only peculiar case is the 3D rope embeddings, which seem to require a much higher tolerance (1e-6) instead of 1e-12-1e-15. I'm trying to figure out why :(

Once we finalize and decide to go forward with this, I'll rename the functions (currently prefixed), remove the older numpy based functions, and pass the tensor device and all usage occurrences so that all sincos/rope embedding creation occurs on device directly.


To those who don't have access to the internal link, you might wonder why we need this change? It's because using torch.compile leads to a graph break and cudaMemSync on models that create sincos/rope positional embeddings on-the-fly. This is due to creating numpy arrays and converting to cpu pytorch tensors and then moving to accelerator device. It's usually not a problem when prepared inside the pipeline, but for multiresolution training and specific models, you have to create it on the fly. These changes are to ensure that we remain in tensor land and don't have to deal with numpy arrays anywhere in the execution path.

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@yiyixuxu @sayakpaul

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Seems to be in the right direction.

return emb


def aryan_get_3d_sincos_pos_embed(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could compare this and the numpy implementation on the same inputs and see if there's any divergence. Usually a good start to try to localize if there are any problems.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to match for all cases except the 3D sincos ones, based on my test code above. I'll give it a look again soon


def get_1d_rotary_pos_embed(
dim: int,
pos: Union[torch.Tensor, np.ndarray, int],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case of an np.array, the device will be a CPU, right? Would we incur any device placement penalty for that in case we're on a non-CPU device?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If an np.array is passed, what you said is right. I did not want to redo this function because it was already mostly torch land, and now it supports the extra torch.Tensor case (previously, it only supported np.ndarray and int)



def aryan_get_2d_rotary_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor, use_real: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
assert embed_dim % 4 == 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could move to raise ValueError() here and elsewhere applicable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

temporal_size: int,
spatial_interpolation_scale: float = 1.0,
temporal_interpolation_scale: float = 1.0,
device: Optional[torch.device] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the code, where we're calling these functions, I think it would be useful to always pass the right device. That way, we won't have to incur any device placement costs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes ofcourse, sounds good!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Nov 11, 2024
@a-r-r-o-w a-r-r-o-w removed the stale Issues that haven't received updates label Nov 11, 2024
@a-r-r-o-w a-r-r-o-w added the wip label Nov 18, 2024
@a-r-r-o-w
Copy link
Contributor Author

Closing in favor of @hlky's PRs

@a-r-r-o-w a-r-r-o-w closed this Dec 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants