In [1]:
import torch
import torch_mlir

from transformers import BertForMaskedLM

# Wrap the bert model to avoid multiple returns problem


class BertTinyWrapper(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.bert = BertForMaskedLM.from_pretrained(
            "prajjwal1/bert-tiny", return_dict=False)

    def forward(self, data):
        return self.bert(data)[0]


model = BertTinyWrapper()
model.eval()
data = torch.randint(30522, (2, 128))
out_mlir_path = "../output/tinybert-linalg_on_tensors.mlir"

# for MHLO
#module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=True)

# for linalg on tensors
module = torch_mlir.compile(
    model, data, output_type="linalg-on-tensors", use_tracing=True)
with open(out_mlir_path, "w", encoding="utf-8") as outf:
    outf.write(str(module))

print(f"tiny bert successfully written into {out_mlir_path}")


  from .autonotebook import tqdm as notebook_tqdm
Some weights of the model checkpoint at prajjwal1/bert-tiny were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


tiny bert successfully written into ../output/tinybert-linalg_on_tensors.mlir


In [54]:
# edit the file manually
!cp ../output/tinybert-linalg_on_tensors.mlir ../output/tinybert-linalg_on_tensors-edit.mlir

# the file has a line that contains "ml_program.global"
# remove the line
!sed - i '/ml_program.global/d' ../output/tinybert-linalg_on_tensors-edit.mlir


In [51]:
# Lowering to LLVM
# !mlir-opt -convert-linalg-to-loops -convert-scf-to-std
# fmt: off
# !/working_dir/builds/llvm-project/build-x86/bin/mlir-opt \
!/home/developer/mlir/clang+llvm-16.0.0-x86_64-linux-gnu-ubuntu-18.04/bin/mlir-opt \
    --canonicalize \
    -convert-tensor-to-linalg \
    -empty-tensor-to-alloc-tensor \
    -eliminate-empty-tensors \
    -linalg-bufferize -arith-bufferize \
    -tensor-bufferize -func-bufferize \
    -finalizing-bufferize -buffer-deallocation \
    --buffer-results-to-out-params \
    --canonicalize -cse \
    -convert-linalg-to-loops \
    -convert-scf-to-cf \
    -convert-linalg-to-llvm \
    -lower-affine \
    -convert-scf-to-cf \
    --convert-memref-to-llvm \
    -convert-linalg-to-loops \
    -convert-scf-to-cf \
    -convert-linalg-to-llvm \
    -convert-memref-to-llvm \
    -convert-arith-to-llvm \
    -convert-math-to-llvm \
    --convert-math-to-libm \
    --canonicalize \
    -convert-func-to-llvm \
    -convert-cf-to-llvm \
    --test-lower-to-llvm \
    -reconcile-unrealized-casts \
    -o ../output/tinybert-llvm.mlir \
    ../output/tinybert-linalg_on_tensors-edit.mlir


In [55]:
#fmt: off
!/home/developer/mlir/clang+llvm-16.0.0-x86_64-linux-gnu-ubuntu-18.04/bin/mlir-translate \
    --mlir-to-llvmir \
    ../output/tinybert-llvm.mlir \
    -o ../output/tinybert.ll
# fmt: on


In [56]:
# The llvm file has a function called forward() which is the entry point
# for the model.

# The function has a signature like this:
# func.func @forward(%arg0: tensor<2x128xi64>) -> tensor<2x128x30522xf32> {

# It can be called from C/C++ code like this:
#   forward((int *)arg0, (int *)arg0, 0, 2, 128, 128, 1);


In [59]:
% % writefile ../output/tinybert.c

# include <stdio.h>
# include <stdlib.h>
# include <string.h>
# include <time.h>

# include "tinybert.h"

int main(int argc, char ** argv) {
    int * arg0 = (int *)malloc(2 * 128 * sizeof(int))

    // time the execution
    clock_t start, end
    double cpu_time_used
    start = clock()
    forward(arg0, arg0, 0, 2, 128, 128, 1)
    end = clock()
    cpu_time_used = ((double)(end - start)) / CLOCKS_PER_SEC
    printf("Time taken: %f seconds\\n", cpu_time_used)

    return 0
}


Overwriting ../output/tinybert.c
