Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Document for different specializations and how they affect shapes #54

Closed
dashesy opened this issue Oct 25, 2022 · 2 comments
Closed

Document for different specializations and how they affect shapes #54

dashesy opened this issue Oct 25, 2022 · 2 comments

Comments

@dashesy
Copy link
Contributor

dashesy commented Oct 25, 2022

After reading some examples, e.g. here:

class Mlp(nn.Module):
    """MLP as used in Vision Transformer, MLP-Mixer and related networks"""

    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        act_layer="GELU",
        drop=0,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        self.fc1 = nn.Linear(
            in_features,
            hidden_features,
            specialization="fast_gelu" if act_layer == "GELU" else "relu",
        )
        self.fc2 = nn.Linear(hidden_features, out_features, specialization="add")

    def forward(self, x, res):
        shape = get_shape(x)
        x = self.fc1(x)
        x = self.fc2(x, res)
        return ops.reshape()(x, shape)

I was wondering what is the reason for the ops.reshape() at the end? Does the specialization change the shapes to some canonical form? What other functions need a resape?

@antinucleon
Copy link
Contributor

the explicitly reshape introduced for this reason:

In low-level math component such as cuBLAS/CUTLASS etc, a gemm is a strictly 2D problem, eg RCR variance:

Y: [M, N] -> gemm_rcr(X: [M, K], W: [N, K])

In pytorch or other framework, there is sugar for ND problem, eg

Y: [B, S, 4H] -> torch.functional.linear(X: [B, S, H], W: [4H, H])
# in low level, X is reshaped into [B * S, H], output Y is initially [B * S, 4H], then reshaped into [B, S, 4H] 

To lower this syntax sugar to actual low level implementation, we insert reshape in AIT

In AIT all reshape ops are 0 cost.

You can try the visualize tool as well: https://facebookincubator.github.io/AITemplate/tutorial/how_to_visualize.html

@dashesy
Copy link
Contributor Author

dashesy commented Oct 26, 2022

Thanks! That is what I was looking for.

@dashesy dashesy closed this as completed Oct 26, 2022
tissue3 pushed a commit to tissue3/AITemplate-1 that referenced this issue Feb 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants