Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Adding support for compiled model #247

Merged
merged 3 commits into from May 26, 2023

Conversation

yaoyaoding
Copy link
Member

A major update.

Compiled model

Allow hidet compile a flow graph into a CompiledModel, which contains all the weights, compiled kernels, and the graph execution library, which provides a new way to store the compiled model.

API:

graph: FlowGraph
model: hidet.runtime.CompiledModel = graph.build()
outputs = model(inputs)
model.save("model.hidet")
model_loaded = hidet.load_model("model.hidet")

Usage:

def test_load_save(device: str):
    # construct graph
    x = hidet.symbol([2, 3], device=device)
    w1 = hidet.randn([3, 4], device=device)
    w2 = hidet.randn([4, 5], device=device)
    y = hidet.ops.matmul(hidet.ops.matmul(x, w1), w2)

    # get computation graph
    graph = hidet.trace_from(y)

    # optimize the graph
    graph = hidet.graph.optimize(graph)

    # build the graph
    model = graph.build()

    # save the model
    model.save('./model.hidet')

    # load the model
    loaded_model = hidet.load_model('./model.hidet')

    # compare the results
    xx = hidet.randn([2, 3], device=device)
    y1 = graph(xx)
    y2 = model(xx)
    y3 = loaded_model(xx)

    numpy.testing.assert_allclose(y1.cpu().numpy(), y2.cpu().numpy())
    numpy.testing.assert_allclose(y1.cpu().numpy(), y3.cpu().numpy())

Hidet Script

Function Kind

We redesigned the allowed function kinds:

  • cuda_internal: this is a cuda device function, can only be called by cuda function
  • cuda_kernel: this is a cuda kernel function
  • cpu_kernel: this is a cpu kernel function
  • cpu_internal: this is a cpu function but not a kernel
  • public: this is a packed function that wraps kernel function(s), or any function that can be called from external host (e.g., python or C++ users)

Do not use packed arguments

Originally, we generate the following style of launch function

void launch(int num_args, int* arg_codes, void** args) {...}

We changed it to the normal form like

void launch(void *a, void* b, int c, float d)

Meta programming

import hidet


def test_args():
    from hidet.lang import attrs, meta, printf, int32

    with hidet.script_module() as script_module:

        @hidet.script
        def launch(args: meta.types([int, bool, float, int32]), second_args: int, thrid_args: meta.types([int32])):
            attrs.func_kind = 'public'

            printf("%d\n", args[0])
            printf("%d\n", args[1])
            printf("%f\n", args[2])
            printf("%d\n", args[3])
            printf("%d\n", second_args)
            printf("%d\n", thrid_args[0])

    module = script_module.build()
    module(1, True, 0.1, 2, 3, 4)
    print(module.source(True))

would generate:

void hidet_launch(int32_t args, bool args_1, float args_2, int32_t args_3, int32_t second_args, int32_t thrid_args) {
  printf("%d\n", args);
  printf("%d\n", args_1);
  printf("%f\n", args_2);
  printf("%d\n", args_3);
  printf("%d\n", second_args);
  printf("%d\n", thrid_args);
}

This allows us to declare a function with compilation-time-known-length parameters.

.
lint & format

.

.

.

.

.

.

.

.

.

.

.

.

.

.
@yaoyaoding yaoyaoding merged commit cf68f53 into hidet-org:main May 26, 2023
2 checks passed
@yaoyaoding yaoyaoding deleted the compiled-model branch May 26, 2023 04:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant