In [1]:
from utils.hiearchy_simple_ppo_env import ParallelEnv
from utils.ppo_model import HiearchyModel as MyModel
from utils.consts import (
    MAX_NUM_STORES_LOADS,
    MAX_NUM_LOOPS,
    MAX_NUM_LOAD_STORE_DIM
)

import torch

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

# Define the environement:

In [2]:
CONFIG = {
    'len_trajectory': 64,
    'ppo_batch_size': 64,
    'steps':10000,
    'ppo_epochs':4,
    'logs':True,
    'entropy_coef':0.01,
    'lr':0.001,
    'truncate':5,
    'json_file':"generated_data/bigger_input_nn_(32x230x230x3)operations.json",
}

env = ParallelEnv(
    json_file=CONFIG["json_file"],
    num_env=1,
    truncate=CONFIG["truncate"],
    reset_repeat=1,
    step_repeat=1,
)

100%|██████████| 19/19 [00:00<00:00, 19.58it/s]


# Define the model:

In [3]:
# Define the model:
input_dim = MAX_NUM_LOOPS + MAX_NUM_LOOPS*MAX_NUM_LOAD_STORE_DIM*MAX_NUM_STORES_LOADS + MAX_NUM_LOOPS*MAX_NUM_LOAD_STORE_DIM + 5 + \
    MAX_NUM_LOOPS*3*CONFIG["truncate"]
print('input_dim:', input_dim)

model = MyModel(
    input_dim=input_dim,
    num_loops=MAX_NUM_LOOPS
)

input_dim: 411


# Random Run:

In [4]:
batch_state, batch_obs = env.reset()
obs = torch.cat(batch_obs).to(device)

for i in range(10):

    with torch.no_grad():
        action_index, action_log_p, values, entropy = model.sample(obs)
    

    batch_next_obs, batch_reward, batch_terminated, batch_truncated, batch_next_state, batch_final_state = env.step(batch_state, action_index, model)
    
    # print(batch_next_state[0].actions)

    for i in range(env.num_env):
        done     = batch_terminated[i] or batch_truncated[i]
        final_state = batch_final_state[i]
        # print(done)
        if done:
            speedup_metric = final_state.root_exec_time /  final_state.exec_time
            print('-'*70)
            print(final_state.operation_id)
            print('speedup:', speedup_metric)
            print('-'*70)                

    batch_state = batch_next_state
    obs = torch.cat(batch_next_obs).to(device)

# Inference:

In [5]:
model = MyModel(
    input_dim=input_dim,
    num_loops=MAX_NUM_LOOPS
)

# model.load_state_dict(torch.load('models/demo_model_bigger_nn.pt'))
model.load_state_dict(torch.load('models/demo_ppo_model.pt'))

<All keys matched successfully>

In [6]:
benchmark_operations = [
    (1, 'linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} ins(%0, %1 : tensor<32x3x230x230xf32>, tensor<64x3x7x7xf32>) outs(%12 : tensor<32x64x112x112xf32>) -> tensor<32x64x112x112xf32>'),
    (4, 'linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%18, %3 : tensor<32x64x56x56xf32>, tensor<16x64x5x5xf32>) outs(%20 : tensor<32x16x52x52xf32>) -> tensor<32x16x52x52xf32>'),
    (15, 'linalg.matmul ins(%39, %41 : tensor<32x84xf32>, tensor<84x10xf32>) outs(%43 : tensor<32x10xf32>) -> tensor<32x10xf32>'),
    (11, 'linalg.matmul ins(%32, %34 : tensor<32x120xf32>, tensor<120x84xf32>) outs(%36 : tensor<32x84xf32>) -> tensor<32x84xf32>'),
    (0, 'linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%2 : tensor<64xf32>) outs(%11 : tensor<32x64x112x112xf32>) {\n^bb0(%in: f32, %out: f32):\n  linalg.yield %in : f32\n} -> tensor<32x64x112x112xf32>'),
    (2, 'linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%13 : tensor<32x64x112x112xf32>) outs(%11 : tensor<32x64x112x112xf32>) {\n^bb0(%in: f32, %out: f32):\n  %cst_1 = arith.constant 0.000000e+00 : f32\n  %46 = arith.cmpf ugt, %in, %cst_1 : f32\n  %47 = arith.select %46, %in, %cst_1 : f32\n  linalg.yield %47 : f32\n} -> tensor<32x64x112x112xf32>'),
    (3, 'linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%4 : tensor<16xf32>) outs(%19 : tensor<32x16x52x52xf32>) {\n^bb0(%in: f32, %out: f32):\n  linalg.yield %in : f32\n} -> tensor<32x16x52x52xf32>'),
    (6, 'linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%5 : tensor<120x10816xf32>) outs(%26 : tensor<10816x120xf32>) {\n^bb0(%in: f32, %out: f32):\n  linalg.yield %in : f32\n} -> tensor<10816x120xf32>'),
    (8, 'linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%30, %6 : tensor<32x120xf32>, tensor<120xf32>) outs(%28 : tensor<32x120xf32>) {\n^bb0(%in: f32, %in_1: f32, %out: f32):\n  %46 = arith.addf %in, %in_1 : f32\n  linalg.yield %46 : f32\n} -> tensor<32x120xf32>'),
    (9, 'linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%31 : tensor<32x120xf32>) outs(%28 : tensor<32x120xf32>) {\n^bb0(%in: f32, %out: f32):\n  %cst_1 = arith.constant 0.000000e+00 : f32\n  %46 = arith.cmpf ugt, %in, %cst_1 : f32\n  %47 = arith.select %46, %in, %cst_1 : f32\n  linalg.yield %47 : f32\n} -> tensor<32x120xf32>'),
    (10, 'linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%7 : tensor<84x120xf32>) outs(%33 : tensor<120x84xf32>) {\n^bb0(%in: f32, %out: f32):\n  linalg.yield %in : f32\n} -> tensor<120x84xf32>'),
    (12, 'linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%37, %8 : tensor<32x84xf32>, tensor<84xf32>) outs(%35 : tensor<32x84xf32>) {\n^bb0(%in: f32, %in_1: f32, %out: f32):\n  %46 = arith.addf %in, %in_1 : f32\n  linalg.yield %46 : f32\n} -> tensor<32x84xf32>'),
    (13, 'linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%38 : tensor<32x84xf32>) outs(%35 : tensor<32x84xf32>) {\n^bb0(%in: f32, %out: f32):\n  %cst_1 = arith.constant 0.000000e+00 : f32\n  %46 = arith.cmpf ugt, %in, %cst_1 : f32\n  %47 = arith.select %46, %in, %cst_1 : f32\n  linalg.yield %47 : f32\n} -> tensor<32x84xf32>'),
    (14, 'linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%9 : tensor<10x84xf32>) outs(%40 : tensor<84x10xf32>) {\n^bb0(%in: f32, %out: f32):\n  linalg.yield %in : f32\n} -> tensor<84x10xf32>'),
    (16, 'linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%44, %10 : tensor<32x10xf32>, tensor<10xf32>) outs(%42 : tensor<32x10xf32>) {\n^bb0(%in: f32, %in_1: f32, %out: f32):\n  %46 = arith.addf %in, %in_1 : f32\n  linalg.yield %46 : f32\n} -> tensor<32x10xf32>')
]

