In [None]:
import sys
sys.path.append('..')
import numpy as np
import matplotlib.pyplot as plt
from hardwares.hardware_params import hardware_params
from model_analyzer import ModelAnalyzer
%load_ext autoreload
%autoreload 2

In [None]:
model_id="meta-llama/Llama-2-13b-hf"
hardware="nvidia_A6000"
analyzer=ModelAnalyzer(model_id,hardware)

In [None]:
fig=plt.figure(figsize=(5, 5))
bar_width = 0.7

for step in ['decode','prefill']:
    plt.subplot(2, 1, 1 if step=='decode' else 2)
    ax=plt.gca()  # twin axis
    # twin axis
    ax2 = plt.twinx()

    for w,a,kv,quantization in [(16,16,16,"FP16"),(4,16,16,"W4"),(4,16,4,"W4KV4"),(4,4,4,"W4A4")]:
        inference_times=[]
        weight_kv_memory_access=[]
        xs=[]
        batchsize=1
        
        for seqlen_p2 in range(8,13):
            seqlen=2**seqlen_p2
            result=analyzer.analyze(seqlen,batchsize,w,a,kv)
            inference_times.append(result["total_results"][step]["inference_time"])
            weight_kv_memory_access.append(result["total_results"][step]["load_weight"]+result["total_results"][step]["load_kv_cache"]+result["total_results"][step]["store_kv_cache"])
            xs.append(seqlen)
        
        ax.plot(xs, inference_times, label=f"{quantization} time cost")
        ax2.plot(xs, weight_kv_memory_access,'--', label=f"{quantization} W+KV memory access")
    plt.xscale('log',base=2)
    # plt.ylabel('Relative Memory Consumption', fontsize=9)
    # ax.legend(loc='upper left', fontsize=9)
    # ax2.legend(loc='upper right', fontsize=9)
    # merge legend
    lines, labels = ax.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax.legend(lines + lines2, labels + labels2, loc='upper right', fontsize=9)

    
    plt.xlabel('Sequence Length', fontsize=9)
    ax.set_ylabel('Time Cost (s)', fontsize=9)
    ax2.set_ylabel('W+KV Memory Access', fontsize=9)
    plt.tight_layout()
# plt.savefig("../output/quantization_memory_consumption.pdf",bbox_inches='tight')

In [None]:
fig=plt.figure(figsize=(5, 2))
for wbit in [16,8,4,2,1]:
    batchsizes=range(1,64)
    ys=[]
    for batchsize in batchsizes:
        result=analyzer.analyze(1024,batchsize,wbit,16,16)
        ys.append(result["total_results"]["decode"]["inference_time"])
    plt.plot(batchsizes,ys,label=f"W{wbit}" if wbit!=16 else "FP16")
plt.legend()
plt.ylabel("Inference Time (s)")
plt.xlabel("Batch Size")
plt.savefig("../output/quantization_memory_access_batch.pdf",bbox_inches='tight')

In [None]:
fig=plt.figure(figsize=(5, 2))
for wbit in [16,8,4,2,1]:
    seqlens=range(4, 11)
    seqlens=range(16,1024)
    ys=[]
    for seqlen in seqlens:
        # seqlen=2**seqlen
        result=analyzer.analyze(seqlen,1,wbit,16,16)
        ys.append(result["total_results"]["prefill"]["inference_time"])
    plt.plot(seqlens,ys,label=f"W{wbit}" if wbit!=16 else "FP16")
plt.legend()
plt.xscale('log',base=2)
plt.ylabel("Inference Time (s)")
plt.xlabel("Sequence Length")
plt.savefig("../output/quantization_memory_access_seq_len.pdf",bbox_inches='tight')