You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
classMlp(nn.Module):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""def__init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer="GELU",
drop=0,
):
super().__init__()
out_features=out_featuresorin_featureshidden_features=hidden_featuresorin_featuresself.fc1=nn.Linear(
in_features,
hidden_features,
specialization="fast_gelu"ifact_layer=="GELU"else"relu",
)
self.fc2=nn.Linear(hidden_features, out_features, specialization="add")
defforward(self, x, res):
shape=get_shape(x)
x=self.fc1(x)
x=self.fc2(x, res)
returnops.reshape()(x, shape)
I was wondering what is the reason for the ops.reshape() at the end? Does the specialization change the shapes to some canonical form? What other functions need a resape?
The text was updated successfully, but these errors were encountered:
the explicitly reshape introduced for this reason:
In low-level math component such as cuBLAS/CUTLASS etc, a gemm is a strictly 2D problem, eg RCR variance:
Y: [M, N] -> gemm_rcr(X: [M, K], W: [N, K])
In pytorch or other framework, there is sugar for ND problem, eg
Y: [B, S, 4H] -> torch.functional.linear(X: [B, S, H], W: [4H, H])
# in low level, X is reshaped into [B * S, H], output Y is initially [B * S, 4H], then reshaped into [B, S, 4H]
To lower this syntax sugar to actual low level implementation, we insert reshape in AIT
After reading some examples, e.g. here:
I was wondering what is the reason for the
ops.reshape()
at the end? Does the specialization change the shapes to some canonical form? What other functions need a resape?The text was updated successfully, but these errors were encountered: