Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] [Unity] [Metascheduler] cannot tune relax linear/matmul with M > 1 for cuda #227

Open
elvin-n opened this issue May 24, 2023 · 4 comments

Comments

@elvin-n
Copy link
Contributor

elvin-n commented May 24, 2023

Unable to tune linear/matmul having M value bigger than 1.

The error message is different comparing to Unity branch and this fact causes me to submit this bug, since changes in mlc-ai relax affected this use case and seems should be fixed here as well, not only in Unity

import tvm
from tvm import meta_schedule as ms
from tvm.relay.backend import Executor
from tvm import relax
from tvm.relax.testing import nn

# -------- Func definition
class Linear(nn.Module):
    def __init__(self, in_features, out_features, dtype: str, bias=False):
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(
            (out_features, in_features), dtype=dtype, name="linear_weight"
        )
        if bias:
            self.bias = nn.Parameter((out_features,), dtype=dtype, name="linear_bias")
        else:
            self.bias = None

    def forward(self, input: relax.Expr) -> relax.Var:
        return nn.emit(relax.op.linear(input, self.weight, self.bias))

bb = relax.BlockBuilder()
seq_len = 4
with bb.function("func1"):
    model = Linear(2048, 768, "float16")
    input = nn.Placeholder((seq_len, 2048), dtype="float16", name="input")
    with bb.dataflow():
        res = model(input)
        params = [
            input,
        ] + model.parameters()
        gv = bb.emit_output((res,))
    bb.emit_func_output(gv, params)

mod = bb.get()
gv = mod.get_global_var("func1")
bb.update_func(gv, mod[gv].with_attr("func1", 1))

mod = relax.pipeline.get_pipeline()(mod)
mod = relax.transform.LiftTransformParams()(mod)

mod = tvm.tir.transform.ForceNarrowIndexToInt32()(mod)

# ------ Metascheduler starts here
database = None

strategy_name = "evolutionary"
name = f"relax_linear_{seq_len}_2048_2048_768"
work_dir = f"./{name}/"
module_equality_name = "ignore-ndarray"

target = tvm.target.Target("nvidia/geforce-rtx-2060", host="llvm")
executor = Executor("graph")
mod = mod.with_attr("executor", executor)
ndk_builder = ms.builder.LocalBuilder(timeout_sec=60)
evaluator_config=ms.runner.EvaluatorConfig(
    number=3,
    repeat=1,
    min_repeat_ms=100,
    enable_cpu_cache_flush=False,
)
ms_rpc_runner = ms.runner.LocalRunner(evaluator_config=evaluator_config,
            alloc_repeat=1,
        )
ms.relax_integration.tune_relax(
    mod=mod,
    target=target,
    params={},
    work_dir=work_dir,
    max_trials_global=1024,
    strategy=strategy_name,
    builder=ndk_builder,
    runner=ms_rpc_runner,
    module_equality=module_equality_name,
)
@junrushao
Copy link
Member

Hey thanks for reporting! Would you mind elaborating what the M value is? Is it possible that it’s because the mixed usage of i32 and i64?

@elvin-n
Copy link
Contributor Author

elvin-n commented May 25, 2023

Would you mind elaborating what the M value is?

M is an input sequence length, for example. In case of dense it is batch size.

Is it possible that it’s because the mixed usage of i32 and i64?

Where dies it happen? If you refer to ForceNarrowIndexToInt32 transformation, then removing the transformation invocation does not affect behaviour

@elvin-n
Copy link
Contributor Author

elvin-n commented May 25, 2023

I found that tuning starts to work if I point seq_len = 32. In opposite to unity where tuning starts to work if I point this parameter to 16

@elvin-n
Copy link
Contributor Author

elvin-n commented May 25, 2023

One more fact - Metascheduler worked for M == 32 with commit c0e4557 but for the latest commit 5b8db51 it cannot tune for any size of M

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants