In [10]:
from typing import List
import torch

def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
    print(">>> my_compiler() invoked:")
    print(">>> FX graph:")
    gm.graph.print_tabular()
    print(f">>> Code:\n{gm.code}")
    return gm.forward  # return a python callable

@torch.compile(backend=my_compiler)
def foo(x, y):
    return (x + y) * x

if __name__ == "__main__":
    a, b = torch.randn(10), torch.ones(10)
    foo(a, b)

>>> my_compiler() invoked:
>>> FX graph:
opcode         name    target                   args          kwargs
-------------  ------  -----------------------  ------------  --------
placeholder    l_x_    L_x_                     ()            {}
placeholder    l_y_    L_y_                     ()            {}
call_function  add     <built-in function add>  (l_x_, l_y_)  {}
call_function  mul     <built-in function mul>  (add, l_x_)   {}
output         output  output                   ((mul,),)     {}
>>> Code:



def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
    l_x_ = L_x_
    l_y_ = L_y_
    add = l_x_ + l_y_;  l_y_ = None
    mul = add * l_x_;  add = l_x_ = None
    return (mul,)
    


In [11]:
import dis

def hello():
    print("Hello, world!")

for k in ["co_names", "co_varnames", "co_consts"]:
    print(k, getattr(hello.__code__, k))
print(dis.dis(hello))

co_names ('print',)
co_varnames ()
co_consts (None, 'Hello, world!')
  4           0 LOAD_GLOBAL              0 (print)
              2 LOAD_CONST               1 ('Hello, world!')
              4 CALL_FUNCTION            1
              6 POP_TOP
              8 LOAD_CONST               0 (None)
             10 RETURN_VALUE
None


In [12]:
import torch
def fn(x, y):
    a = torch.sin(x).cuda()
    b = torch.sin(y).cuda()
    return a + b
new_fn = torch.compile(fn, backend="inductor")
input_tensor = torch.randn(10000).to(device="cuda:0")
a = new_fn(input_tensor,input_tensor)

In [13]:
import torch 
import time
from transformers import BertTokenizer,BertModel
import torch._dynamo as dynamo

In [22]:
device = torch.device("cuda:0" if torch.cuda.is_available else "cpu")
tokenizer = BertTokenizer.from_pretrained('/dataset/crosspipe/bert-base-uncased')
model = BertModel.from_pretrained("/dataset/crosspipe/bert-base-uncased").to(device)
model = dynamo.optimize("inductor")(model)
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors='pt').to(device="cuda:0")

In [86]:
with torch.no_grad():
    torch.cuda.synchronize()
    start = time.time()
    output = model(**encoded_input)
    end = time.time()
    torch.cuda.synchronize()
    print(end -start)    

0.009456157684326172


In [81]:
num_iteration = 10
total_time = 0
for _ in range(num_iteration):
    with torch.no_grad():
        torch.cuda.synchronize()
        start = time.time()
        output = model(**encoded_input)
        end = time.time()
        torch.cuda.synchronize()
        print(end-start)
        total_time += (end -start)
        
average_time = total_time / num_iteration
print(f"averge time over {num_iteration} iterations is {average_time}")

averge time over 10 iterations is 0.004197549819946289
