In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from torch.utils import _pytree as pytree
from iree.turbine.aot import *
from iree.compiler.ir import Context


In [2]:
#set some config values

hf_auth_token = "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"
hf_model_name = "meta-llama/Llama-2-7b-chat-hf"
state_schema_path = "llama2_state_schema.json"
with open(state_schema_path, "r+") as f:
    state_schema = pytree.treespec_loads(f.read())
prompt = """
<s>[INST] <<SYS>>
Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>> hi what are you? [/INST]
"""


In [3]:
#Initialize the tokenizer and base model from Huggingface
tokenizer = AutoTokenizer.from_pretrained(
    hf_model_name,
    use_fast=False,
    use_auth_token=hf_auth_token,
)
mod = AutoModelForCausalLM.from_pretrained(
    hf_model_name,
    torch_dtype=torch.float,
    use_auth_token=hf_auth_token,
)



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



In [4]:
#get some sample input values
initial_input = tokenizer(prompt, return_tensors="pt")
example_input_id = initial_input.input_ids


#Define some dim sizes to make the model code readable
BATCH_SIZE = 1 #Note: only 1 supported currently.
HEADS = 32
HIDDEN_DIM = 128
MAX_STEP_SEQ = 4095
#This is a maximum size past key value tensor, representing the full context window of the model. 
#We do this to make it easier for the compiler to reason about memory.
global_pkv = torch.zeros(
    size=(HEADS * 2, BATCH_SIZE, HEADS, MAX_STEP_SEQ, HIDDEN_DIM),
    dtype=torch.float32,
)
seq_step = AbstractIndex

In [5]:
#define some helper functions for manipulating the pkv state

def slice_up_to_step(global_pkv, seq_step, heads, hidden_dim):
    """
    global_pkv: the global pkv tensor
    seq_step: the current token index of the model
    heads: the number of attn heads
    hidden_dim: feature dimension size
    takes the global_pkv tensor and gets the seq_step pair for each head
    """
    all_pkv_tensors = []
    for i in range(heads * 2):
        sliced = IREE.tensor_slice(
            global_pkv, i, 0, (0, heads), (0, seq_step), (0, hidden_dim)
        )  # sequence context dim
        all_pkv_tensors.append(
            IREE.tensor_reshape(sliced, 1, heads, seq_step, hidden_dim)
        )

    return all_pkv_tensors


def update_state(global_pkv, state_updates, seq_step, heads, hidden_dim):
    """
    global_pkv: the global pkv tensor
    state_updates: the state updates output by a forward pass of the model
    seq_step: the current token index of the model
    heads: the number of attn heads
    hidden_dim: feature dimension size
    updates the global state of the model at seq_step with state_updates
    """
    all_updates = []
    for i in range(heads * 2):
        #expand dim in state updates to match the rank of global_pkv
        update = IREE.tensor_reshape(
            state_updates[i], 1, 1, heads, 1, hidden_dim
        )
        all_updates.append(
            IREE.tensor_update(global_pkv, update, i, 0, 0, seq_step, 0)
        )
    return all_updates

