Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Dynamic shape fx trace #294

Merged
merged 24 commits into from Jul 12, 2023
Merged

Conversation

Aalanli
Copy link
Collaborator

@Aalanli Aalanli commented Jun 27, 2023

enable the option torch.compile(..., dynamic=True)

  • convert torch FakeTensor to hidet Symbolic Tensor
  • There may be a bug in torch.dynamo, so we filter/pre-process inputs in both the example inputs and the wrapped function
  • Altered the graph interpreter to support non-torch functions, such as builtins add, getitem, etc. that remove dependence on register functions

@Aalanli
Copy link
Collaborator Author

Aalanli commented Jun 29, 2023

Essentially there was a bug with the norm op. It tries to achieve polymorphism (in f32 vs f16) with class overloading (the fp16 task subclassed the fp32 task), but this results in incorrect behaviour when combined with the automatic mixed precision pass, as the Op was originally in fp32, which gets reforwarded in the pass, but the implement_cuda schedule template still assumes that the input is in fp32. This results in an array of fp16 inputs reinterpreted as fp32; pointers in c++ silently cast.

I think there are two ways to achieve type polymorphism in schedule templates right now.

  1. Write the kernel using generic types
  2. Write the Op and Task only using declarative definitions, then write additional Ops and Tasks that implement cuda, and derive the resolve rules. This works since the declarative definitions are polymorphic, while the resolve_variant pass happens after the automatic_mixed_precision pass.

@Aalanli Aalanli requested a review from yaoyaoding June 29, 2023 19:24
@yaoyaoding
Copy link
Member

Hi @Aalanli,

The second method is our current design. We have a some base operator (matmul, conv2d) that supports arbitrary data types and use auto scheduler to schedule. These base operators will be resove to specialized ones with specialized template, and we should check the special condition in the task definition (like in this case, we should assert the input dtype is fp16).

Copy link
Member

@yaoyaoding yaoyaoding left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Aalanli. Overall looks good to me!

@xinli-git could you also have a look at this PR (especially about the normalization part).

Comment on lines 107 to 108
# unfortunately, when dynamic=True in torch.compile, there may exist other non-tensor parameters
# in example inputs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For those dynamic shape, I am wondering if these scalar parameters are act as the shape of the input tensors. If that's the case, we can ignore those scalar parameters.

Say a torch model gives us

sample_inputs = [tensor(['m', 'n'], 'm', 'n']

We can declare the symbol variable for 'm' and 'n' (when we define the symbol tensor) and ignore the 'm' and 'n' scalar parameters.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any clue on this?



@register_function(operator.iadd)
def iadd(x: Tensor, y: Tensor):
return ops.add(x, y)
return x + y
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the x and y could be DynInt?

@yaoyaoding
Copy link
Member

I think there are two ways to achieve type polymorphism in schedule templates right now.

  1. Write the kernel using generic types
  2. Write the Op and Task only using declarative definitions, then write additional Ops and Tasks that implement cuda, and derive the resolve rules. This works since the declarative definitions are polymorphic, while the resolve_variant pass happens after the automatic_mixed_precision pass.

To be more specific, the hidet task and their schedule template should make sure: the schedule template strictly implements what the computation defines. We can take both ways you mentioned. For example, our batch_matmul schedule template can support generic types (e.g., fp32, fp16, int32, int16, int8), but it requires the shape to be [B, M, K] and [B, K, N]. For the second case, we have base operator and their variants. We should make sure that our resolve rule is correct. And also add enough assersions in the task computation definition to prevent the not supported cases (in this case, the fp16 computation did not check the data type?).

Copy link
Collaborator

@xinli-git xinli-git left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes in normalize. In principle, this is the right approach. I left two implementations initially so I could add vector load for the fp16 case in the future.

but now that there is the vector data type that Yaoyao has recently introduced, keeping op and op_fp16 in a single place is the right way to go, and I intend to do the same for reduce op

check_module(model, [x], atol=1e-2, rtol=1e-2)
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet50', pretrained=True).cuda().eval()
x = torch.randn(*shape).cuda()
check_module(model, [x], atol=1e-2, rtol=1e-2, dynamic=dynamic)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we been using the CPU path before this change?

x: Tensor = op.inputs[0]
if not is_contiguous_norm(dims, len(x.shape)):
return None
if x.dtype != dtypes.float16 or prod([x.shape[dd] for dd in dims]) % 2 != 0:
Copy link
Collaborator

@xinli-git xinli-git Jul 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removing this is safe for now, but we might need to think about how to handle it when we decide to use 2xfp16 types and the norm size is odd.

@@ -32,15 +29,6 @@ class NormalizeResolveRule(ResolveRule):
2) resolve_generic: Default case, return the output of the regular f32 reduce schedule.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the resolve_fp16 comment above

@xinli-git
Copy link
Collaborator

@yaoyaoding the part about normalize is fine as long as the current CI can pass. thanks for the notification :)

@xinli-git
Copy link
Collaborator

xinli-git commented Jul 4, 2023

Reading this again I think the problem is that op.reforward used fp32 implement_cuda when in automatic mixed precision, the input has changed the inputs to fp16?

Is my understanding correct that we should not sub-class operators? We should either write them as seperate classes or have a generate Operator / Task that works for all input data types?

Basically this line is causing the problem: https://github.com/hidet-org/hidet/blob/main/python/hidet/graph/operator.py#L166 ?

@yaoyaoding
Copy link
Member

Basically this line is causing the problem: https://github.com/hidet-org/hidet/blob/main/python/hidet/graph/operator.py#L166 ?

That line does not have problem. The reforward will create the task again based on the new inputs and parameters. The problem is that the task did not check the data type. If the task only support one data type, it should explicitly assert that its input has that data type. It it accepts the inputs, then its implement function SHOULD support that.

We can sub-class operator like the ElementwiseBinaryOp, UnaryElementwiseOp, etc.

@yaoyaoding
Copy link
Member

The key convention here is: keep the task computation definition and the implement function consistent.

Copy link
Member

@yaoyaoding yaoyaoding left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still not sure what the extra scalar parameters are, let's figure them out before merge this PR.

@yaoyaoding
Copy link
Member

Thanks @Aalanli !

@yaoyaoding yaoyaoding merged commit 9d51c74 into hidet-org:main Jul 12, 2023
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants