In [None]:
import sys
from pathlib import Path

# add the utils directory to the path
sys.path.append(str(Path().resolve().parent / "utils"))

import numpy as np
from generator import ORTGenerator
from transformers import AutoTokenizer


base_model_name = "microsoft/Phi-3-mini-4k-instruct"

ep = "CUDAExecutionProvider"
ep = "CPUExecutionProvider"

if ep == "CUDAExecutionProvider":
    model_path = "models/phi3-qlora-cuda/qlora-conversion-optimization_fp16-4bit-extract/gpu-cuda_model/model.onnx"
    tiny_codes_path = "models/phi3-qlora-cuda/qlora-conversion-optimization_fp16-4bit-extract/gpu-cuda_model/adapter_weights.npz"
elif ep == "CPUExecutionProvider":
    model_path = "models/phi3-qlora-cpu/qlora-conversion-optimization_fp32-4bit-extract/cpu-cpu_model/model.onnx"
    tiny_codes_path = "models/phi3-qlora-cpu/qlora-conversion-optimization_fp32-4bit-extract/cpu-cpu_model/adapter_weights.npz"

# load weights
tiny_codes_weights = np.load(tiny_codes_path)

# create zero weights for the base model
base_zero_weights = {key: np.zeros_like(value) for key, value in tiny_codes_weights.items()}

# create random weights for control. Show that the fine-tuned adapter is doing something
random_weights = {key: np.random.rand(*value.shape).astype(value.dtype) for key, value in tiny_codes_weights.items()}

adapters = {
    "base": {
        "weights": base_zero_weights
    },
    "tiny-codes": {
        "weights": tiny_codes_weights,
        "template": "### Question: {prompt} \n### Answer:"
    },
    "random": {
        "weights": random_weights
    }
}

# load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)

# load the generator
generator = ORTGenerator(model_path, tokenizer, execution_provider=ep, device_id=6, adapters=adapters, adapter_mode="inputs")

In [None]:
prompt = "Calculate the sum of a list of integers."


for adapter in adapters:
    print("Using adapter:", adapter)
    response = generator.generate(prompt, adapter=adapter, max_gen_len=100, use_io_binding=True)
    print(response)
    print("="*100)