In [6]:
#Set up our CompiledModule class
class StateUpdateModule(CompiledModule):
    #export_parameters makes the params of the model global, so that multiple 
    # exported functions can access them without duplicating the constants.
    # We set external parameters mode to increase readability of the IR.
    params = export_parameters(mod, external=True)
    # export our global pkv tensor, making sure to set mutable=True so it 
    # can be modified by our exported functions
    global_state = export_global(abstractify(global_pkv), mutable=True)
    # this is our sequences step, its the current token index for the model
    global_seq_step = export_global(AbstractIndex, mutable=True)

    #set up our stateless jittable functions
    @jittable
    def initialize(input_ids):
        result = mod.forward(input_ids)
        state1_flat, _ = pytree.tree_flatten(result.past_key_values)
        token = torch.argmax(result.logits[:, -1, :], dim=1)
        token = token[None, :]
        return token, *state1_flat

    @jittable
    def forward(token0: torch.Tensor, *state0_flat):
        # Unpad the states.
        state0 = pytree.tree_unflatten(state0_flat, state_schema)
        result = mod.forward(token0, past_key_values=state0)
        state1_flat, _ = pytree.tree_flatten(result.past_key_values)
        # extract only the newest pkvs for each head
        state1_flat = [x[:, :, -1:, :] for x in state1_flat]
        token = torch.argmax(result.logits[:, -1, :], dim=1)
        token = token[None, :]
        return token, *state1_flat
    

    """
    run_initialize is doing the "first forward" pass of the llama model, 
    which computes the initial past_key_values for the input tokens (prompt)
    """
    def run_initialize(
        self, x=AbstractTensor(BATCH_SIZE, None, dtype=torch.int64)
    ):
        """ 
        We need to set constraints for the dynamic dimensions of the input.
        This may look like magic, but if you try and export without it, 
        torch compile will actually try to calculate the constraints for you and tell you what to add
        """
        init_const = [x.dynamic_dim(1) < MAX_STEP_SEQ]
        token, *state = self.initialize(x, constraints=init_const)
        updates = []
        #initialize the token index to the 3rd dim of the pkv output
        self.global_seq_step = IREE.tensor_dim(
            state[0], 3
        )
        #iterate through the past key values and update global state
        for i in range(HEADS * 2):
            slice_of_state = IREE.tensor_reshape(
                state[i], 1, 1, HEADS, self.global_seq_step, HIDDEN_DIM
            )
            updates.append(
                IREE.tensor_update(
                    self.global_state, slice_of_state, i, 0, 0, 0, 0
                )
            )
        return token

    def run_forward(self, x=AbstractTensor(1, None, dtype=torch.int64)):
        state_arg = slice_up_to_step(
            self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM
        )
        forw_const = [state_arg[0].dynamic_dim(2) < MAX_STEP_SEQ] + [
            x.dynamic_dim(2) == (state_arg[0].dynamic_dim(2))
            for x in state_arg[1:]
        ]
        token, *state_update = self.forward(
            x, *state_arg, constraints=forw_const
        )
        self.global_seq_step = self.global_seq_step + 1
        res = update_state(
            self.global_state,
            state_update,
            self.global_seq_step,
            HEADS,
            HIDDEN_DIM,
        )
        return token

In [7]:
#Run the export pipeline
inst = StateUpdateModule(context=Context(), import_to="IMPORT")
module_str = str(CompiledModule.get_mlir_module(inst))

[2023-10-09 18:48:30,657] torch.fx.experimental.symbolic_shapes: [INFO] create_env
[2023-10-09 18:48:30,660] torch.fx.experimental.symbolic_shapes: [INFO] create_symbol s0 = 2 for input0.size()[1]
  torch.has_cuda,
  torch.has_cudnn,
  torch.has_mps,
  torch.has_mkldnn,
[2023-10-09 18:48:30,910] torch.fx.experimental.symbolic_shapes: [INFO] eval s0 <= 4096 [guard added] (transformers/models/llama/modeling_llama.py:119 in forward)
[2023-10-09 18:48:30,918] torch.fx.experimental.symbolic_shapes: [INFO] eval Ne(s0, 4096) [guard added] (transformers/models/llama/modeling_llama.py:123 in forward)
[2023-10-09 18:48:43,109] [0/0] torch.fx.experimental.symbolic_shapes: [INFO] create_env
[2023-10-09 18:48:43,124] [0/0] torch.fx.experimental.symbolic_shapes: [INFO] create_symbol s0 = 2 for L['arg0_1'].size()[1]
[2023-10-09 18:49:00,136] [0/0] torch.fx.experimental.symbolic_shapes: [INFO] produce_guards
[2023-10-09 18:49:15,183] torch.fx.experimental.symbolic_shapes: [INFO] create_env
[2023-10-09

In [8]:
#Output a torch-ir mlir file
with open("llama2_torch.mlir", "w+") as f:
    f.write(module_str)
#TODO: run the rest of the compile pipeline and do inference