In [None]:
from DWG.utils import *

In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# Load the pretrained model and tokenizer
model_name = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)

# Example usage
input_text = "Hello, how are you?"
input_ids = tokenizer.encode(input_text, return_tensors="pt")
output = model.generate(input_ids)
decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)

print(decoded_output)


In [None]:
import torch.onnx

# Set the path for saving the ONNX file
onnx_file_path = "model.onnx"

# Set the input tensor
input_ids = torch.tensor([[15496, 11, 703, 389, 345, 30]])

# Export the model to ONNX
torch.onnx.export(model, input_ids, onnx_file_path)


In [None]:
import onnx
import numpy as np
import onnxruntime as ort
#session = ort.InferenceSession("model.onnx", 
#                               providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
#                               )
#ortvalue = ort.OrtValue.ortvalue_from_numpy(input_ids.numpy())
onnx_model = onnx.load("model.onnx")
#onames = [o.name for o in onnx_model.graph.output]

#session.run(onames, {"onnx::Reshape_0": ortvalue})

In [None]:
#from DWG.utils import count_operators

def count_operators(model):
    """
    Count the number of operators in an ONNX model.
    """
    #model = onnx.load(onnx_operator)
    count = {}
    for node in model.graph.node:
        if node.op_type in count:
            count[node.op_type] += 1
        else:
            count[node.op_type] = 1
    return count

def count_macs(session):
    """
    Count MACs for supported operations. 
    Supported operations are Conv1D/Conv2D, MatMul, and Gemm at the moment.
    """
    # First check if op type is supported
    supported_ops = ["Conv", "MatMul", "Gemm"]
    count = {}
    warning_issued=False
    # Get the entire tensor shape info from onnx graph
    tensor_shapes = {tensor.name: tensor.dims for tensor in onnx_model.graph.initializer}
    
    # Use ONNX runtime session to iterate through all nodes
    for node in session.get_modelmeta().graph.node:
        # Check if the node is supported and issue warning if not
        op_type = node.op_type
        if not any([op in op_type for op in supported_ops]):
            if not warning_issued:
                print(f"Warning: {op_type} is not supported for MACs calculation.")
                warning_issued = True
            continue
        # Get the input tensor shapes
        input_shapes = [tensor_shapes[input_name] for input_name in node.input]
        # Get the output tensor shape
        output_shape = tensor_shapes[node.output[0]]
        # Calculate the MACs for Conv
        if "Conv" in op_type:
            # Get the kernel size
            kernel_size = np.prod(input_shapes[1][2:])
            # Get the number of output channels
            output_channels = output_shape[1]
            # Get the number of MACs
            macs = kernel_size * output_channels
        # Calculate the MACs for MatMul or Gemm
        elif "MatMul" in op_type or "Gemm" in op_type:
            # Get the number of MACs
            macs = input_shapes[0][1] * output_shape[1]

        # Update the count dictionary
        count[node.op_type] = count.get(node.op_type, 0) + macs
    # Also return the total MACs
    total_macs = sum(count.values())
    count['total'] = total_macs
    return count
        

# Load the ONNX model
onnx_model = onnx.load("model.onnx")

# Get the count dictionary of operators
count_dict = count_operators(onnx_model)

#count_dict

# Get the count dictionary of MACs
macs_dict = count_macs(session)


In [None]:
inferred_shapes = onnx.shape_inference.infer_shapes(onnx_model)

In [None]:
tensor_shapes = {tensor.name: tensor.dims for tensor in onnx_model.graph.initializer}

In [None]:
for node in onnx_model.graph.node:
    if node.op_type == 'Reshape':
        print(node)
        

In [None]:
tnames = list(tensor_shapes.keys())

In [None]:
# graph intermediate tensors
for node in onnx_model.graph.node:
    for i in node.input:
        if i not in tnames:
            print(i)
    for o in node.output:
        if o not in tnames:
            print(o)