In [None]:
# for i, operation in benchmark_operations:
for i in range(len(env.env.operations_files)):    
    
    print(f'Operation ({i}):', env.env.operations_files[i][0])
    
    # Reset the environement with the specific operation
    state, obs = env.reset(i)
    obs = torch.cat(obs).to(device)

    while True:

        with torch.no_grad():
            # Select the action using the model
            action, action_log_p, values, entropy = model.sample(obs)

        # Apply the action and get the next state
        next_obs, reward, terminated, truncated, next_state, final_state = env.step(state, action, model)
        
        done = terminated[0] or truncated[0]
        if done:
            final_state = final_state[0]
            speedup_metric = final_state.root_exec_time /  final_state.exec_time
            print('Base execution time:', final_state.root_exec_time)
            print('New execution time:', final_state.exec_time)
            print('speedup:', speedup_metric)
            break             

        state = next_state
        obs = torch.cat(next_obs).to(device)

    print('\n\n\n')

# Demo 2:

In [8]:
from fusion_utils.transforms import *
from utils.observation_utils import *
from data_generation.data_generation_from_model import transform_wrapper
import numpy as np

In [9]:
operation = """linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"], tag = "operation_0"} ins(%4 : tensor<16xf32>) outs(%19 : tensor<32x16x52x52xf32>) {\n^bb0(%in: f32, %out: f32):\n  linalg.yield %in : f32\n} -> tensor<32x16x52x52xf32>"""

code = transform_wrapper(operation)

old_exec_time = evaluate_code(code)
print('Base execution time:', old_exec_time / 1e9, 'ms')

new_code = code
new_code = transform_dialect_TP(new_code, 'operation_0', [4, 4, 0, 0])
new_code = transform_dialect_tile(new_code, 'operation_0', [2, 16, 4, 4])
new_code = transform_dialect_vectorise(new_code, 'operation_0')


new_exec_time = evaluate_code(new_code)
print('New execution time:', new_exec_time / 1e9, 'ms')

speedup = old_exec_time / new_exec_time
print('Speedup:', speedup)

Base execution time: 0.00132489 ms


New execution time: 2e-07 ms
Speedup: 6624.45


In [10]:
operation = """linalg.matmul {tag = "operation_0"} ins(%39, %41 : tensor<32x84xf32>, tensor<84x10xf32>) outs(%43 : tensor<32x10xf32>) -> tensor<32x10xf32>"""

code = transform_wrapper(operation)

old_exec_time = evaluate_code(code)
print('Base execution time:', old_exec_time / 1e9, 'ms')

new_code = code
new_code = transform_dialect_TP(new_code, 'operation_0', [4, 5, 0])
new_code = transform_dialect_interchange(new_code, 'operation_0', [1, 0, 2])
new_code = transform_dialect_vectorise(new_code, 'operation_0')


new_exec_time = evaluate_code(new_code)
print('New execution time:', new_exec_time / 1e9, 'ms')

speedup = old_exec_time / new_exec_time
print('Speedup:', speedup)

Base execution time: 0.000108674 ms
New execution time: 1.3e-07 ms
Speedup: 835.9538461538461


In [11]:
operation = """linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>, tag = "operation_0"} ins(%0, %1 : tensor<32x3x230x230xf32>, tensor<64x3x7x7xf32>) outs(%12 : tensor<32x64x112x112xf32>) -> tensor<32x64x112x112xf32>"""

code = transform_wrapper(operation)

old_exec_time = evaluate_code_2(code)
print('Base execution time:', old_exec_time / 1e9, 'ms')

new_code = code

new_code = transform_dialect_TP(new_code, 'operation_0', [4, 8, 0, 0, 0, 0, 0])
new_code = transform_dialect_tile(new_code, 'operation_0', [2, 8, 0, 0, 3, 7, 7])
new_code = transform_dialect_tile(new_code, 'operation_0', [2, 0, 1, 0, 0, 1, 0])
new_code = apply_conv2d_decomposition(new_code, 'operation_0')
new_code = transform_dialect_vectorise(new_code, 'operation_0')

new_exec_time = evaluate_code_2(new_code)
print('New execution time:', new_exec_time / 1e9, 'ms')

speedup = old_exec_time / new_exec_time
print('Speedup:', speedup)

Base execution time: 17.573441457 ms
New execution time: 2.590573458 ms
Speedup: 6.78361055647008
