diff --git a/docs/ComputationGraphOptimization.md b/docs/ComputationGraphOptimization.md
new file mode 100644
index 000000000..87615e919
--- /dev/null
+++ b/docs/ComputationGraphOptimization.md
@@ -0,0 +1,163 @@
+# 计算图优化
+
+目前buddy compiler的前端是aten ir op经过buddy compiler dynamo中的op到mlir 中op的直接翻译。由于具体执行并不在python端,所以有些优化需要在front和mlir端同时做。
+
+![image-20240522111515951](./Images/Buddy-mlirFrontLowerStruct.png)
+
+## 算子融合
+
+**背景:**计算图在执行过程中存在两种问题:①内存墙。内存墙是部分访存密集型算子io操作过于密集,造成了频繁的内存交换,导致计算性能的的上限取决于内存带宽速率。②并行墙。并行墙是由于计算图中算子并行度不够高,无法充分利用GPU和NPU多核或众核的能力。
+
+早期的手写算子融合,比如对生产者消费者算子进行loop fusion,可以提高中间张量的访存局部性,可以解决内存墙的问题。近年来的XLA、TVM等AI编译框架为了解决手工融合无法泛化的问题,提出自动loop fusion,即把相邻的存在数据依赖的算子自动进行Loop空间深度融合。但Loop Fusion相邻算子节点是否可融合受限于待融合算子的Loop循环是否可以进行有效的循环合并。而随着Rammer的发布,大家发现通过算子融合也可以解决并行墙的问题。Rammer提出一种将计算图中的算子节点进行并行编排,从而提升整体计算并行度。特别是对于网络中存在可并行的分支节点,这种方式可以获得较好的并行加速效果。所以算子融合其实有两种目的,为达成两种目的所采用的方法也不尽相同。
+
+**思路:**内存墙问题:在buddy compiler front中由于lower的粒度太大,比如会直接lower到con2D上,会导致没办法做细粒度的loop fusion,只能通过小算子融合减少一些I/O操作,能否减少GPU kernel的开销要看后端具体实现。front前端能做的是识别出来哪些算子可以融合,分在一个组里面,将融合的算子输入输出直连,而lower操作是不会变的。循环层面的优化要到mlir里面做。
+
+并行墙问题:可以借鉴Rammer算法,能够将buddy compiler op进行重新编排。
+
+以下列出的方案是解决内存墙问题的方案。
+
+方案①:使用TVM采取的支配树方案,基于规则实现算子融合。这种方案的好处是可以自动匹配可以融合的算子,在算子种类增多的时候无需手写融合后的算子。
+
+方案②:遍历方案。需要对整个计算DAG图进行遍历,实现典型的计算密集型+访存密集型的融合算子。好处是实现较为简单,在遍历DAG图时只需要提供算子的名称即可完成匹配,融合算子的种类可以分为如下几类kernel 融合、纵向层融合和横向层融合。
+
+![img](./Images/opfusion.png)
+
+
+
+比如上图中enlarge conv+fuse conv做了kernel间融合,conv2d 1\*1*256先填充0扩展至3\*3\*256的维度,然后将他们做fuse。这样虽然计算量增大,但可以将两次kernel的开销降为1次,访存也只需要一次。
+
+fuse conv+add横向融合中将conv和add融合,横向的意思是add和conv都会使用到split后产生的数据,所以让他们融合成一个算子,使得本应访问两次内存降为了一次。
+
+fuse conv+relu纵向融合中将conv和relu融合,纵向的意思是relu会使用到conv运算产生的结果,因此直接将两个算子融合为一个,减少中间结果的写回和访问操作。
+
+各个深度学习框架或推理框架会有自己手工融合的种类,比torch中:
+
+![image-20240523134420016](./Images/SomeFixedOpFusionCombination.png)
+
+以上两种方案都是在python前端对不同的算子进行分组,不同的是遍历方案是模式匹配,模式需要手写。而基于规则的方案是自动生成组别。但两种方案都需要在后续lower到mlir后,在mlir op中对融合后的算子进行支持和加速。比如Conv+Relu的成为划归到一个group,conv会先lower到tosa.conv2d,relu会lower到tosa.constop和tosa.maximumop,将group num也lower到对应的milr op上(作为attr?)。在lower到loop层面之后,在对循环层面做优化。
+
+## 布局转换
+
+**背景:**在深度学习领域,多维数据通过多维数组存储,比如卷积神经网络的特征图(feature map)通常用四维数组保存。即NCHW或NHWC。其中N代表batch size的大小,C代表特征图chennel数量,H代表特征图高度,W代表特征图宽度。深度学习框架中,布局一般为NCHW或NHWC。
+
+NCHW布局中,同一个通道的数据值连续排布,更适合需要对每个通道单独运算的操作,如 MaxPooling。NCHW 计算时需要的存储更多,适合GPU运算,利用 GPU 内存带宽较大并且并行性强的特点.
+
+![image-20240521155427877](./Images/NCHWLayoutExample.png)
+
+NHWC 布局中,其不同通道中的同一位置元素顺序存储,因此更适合那些需要对不同通道的同一数据做某种运算的操作,比如“Conv1x1”。NHWC 更适合多核CPU运算,CPU 内存带宽相对较小,每个数据计算的时延较低,临时空间也很小,有时计算机采取异步的方式边读边算来减小访存时间,因此计算控制灵活且复杂
+
+![image-20240521155723083](./Images/NHWCLayoutExample.png)
+
+**思路:**在front前端,根据每个算子的device和运算类型,决定最友好的布局格式。测试算子添加layout transform pass后的运行时间。如果添加layout transform pass后,减少的计算时间大于做layout transform的时间,那么对这个算子采取layout transform pass,否则不做layout transform pass。
+
+在计算完成后,如果结果的layout 符合消费者最友好的布局格式,则直接传递给消费者。否则转换回输入的layout。
+
+同时,不仅feature map可以做layout transform,weight也可以做layout transform,通过对weight 做 layout transform将计算期间的列主序访问转换为行主序访问,不过这一般在算子层面做优化。
+
+## 内存分配
+
+**背景:**推理过程中的内存分配大概可以分为以下四类:输入input内存分配、算子输出tensor内存分配、算子attr如weight、bias内存分配和算子运行时一些额外运算buffer的申请。
+
+**思路:**内存分配算法可以分成两个部分:张量生命期的分析和内存分配。首先,给定计算图之后,唯一决定张量生命期的就是节点(算子)的执行顺序。在推理时由于我们已经完全拿到了model的DAG图和确切的执行顺序,所以一般只要以某种拓扑序要遍历计算图就可以生成一个依赖正确的节点的执行顺序,进而决定出每个张量的生命期,即分配和释放的时间点。可以在每个node上添加时间点得到。
+
+在推理阶段,整个计算图已经确定下来了,那么计算图所用到的所有内存已经确定。可以在导入计算图时对计算图所使用的内存进行虚拟分配,实际分配在mlir中分配。虚拟分配时,根据节点执行顺序,分析计算图每个node相关联张量的张量生命期,计算得到每个tensor所分配内存块的信息,比如内存块大小、内存块id,分配内存的顺序等。
+
+**具体操作:**如果节省内存?
+
+①原地更新、替换操作:
+
+![Inplace op](./Images/alloc_inline.png)
+
+当某个算子的接下来的一个算子是element-wise函数,可以直接在上一个算子的内存空间上进行计算,无需额外申请空间。如图A到B到C到E算子,整个计算链没有申请额外空间。不过需要注意的是如下图所示吗,当B算子有一个额外consumer F做non-element-wise操作,不能再在A算子的内存上直接计算了,需要另外处理。
+
+![Inplace trap](./Images/alloc_inline_trap.png)
+
+②内存共享基础——内存池算法
+
+使用内存池算法,维护一个内存分配器。遍历一遍DAG图后提前算好各个算子的生命周期、孩子父母名称数量、需要的中间结果buffer和activation所占用的体积,按照每个张量的分配和释放顺序依次向内存池申请和释放对应大小的内存空间,并记录每个张量分配的地址偏移。当一个张量被释放回内存池时,后续的张量分配就可以自动复用前面的空间(具体请看后面的类寄存器分配算法)。当所有张量分配完时,内存池使用到的最大内存空间即为执行该计算图所需要的最小内存。在真实的运行时,我们只需要在内存中申请一块该大小的内存空间,并按照之前的记录的地址偏移为每个张量分配内存即可。这样即可以优化总内存的占用量,也可以避免运行时的内存分配维护开销。
+
+③内存共享算法——类寄存器分配算法
+
+以每个变量作为节点,变量之间的重叠生命周期为边,构造冲突图,然后运行图着色算法。该算法如下图所示。按拓扑顺序遍历整个图,并使用计数器来记录还有多少个依赖于该节点并且还未进行计算的节点数目,即节点的生命周期。如果当前操作的输入变量没有被其他操作引用,即输入变量的计数器为1,那么当前操作的输出变量就可以使用置换操作。临时标签(图中的红、绿、蓝标签)用于指示内存共享。当节点计数器变为0时,可以回收节点的标签,右上角的矩形框里表示回收到的可以被重新利用的内存,当节点的生命周期没有重叠时,另一个节点使用回收标签就会发生内存共享。不过需要注意的是,图着色算法的复杂度为O(n^2^),为np完全问题,计算图优化的时间可能会延长。
+
+![img](./Images/MemoryAllocExample.png)
+
+或者用一个更简单的启发式贪心算法:
+
+当某个算子的子节点都已经计算完毕,可以认为其生命周期已经结束,weight所占用的空间和算子输出activation所使用的内存空间就可以被复用。如图所示,当B算子的子节点C计算完成后,存储B节点输出activation的内存置为空闲。当E节点开始时,遍历生命周期已经结束的算子,如果这之中的算子所申请的空间比E节点所需要的空间大,那么复用这个算子申请的空间。
+
+![Normal Sharing](./Images/MemoryAllocNormalExample.png)
+
+
+
+以上算法是在前端优化时做,相当于用动态内存分配模拟出了一份静态内存分配方案,在运行时可以直接根据分配到的内存id或地址偏移直接申请内存空间。
+
+## 常量折叠
+
+在front端做常量折叠或者传播,由于粒度很粗,只能做一些简单的op层面的常量折叠或传播。
+
+**思路:**
+
+①传统编译器理念上的常量折叠
+
+比如一个算子没有变量输入,那么其可以直接作为常量输出给他的consumer。
+
+②某些算子与数据形状shape有关系
+
+比如tensorflow中的size,shape、rank算子,这些算子和输入的数据没有关系,只和输入的数据的形状有关系。在推理阶段,拿到的图都是静态图,shape已经确定,所以可以提前将这些值计算出来替换为const。(目前的front在写buddy compiler op时其实已经实现了这个功能)
+
+③针对某些特殊规则算子的优化
+
+最经典的例子BN折叠,由于在推理阶段,每个BN层的参数已经固定,BN就像是对上一层的结果进行简单的线性转换。由于卷积也是一个线性变换,所以两个操作可以合并成一个单一的线性变换,这会删除很多不必要的参数和操作数量。合并后的conv操作如图所示
+
+![image-20240522220406922](./Images/BNFolder.png)
+
+这和算子融合有些区别,这个可以直接在图优化过程中直接将conv操作和BN操作融合起来,删除BN操作,实际推理计算过程中就没有BN算子了。不过需要对各种类型的conv都要写一下,比如group conv、transpose conv。
+
+## 死代码消除
+
+在front端做死代码消除,由于粒度很粗,只能做一些简单的op层面的算子消除。
+
+**思路:**①在前端,通过遍历DAG图,找出没有child的node,然后删除这个node。或者删除一些完全用不到的op比如printop,noop,dropoutop等。
+
+②一些算子在设置了某些参数后变的没有意义。比如Add a 0、Mul a 1、Transpose、TensorConverter等这一类的算子,对这类算子遍历DAG图时检查其参数,确保其有意义。
+
+③Op位置无意义。比如当Unsqueeze Op的输入是const Op时,可以将const Op进行Unsqueeze操作后直接删除Unsqueeze Op。模型的输出后接了一个内存排布,没有意义。
+
+③类型重复算子。比如有两个来连续的reshape,可以直接合成一个去掉一个。
+
+④Op前后反义。比如前一个时expand操作,后一个时squeeze操作,就可以删除。比如合并后又进行拆分,同样可以删除这两个op
+
+## 代数化简
+
+**背景:**利用交换律、结合律等规律调整途中算子的执行顺序,或者时删除不必要的算子。具体可以分为以下几个部分①算数化简:通过利用代数之间算术运算法则,在计算图中可以确定优化的运算符执行顺序,从而用新的运算符替换原有复杂的运算符组合。②运行化简:减少运算或者执行时候,冗余的算子或者算子对③广播化简:多个张量形状 Shape 不同,需要进行广播将张量的形状拓展为相同 shape 再进行运算,化简为最小计算所需的广播运算数量。
+
+**思路:**代数化简其实就是子图替换的一套规则,我们需要将这些规则写到pass里面做替换。
+
+比如以下几个例子:
+
+①算数化简
+
+![image-20240523140405325](./Images/ArithmeticSimRule1.png)
+
+
+
+
+
+
+
+②运行化简
+
+•对合算子化简:
+
+
+
+•幂等算子化简:
+
+
+
+③广播化简
+
+
+
+
\ No newline at end of file
diff --git a/docs/Images/ArithmeticSimRule1.png b/docs/Images/ArithmeticSimRule1.png
new file mode 100644
index 000000000..d1b9326a1
Binary files /dev/null and b/docs/Images/ArithmeticSimRule1.png differ
diff --git a/docs/Images/ArithmeticSimRule2.png b/docs/Images/ArithmeticSimRule2.png
new file mode 100644
index 000000000..450419808
Binary files /dev/null and b/docs/Images/ArithmeticSimRule2.png differ
diff --git a/docs/Images/ArithmeticSimplification.png b/docs/Images/ArithmeticSimplification.png
new file mode 100644
index 000000000..e439cb90f
Binary files /dev/null and b/docs/Images/ArithmeticSimplification.png differ
diff --git a/docs/Images/ArithmeticSimplification2.png b/docs/Images/ArithmeticSimplification2.png
new file mode 100644
index 000000000..75618e262
Binary files /dev/null and b/docs/Images/ArithmeticSimplification2.png differ
diff --git a/docs/Images/BNFolder.png b/docs/Images/BNFolder.png
new file mode 100644
index 000000000..982ecd1f7
Binary files /dev/null and b/docs/Images/BNFolder.png differ
diff --git a/docs/Images/BroadcastSim.png b/docs/Images/BroadcastSim.png
new file mode 100644
index 000000000..357c04b92
Binary files /dev/null and b/docs/Images/BroadcastSim.png differ
diff --git a/docs/Images/Buddy-mlirFrontLowerStruct.png b/docs/Images/Buddy-mlirFrontLowerStruct.png
new file mode 100644
index 000000000..aa3a90b84
Binary files /dev/null and b/docs/Images/Buddy-mlirFrontLowerStruct.png differ
diff --git a/docs/Images/IdempotentOperatorSim.png b/docs/Images/IdempotentOperatorSim.png
new file mode 100644
index 000000000..3b62c6573
Binary files /dev/null and b/docs/Images/IdempotentOperatorSim.png differ
diff --git a/docs/Images/InvolutionOperatorSim.png b/docs/Images/InvolutionOperatorSim.png
new file mode 100644
index 000000000..6431ab770
Binary files /dev/null and b/docs/Images/InvolutionOperatorSim.png differ
diff --git a/docs/Images/MemoryAllocExample.png b/docs/Images/MemoryAllocExample.png
new file mode 100644
index 000000000..036bf8987
Binary files /dev/null and b/docs/Images/MemoryAllocExample.png differ
diff --git a/docs/Images/MemoryAllocNormalExample.png b/docs/Images/MemoryAllocNormalExample.png
new file mode 100644
index 000000000..373717be5
Binary files /dev/null and b/docs/Images/MemoryAllocNormalExample.png differ
diff --git a/docs/Images/NCHWLayoutExample.png b/docs/Images/NCHWLayoutExample.png
new file mode 100644
index 000000000..e3528c871
Binary files /dev/null and b/docs/Images/NCHWLayoutExample.png differ
diff --git a/docs/Images/NHWCLayoutExample.png b/docs/Images/NHWCLayoutExample.png
new file mode 100644
index 000000000..477e0dc30
Binary files /dev/null and b/docs/Images/NHWCLayoutExample.png differ
diff --git a/docs/Images/SomeFixedOpFusionCombination.png b/docs/Images/SomeFixedOpFusionCombination.png
new file mode 100644
index 000000000..a667a2382
Binary files /dev/null and b/docs/Images/SomeFixedOpFusionCombination.png differ
diff --git a/docs/Images/alloc_inline.png b/docs/Images/alloc_inline.png
new file mode 100644
index 000000000..5d0ffc7c7
Binary files /dev/null and b/docs/Images/alloc_inline.png differ
diff --git a/docs/Images/alloc_inline_trap.png b/docs/Images/alloc_inline_trap.png
new file mode 100644
index 000000000..99f0acc64
Binary files /dev/null and b/docs/Images/alloc_inline_trap.png differ
diff --git a/docs/Images/image-20240523140737403.png b/docs/Images/image-20240523140737403.png
new file mode 100644
index 000000000..7df036f75
Binary files /dev/null and b/docs/Images/image-20240523140737403.png differ
diff --git a/docs/Images/opfusion.png b/docs/Images/opfusion.png
new file mode 100644
index 000000000..617b2b527
Binary files /dev/null and b/docs/Images/opfusion.png differ
diff --git a/examples/BuddyVGG/CMakeLists.txt b/examples/BuddyVGG/CMakeLists.txt
new file mode 100644
index 000000000..5d0dd918f
--- /dev/null
+++ b/examples/BuddyVGG/CMakeLists.txt
@@ -0,0 +1,60 @@
+add_custom_command(
+ OUTPUT ${BUDDY_EXAMPLES_DIR}/BuddyVGG/forward.mlir ${BUDDY_EXAMPLES_DIR}/BuddyVGG/subgraph0.mlir ${BUDDY_EXAMPLES_DIR}/BuddyVGG/arg0.data
+ COMMAND python3 ${BUDDY_EXAMPLES_DIR}/BuddyVGG/buddy-vgg-import.py
+ COMMENT "Generating forward.mlir, subgraph0.mlir and parameter files"
+)
+
+add_custom_command(
+ OUTPUT forward.o
+ COMMAND ${LLVM_MLIR_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyVGG/forward.mlir
+ -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith), empty-tensor-to-alloc-tensor, convert-elementwise-to-linalg, arith-bufferize, func.func(linalg-bufferize, tensor-bufferize), func-bufferize)" |
+ ${LLVM_MLIR_BINARY_DIR}/mlir-opt
+ -pass-pipeline "builtin.module(func.func(buffer-deallocation-simplification, convert-linalg-to-loops), eliminate-empty-tensors, func.func(llvm-request-c-wrappers),convert-math-to-llvm, convert-math-to-libm, convert-scf-to-cf, convert-arith-to-llvm, expand-strided-metadata, finalize-memref-to-llvm, convert-func-to-llvm, reconcile-unrealized-casts)" |
+ ${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir |
+ ${LLVM_MLIR_BINARY_DIR}/llvm-as |
+ ${LLVM_MLIR_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O0 -o ${BUDDY_BINARY_DIR}/../examples/BuddyVGG/forward.o
+ DEPENDS ${BUDDY_EXAMPLES_DIR}/BuddyVGG/forward.mlir
+ COMMENT "Building forward.o"
+ VERBATIM)
+
+add_custom_command(
+ OUTPUT subgraph0.o
+ COMMAND ${LLVM_MLIR_BINARY_DIR}/mlir-opt ${BUDDY_EXAMPLES_DIR}/BuddyVGG/subgraph0.mlir
+ -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named, tosa-to-linalg, tosa-to-tensor, tosa-to-arith))" |
+ ${BUDDY_BINARY_DIR}/buddy-opt
+ -eliminate-empty-tensors
+ -convert-tensor-to-linalg
+ -linalg-bufferize
+ -convert-linalg-to-affine-loops
+ -lower-affine
+ -func-bufferize-dynamic-offset
+ -arith-bufferize
+ -tensor-bufferize
+ -buffer-deallocation
+ -finalizing-bufferize
+ -convert-vector-to-scf
+ -expand-strided-metadata
+ -convert-vector-to-llvm
+ -convert-arith-to-llvm
+ -finalize-memref-to-llvm
+ -convert-scf-to-cf
+ -llvm-request-c-wrappers
+ -convert-arith-to-llvm
+ -convert-func-to-llvm
+ -reconcile-unrealized-casts |
+ ${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir |
+ ${LLVM_MLIR_BINARY_DIR}/llvm-as |
+ ${LLVM_MLIR_BINARY_DIR}/llc -filetype=obj -relocation-model=pic -O0 -o ${BUDDY_BINARY_DIR}/../examples/BuddyVGG/subgraph0.o
+ DEPENDS ${BUDDY_EXAMPLES_DIR}/BuddyVGG/subgraph0.mlir
+ COMMENT "Building subgraph0.o"
+ VERBATIM)
+
+add_library(VGG STATIC subgraph0.o forward.o)
+
+SET_TARGET_PROPERTIES(VGG PROPERTIES LINKER_LANGUAGE C)
+
+add_executable(buddy-vgg-run buddy-vgg-main.cpp)
+target_link_directories(buddy-vgg-run PRIVATE ${LLVM_MLIR_LIBRARY_DIR})
+
+set(BUDDY_VGG_LIBS VGG mlir_c_runner_utils ${OpenCV_LIBS})
+target_link_libraries(buddy-vgg-run ${BUDDY_VGG_LIBS})
diff --git a/examples/BuddyVGG/README.md b/examples/BuddyVGG/README.md
new file mode 100644
index 000000000..2b7889d82
--- /dev/null
+++ b/examples/BuddyVGG/README.md
@@ -0,0 +1,28 @@
+# Buddy Compiler VGG Example
+
+## Introduction
+This example shows how to use Buddy Compiler to compile a VGG model to MLIR code then run it. T
+## How to run
+1. Ensure that LLVM, Buddy Compiler and the Buddy Compiler python packages are installed properly. You can refer to [here](https://github.com/buddy-compiler/buddy-mlir) for more information and do a double check.(Note: 1.You should build llvm and buddy-mlir with python binding choice, and you'd better create a virtual env for that. 2. In your env, version of python is required below 3.12 because it seem that torch.compile() only support python<=3.11 now)
+
+2. Set the `PYTHONPATH` environment variable.
+```bash
+$ export PYTHONPATH=/path-to-buddy-mlir/llvm/build/tools/mlir/python_packages/mlir_core:/path-to-buddy-mlir/build/python_packages:${PYTHONPATH}
+```
+
+3. Activate your python environment.
+
+```bash
+$ conda activate your-env
+```
+
+4. Build and run the VGG example
+```bash
+$ cd buddy-mlir/build
+$ cmake -G Ninja .. -DBUDDY_VGG_EXAMPLES=ON
+$ ninja buddy-vgg-run
+$ cd bin
+$ ./buddy-vgg-run
+```
+
+5. Enjoy it!
diff --git a/examples/BuddyVGG/buddy-vgg-import.py b/examples/BuddyVGG/buddy-vgg-import.py
new file mode 100644
index 000000000..5dda772c8
--- /dev/null
+++ b/examples/BuddyVGG/buddy-vgg-import.py
@@ -0,0 +1,74 @@
+# ===- buddy-lenet-import.py ---------------------------------------------------
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# ===---------------------------------------------------------------------------
+#
+# This is the LeNet model AOT importer.
+#
+# ===---------------------------------------------------------------------------
+
+import os
+from pathlib import Path
+
+import torchvision
+import numpy as np
+import torch
+from torch._inductor.decomposition import decompositions as inductor_decomp
+
+from buddy.compiler.frontend import DynamoCompiler
+from buddy.compiler.graph import GraphDriver
+from buddy.compiler.graph.transform import simply_fuse
+from buddy.compiler.ops import tosa
+# from torchsummary import summary
+
+# def model_parameter_count(model):
+# for name,parameters in model.named_parameters():
+# print(name,':',parameters.size())
+# summary(model,input_size=(3,224,224),batch_size=1)
+
+model = torchvision.models.vgg16()
+model = model.eval()
+#model_parameter_count(model)
+# Initialize Dynamo Compiler with specific configurations as an importer.
+dynamo_compiler = DynamoCompiler(
+ primary_registry=tosa.ops_registry,
+ aot_autograd_decomposition=inductor_decomp,
+)
+
+data = torch.randn([1, 3, 224, 224])
+# Import the model into MLIR module and parameters.
+with torch.no_grad():
+ graphs = dynamo_compiler.importer(model, data)
+
+assert len(graphs) == 1
+graph = graphs[0]
+params = dynamo_compiler.imported_params[graph]
+pattern_list = [simply_fuse]
+graphs[0].fuse_ops(pattern_list)
+driver = GraphDriver(graphs[0])
+driver.subgraphs[0].lower_to_top_level_ir()
+path_prefix = os.path.dirname(os.path.abspath(__file__))
+with open(os.path.join(path_prefix, "subgraph0.mlir"), "w") as module_file:
+ print(driver.subgraphs[0]._imported_module, file=module_file)
+with open(os.path.join(path_prefix, "forward.mlir"), "w") as module_file:
+ print(driver.construct_main_graph(True), file=module_file)
+
+params = dynamo_compiler.imported_params[graph]
+current_path = os.path.dirname(os.path.abspath(__file__))
+
+float32_param = np.concatenate(
+ [param.detach().numpy().reshape([-1]) for param in params]
+)
+
+float32_param.tofile(Path(current_path) / "arg0.data")
diff --git a/examples/BuddyVGG/buddy-vgg-main.cpp b/examples/BuddyVGG/buddy-vgg-main.cpp
new file mode 100644
index 000000000..2c195faad
--- /dev/null
+++ b/examples/BuddyVGG/buddy-vgg-main.cpp
@@ -0,0 +1,153 @@
+//===- buddy-vgg-main.cpp -----------------------------------------------===//
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+//===----------------------------------------------------------------------===//
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+constexpr size_t ParamsSize = 138357544;
+const std::string ImgName = "YuTu.png";
+
+/// Declare vgg forward function.
+extern "C" void _mlir_ciface_forward(MemRef *output,
+ MemRef *arg0,
+ Img *input);
+
+/// Function for preprocessing the image to match model input requirements.
+const cv::Mat imagePreprocessing() {
+ // Get the directory of the vgg example and construct the image path.
+ std::string vggDir = getenv("VGG_EXAMPLE_PATH");
+ std::string imgPath = vggDir + "/images/" + ImgName;
+ // Read the image in grayscale mode.
+ cv::Mat inputImage = cv::imread(imgPath, cv::IMREAD_GRAYSCALE);
+ assert(!inputImage.empty() && "Could not read the image.");
+ cv::Mat resizedImage;
+ int imageWidth = 224;
+ int imageHeight = 224;
+ // Resize the image to 28x28 pixels.
+ cv::resize(inputImage, resizedImage, cv::Size(imageWidth, imageHeight),
+ cv::INTER_LINEAR);
+ return resizedImage;
+}
+
+/// Print [Log] label in bold blue format.
+void printLogLabel() { std::cout << "\033[34;1m[Log] \033[0m"; }
+
+/// Load parameters into data container.
+void loadParameters(const std::string ¶mFilePath,
+ MemRef ¶ms) {
+ const auto loadStart = std::chrono::high_resolution_clock::now();
+ // Open the parameter file in binary mode.
+ std::ifstream paramFile(paramFilePath, std::ios::in | std::ios::binary);
+ if (!paramFile.is_open()) {
+ throw std::runtime_error("[Error] Failed to open params file!");
+ }
+ printLogLabel();
+ std::cout << "Loading params..." << std::endl;
+ printLogLabel();
+ // Print the canonical path of the parameter file.
+ std::cout << "Params file: " << std::filesystem::canonical(paramFilePath)
+ << std::endl;
+ // Read the parameter data into the provided memory reference.
+ paramFile.read(reinterpret_cast(params.getData()),
+ sizeof(float) * (params.getSize()));
+ if (paramFile.fail()) {
+ throw std::runtime_error("Error occurred while reading params file!");
+ }
+ paramFile.close();
+ const auto loadEnd = std::chrono::high_resolution_clock::now();
+ const std::chrono::duration loadTime =
+ loadEnd - loadStart;
+ printLogLabel();
+ std::cout << "Params load time: " << (double)(loadTime.count()) / 1000
+ << "s\n"
+ << std::endl;
+}
+
+/// Softmax function to convert logits to probabilities.
+void softmax(float *input, size_t size) {
+ size_t i;
+ float max_value = -INFINITY;
+ double sum = 0.0;
+ // Find the maximum value in the input array for numerical stability.
+ for (i = 0; i < size; ++i) {
+ if (max_value < input[i]) {
+ max_value = input[i];
+ }
+ }
+ // Calculate the sum of the exponentials of the input elements, normalized by
+ // the max value.
+ for (i = 0; i < size; ++i) {
+ sum += exp(input[i] - max_value);
+ }
+ // Normalize the input array with the softmax calculation.
+ for (i = 0; i < size; ++i) {
+ input[i] = exp(input[i] - max_value) / sum;
+ }
+}
+
+int main() {
+ // Print the title of this example.
+ const std::string title = "VGG Inference Powered by Buddy Compiler";
+ std::cout << "\033[33;1m" << title << "\033[0m" << std::endl;
+
+ // Preprocess the image to match the input requirements of the model.
+ cv::Mat image = imagePreprocessing();
+
+ // Define the sizes of the input and output tensors.
+ intptr_t sizesInput[4] = {1, 3, 224, 224};
+ intptr_t sizesOutput[2] = {1, 1000};
+
+ // Create input and output containers for the image and model output.
+ Img input(image, sizesInput, true);
+ MemRef output(sizesOutput);
+
+ // Load model parameters from the specified file.
+ std::string vggDir = getenv("VGG_EXAMPLE_PATH");
+ std::string paramsDir = vggDir + "/arg0.data";
+ MemRef paramsContainer({ParamsSize});
+ loadParameters(paramsDir, paramsContainer);
+
+ // Call the forward function of the model.
+ _mlir_ciface_forward(&output, ¶msContainer, &input);
+
+ // Apply softmax to the output logits to get probabilities.
+ auto out = output.getData();
+ softmax(out, 1000);
+
+ // Find the classification and print the result.
+ float maxVal = 0;
+ float maxIdx = 0;
+ for (int i = 0; i < 1000; ++i) {
+ if (out[i] > maxVal) {
+ maxVal = out[i];
+ maxIdx = i;
+ }
+ }
+
+ std::cout << "Classification: " << maxIdx << std::endl;
+ std::cout << "Probability: " << maxVal << std::endl;
+
+ return 0;
+}
diff --git a/examples/BuddyVGG/images/YuTu.png b/examples/BuddyVGG/images/YuTu.png
new file mode 100644
index 000000000..91868fc8f
Binary files /dev/null and b/examples/BuddyVGG/images/YuTu.png differ
diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt
index 7ec0d3b4f..805523ee5 100644
--- a/examples/CMakeLists.txt
+++ b/examples/CMakeLists.txt
@@ -16,6 +16,10 @@ if (BUDDY_LENET_EXAMPLES)
add_subdirectory(BuddyLeNet)
endif()
+if (BUDDY_VGG_EXAMPLES)
+ add_subdirectory(BuddyVGG)
+endif()
+
if(BUDDY_DSL_EXAMPLES)
add_subdirectory(ToyDSL)
endif()
diff --git a/frontend/Python/frontend.py b/frontend/Python/frontend.py
index bd92a8074..1f7c08bec 100644
--- a/frontend/Python/frontend.py
+++ b/frontend/Python/frontend.py
@@ -158,6 +158,7 @@ def __init__(
"where.self": WhereOp,
"sqrt.default": SqrtOp,
"reciprocal.default": ReciprocalOp,
+ "_adaptive_avg_pool2d.default": AdaptiveAvgPool2dOP,
}
@property
diff --git a/frontend/Python/graph/operation.py b/frontend/Python/graph/operation.py
index 903b12865..ab12b7384 100644
--- a/frontend/Python/graph/operation.py
+++ b/frontend/Python/graph/operation.py
@@ -447,7 +447,13 @@ def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ReduceType
self._layout = "NCHW"
-
+
+class AdaptiveAvgPool2dOP(Op):
+ def __init__(self) -> None:
+ super().__init__()
+ self._op_type = OpType.ReduceType
+ self._layout = "NCHW"
+
class CallOp(Op):
def __init__(self) -> None:
super().__init__()
diff --git a/frontend/Python/ops/tosa.py b/frontend/Python/ops/tosa.py
index e5fe9a4e3..08e487ac7 100644
--- a/frontend/Python/ops/tosa.py
+++ b/frontend/Python/ops/tosa.py
@@ -21,7 +21,7 @@
import array
from typing import Dict, List, Tuple, Union
import numpy
-
+import math
import mlir.ir as ir
from mlir.dialects import tensor, tosa
@@ -57,6 +57,7 @@
SigmoidOp,
ReciprocalOp,
MeanOp,
+ AdaptiveAvgPool2dOP,
)
from .utils import *
@@ -961,7 +962,35 @@ def maxpool2d_op(node: MaxPool2dOp, symbol_table):
permute_result_type, op.result, perm_const_op.results[0]
)
return op
-
+def adaptive_avg_pool2d_op(node:AdaptiveAvgPool2dOP, symbol_table):
+ assert len(node.args) == 2
+ #Only support NCHW now
+ assert node._layout.find("NCHW")!=-1
+
+ print(list(node.args[1]))
+ print(type(node.args[1]))
+ input1 = symbol_table.get((str(node.args[0]), 0))
+ input_size = list(ir.RankedTensorType(input1.type).shape)[-1]
+ out_shape = node.tensor_meta["shape"]
+ out_size = node.tensor_meta["shape"][-1]
+ #out_size = node.args[1]
+ stride = math.floor((input_size) / (out_size))
+ kernel = input_size - (out_size - 1) * stride
+ pad = [0]*4
+ stride = [stride]*2
+ kernel = [kernel]*2
+ pad_attr = ir._denseI64ArrayAttr(pad, None)
+ kernel_attr = ir._denseI64ArrayAttr(kernel, None)
+ stride_attr = ir._denseI64ArrayAttr(stride, None)
+
+ dtype = node.tensor_meta["dtype"]
+ result_element_type = mlir_element_type_get(dtype)
+
+ output = ir.RankedTensorType.get(out_shape, result_element_type)
+ ##acc_type type of operands i.e i32..
+ acc_type_attr = ir._typeAttr(result_element_type, None)
+ op = tosa.AvgPool2dOp(output, input1, kernel_attr, stride_attr, pad_attr,acc_type_attr)
+ return op
def convolution2d_op(node: Conv2dOp, symbol_table):
"""
Import the convolution operation.
@@ -1246,4 +1275,5 @@ def mean_op(node: MeanOp, symbol_table):
"SigmoidOp": sigmoid_op,
"ReciprocalOp": reciprocal_op,
"MeanOp": mean_op,
+ "AdaptiveAvgPool2dOP":adaptive_avg_pool2d_op,
}