In [None]:
import sys
print(sys.path)
print(sys.executable)

In [9]:
import torch, json, gc
import numpy as np
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRotaryEmbedding
from TernaryLLM import TernaryConfig, LlamaTernaryAttention


def benchmark_attn(config: TernaryConfig, ternarization: bool = True, seq_len = 256):
    if ternarization:
        attn_layer = LlamaTernaryAttention(config, 0).to('cuda')
    else:
        attn_layer = LlamaAttention(config, 0).to('cuda')
        # attn_layer = LlamaSdpaAttention(llama_config, 0).to('cuda')
    
    torch.cuda.init()
    torch.cuda.synchronize()
 
    input_hidden_states = torch.rand((1, seq_len, llama_config.hidden_size)).to('cuda')
    rotary_emb = LlamaRotaryEmbedding(llama_config).to('cuda')
    position_ids = torch.arange(seq_len, dtype=torch.int32).unsqueeze(0).to('cuda')
    cos, sin = rotary_emb(input_hidden_states, position_ids)
    position_embeddings = (cos, sin)

    attn_spans = []
    attn_mems = []

    # warm-up
    for i in range(10):
        with torch.no_grad():
            output = attn_layer(hidden_states=input_hidden_states, position_embeddings=position_embeddings, attention_mask=None)
    # torch.cuda.empty_cache()


    for i in range(20):
        start_event = torch.cuda.Event(enable_timing=True, blocking=True)
        end_event = torch.cuda.Event(enable_timing=True, blocking=True)
        start_event.record()

        with torch.no_grad():
            output = attn_layer(hidden_states=input_hidden_states.detach(), position_embeddings=position_embeddings,  attention_mask=None)
        reserv_mem = torch.cuda.memory_reserved(0)
        alloc_mem = torch.cuda.memory_allocated(0)

        end_event.record()
        torch.cuda.synchronize()

        # collect stat
        elapsed_time = start_event.elapsed_time(end_event)
        attn_spans.append(elapsed_time)
        attn_mems.append(alloc_mem)
        
        del output
    torch.cuda.empty_cache()
    
    attn_spans.pop(0)
    attn_spans = np.array(sorted(attn_spans, reverse=(not ternarization)))
    
    return attn_spans, np.array(attn_mems)

In [None]:
## Run benchmark
with open('./config/config.json') as f:
    llama_3_1b_json = json.load(f)

llama_config = TernaryConfig(
    vocab_size=llama_3_1b_json["vocab_size"],
    hidden_size=llama_3_1b_json["hidden_size"],
    intermediate_size=llama_3_1b_json["intermediate_size"],
    attention_dropout=llama_3_1b_json["attention_dropout"],
    num_attention_heads=llama_3_1b_json["num_attention_heads"],
    head_dim=llama_3_1b_json["head_dim"],
    num_key_value_heads=llama_3_1b_json["num_key_value_heads"],
    max_position_embeddings=llama_3_1b_json["max_position_embeddings"],
    rope_theta=llama_3_1b_json["rope_theta"],
    rope_scaling=llama_3_1b_json["rope_scaling"],
    sparsity=llama_3_1b_json["sparsity"],
    uniform_sparsity=llama_3_1b_json["uniform_sparsity"],
    uniform_sparsity_block_size=llama_3_1b_json["uniform_sparsity_block_size"],
    padding=llama_3_1b_json["padding"],
    padding_size=llama_3_1b_json["padding_size"]
)

sparsities = [i/100 for i in range(70, 100, 1)]
# sparsities.append(0.999)

import sys
ternary_benchmark = False
log = True 

seq_len = 256
log_ter_time = open(f'./data/llama_3_1b_ter_attn_nonuniform_time_{seq_len}', "w")
log_ter_mem = open(f'./data/llama_3_1b_ter_attn_nonuniform_mem_{seq_len}', "w")
log_vanila_time = open(f'./data/llama_3_1b_vanila_attn_time_{seq_len}', "w")
log_vanila_mem = open(f'./data/llama_3_1b_vanila_attn_mem_{seq_len}', "w")

for sp in sparsities:
    llama_config.sparsity = sp
    ter_attn_spans, ter_attn_mems = benchmark_attn(llama_config, True, seq_len)
    print(f"TerSpMM  {sp:.2f} {ter_attn_spans[0:10].mean():.4f} ms {ter_attn_mems.mean()/1024/1024:.4f} MB")
    log_ter_time.write(f"{sp} {ter_attn_spans[0:10].mean()}\n")
    log_ter_mem.write(f"{sp} {ter_attn_mems.mean()}\n")
    torch.cuda.empty_cache()

    vanila_attn_spans, vanila_attn_mems = benchmark_attn(llama_config, False, seq_len)
    print(f"nnLinear {sp:.2f} {vanila_attn_spans[0:10].mean():.4f} ms {vanila_attn_mems.mean()/1024/1024:.4f} MB")
    log_vanila_time.write(f"{sp} {vanila_attn_spans[0:10].mean()}\n")
    log_vanila_mem.write(f"{sp} {vanila_attn_mems.mean()}\n")

    torch.cuda.empty_cache()
    gc.collect()

log_ter_time.close()
log_ter_mem.close()
log_vanila_time.close()
log_vanila_mem.close()

In [10]:
from transformers.models.llama.modeling_llama import LlamaMLP
from TernaryLLM import LlamaTernaryMLP, TernaryConfig

import torch, json, time, sys
import numpy as np

def benchmark_mlp(config: TernaryConfig, ternarization: bool = True, seq_len = 256):
    mlp_layer = None
    if ternarization:
        mlp_layer = LlamaTernaryMLP(config).to('cuda')
    else:
        mlp_layer = LlamaMLP(config).to('cuda')

    torch.cuda.init()
    torch.cuda.synchronize()

    input = torch.randn((1, seq_len, config.hidden_size)).to('cuda')

    mlp_cpu_spans = []
    mlp_spans = []
    mlp_mems = []

    # warm-up
    for _ in range(10):
        with torch.no_grad():
            output = mlp_layer(input)
    
    # benchmark
    for i in range(100):
        start_event = torch.cuda.Event(enable_timing=True, blocking=True)
        end_event = torch.cuda.Event(enable_timing=True, blocking=True)
        start_event.record()
        start_cpu = time.time()

        with torch.no_grad():
            mlp_layer(input)
        alloc_mem = torch.cuda.memory_allocated(0)

        end_event.record()
        torch.cuda.synchronize()
        end_cpu = time.time()

        # collect stat
        elapsed_time = start_event.elapsed_time(end_event)
        elapsed_cpu_time = end_cpu - start_cpu
        mlp_spans.append(elapsed_time)
        mlp_mems.append(alloc_mem)
        mlp_cpu_spans.append(elapsed_cpu_time*1000)
    
    mlp_spans.pop(0)
    mlp_spans = np.array(sorted(mlp_spans, reverse=(not ternarization)))
    mlp_cpu_spans = np.array(sorted(mlp_cpu_spans, reverse=(not ternarization)))

    return mlp_cpu_spans, mlp_spans, np.array(mlp_mems)

In [None]:
with open('./config/config.json') as f:
    llama_3_1b_json = json.load(f)

llama_config = TernaryConfig(
    vocab_size=llama_3_1b_json["vocab_size"],
    hidden_size=llama_3_1b_json["hidden_size"],
    intermediate_size=llama_3_1b_json["intermediate_size"],
    attention_dropout=llama_3_1b_json["attention_dropout"],
    num_attention_heads=llama_3_1b_json["num_attention_heads"],
    head_dim=llama_3_1b_json["head_dim"],
    num_key_value_heads=llama_3_1b_json["num_key_value_heads"],
    max_position_embeddings=llama_3_1b_json["max_position_embeddings"],
    rope_theta=llama_3_1b_json["rope_theta"],
    rope_scaling=llama_3_1b_json["rope_scaling"],
    hidden_act=llama_3_1b_json["hidden_act"],
    mlp_bias=llama_3_1b_json["mlp_bias"],
    sparsity=llama_3_1b_json["sparsity"],
    uniform_sparsity=llama_3_1b_json["uniform_sparsity"],
    uniform_sparsity_block_size=llama_3_1b_json["uniform_sparsity_block_size"],
    padding=llama_3_1b_json["padding"],
    padding_size=llama_3_1b_json["padding_size"]
)

ternary_benchmark = False
log = True
sparsities = [i/100 for i in range(70, 100, 1)]
seq_len = 256

log_ter_time = open(f'./data/llama_3_1b_ter_mlp_uniform_time_{seq_len}', "w")
log_ter_mem = open(f'./data/llama_3_1b_ter_mlp_uniform_mem_{seq_len}', "w")
log_vanila_time = open(f'./data/llama_3_1b_vanila_mlp_time_{seq_len}', "w")
log_vanila_mem = open(f'./data/llama_3_1b_vanila_mlp_mem_{seq_len}', "w")

for sp in sparsities:
    llama_config.sparsity = sp
    ter_cpu_spans, ter_spans, ter_mems = benchmark_mlp(llama_config, True, seq_len)
    print(f"TerSpMM MLP  {sp:.2f} {ter_spans[0:10].mean():.4f} ms {ter_mems.mean()/1024/1024:.4f} MB")
    log_ter_time.write(f"{sp} {ter_spans[0:10].mean()}\n")
    log_ter_mem.write(f"{sp} {ter_mems.mean()}\n")
    torch.cuda.empty_cache()

    vanila_cpu_spans, vanila_spans, vanila_mems = benchmark_mlp(llama_config, False, seq_len)
    print(f"nnLinear MLP {sp:.2f} {vanila_spans[0:10].mean():.4f} ms {vanila_mems.mean()/1024/1024:.4f} MB")
    log_vanila_time.write(f"{sp} {vanila_spans[0:10].mean()}\n")
    log_vanila_mem.write(f"{sp} {vanila_mems.mean()}\n")

    torch.cuda.empty_cache()
    gc.collect()

In [20]:
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
from TernaryLLM.TernaryLlama import (
    add_padding_to_token,
    prepare_ternary_model
)
from TernaryLLM.configuration_ternary import TernaryConfig
from huggingface_hub import login
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch, time, sys, json
import numpy as np
from accelerate import init_empty_weights
import subprocess, threading


power_monitor_lock = threading.Lock()
power_monitor = False
total_energy_joules = 0
walt_trends = []
def collect_power(interval: float = 0.1):
    global total_energy_joules
    global power_monitor
    global power_monitor_lock
    global walt_trends
    while True:
        with power_monitor_lock:
            if not power_monitor:
                break
        time.sleep(interval)
        output = subprocess.run(
            ["nvidia-smi", "--query-gpu=power.draw", "--format=csv,nounits"],
            capture_output=True,
            text=True,
        )
        power_w = float(output.stdout.split("\n")[1])
        walt_trends.append(power_w)
        total_energy_joules += power_w * interval

def benchmark_llama_raw(
    model, 
    tokenizer, 
    config: TernaryConfig, 
    text: str, 
    interval: float = 0.1,
    ternarization: bool = True
):
    model = prepare_ternary_model(model, config) if ternarization else model
    model.to('cuda')    # move to gpu

    # Tokenize the input string and add paddings to tokens
    input_ids = tokenizer(text, return_tensors="pt").input_ids
    input_ids = add_padding_to_token(input_ids, tokenizer, 4, 'cuda')
    print(f"Sequence length: {input_ids.size(1)}")

    with torch.no_grad():
        for _ in range(5):
            output = model(input_ids)
    torch.cuda.empty_cache()

    # prepare power benchmark
    global total_energy_joules
    global power_monitor
    global power_monitor_lock
    global walt_trends
    # start power measurement
    with power_monitor_lock:
        power_monitor = True
    power_monitor_thread = threading.Thread(target=collect_power, args=(interval,), daemon=True)
    power_monitor_thread.start()

    # prepare time benchmark
    gen_mems = []
    total_tokens = 0
    cpu_ts = time.monotonic()

    # start generation
    with torch.no_grad():
        for i in range(1):

            output = model.generate(
                input_ids=input_ids,
                max_new_tokens=500,
                do_sample=False,
                use_cache=False,
            )
            alloc_mem = torch.cuda.memory_allocated(0)

            # Decode the output tokens to a string
            print(output.shape[-1])
            num_tokens = output.shape[-1] - input_ids.size(1)
            total_tokens += num_tokens
            gen_mems.append(alloc_mem)
            torch.cuda.empty_cache()
    torch.cuda.synchronize()
    elapsed_time_sec = time.monotonic() - cpu_ts
    gen_mems = np.array(gen_mems)
    
    # end power measurement
    with power_monitor_lock:
        power_monitor = False
    power_monitor_thread.join()
    walt_trends = np.array(walt_trends)

    return (total_tokens, total_tokens/elapsed_time_sec, (25*input_ids.size(1))/elapsed_time_sec, gen_mems, walt_trends, total_energy_joules)

In [None]:
ternary_benchmark = False

# load ternary model config
with open('./config/config.json') as f:
    llama_3_1b_json = json.load(f)
ternary_llama_config = TernaryConfig(
    ternary_attn_linear=llama_3_1b_json["ternary_attn_linear"],
    ternary_mlp=llama_3_1b_json["ternary_mlp"],
    sparsity=llama_3_1b_json["sparsity"],
    uniform_sparsity=llama_3_1b_json["uniform_sparsity"],
    uniform_sparsity_block_size=llama_3_1b_json["uniform_sparsity_block_size"],
    padding=llama_3_1b_json["padding"],
    padding_size=llama_3_1b_json["padding_size"]
)

# load and prepare float model
model_name = "meta-llama/Llama-3.2-1B"
with init_empty_weights():
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float32,
        low_cpu_mem_usage=True,
    )
model.tie_weights()
model.to_empty(device="cuda")  # Allocates empty tensors on GPU
model.load_state_dict(model.state_dict())  # Re-initializes weights
tokenizer = AutoTokenizer.from_pretrained(model_name)

# load the dataset from ag news
dataset = load_dataset("fancyzhx/ag_news")

# prepare log file
log_time = open('./data/llama_3_1b_ter_llama_generation', "a") if ternary_benchmark else open('./data/llama_3_1b_vanilla_llama_generation', "w")

total_tokens, output_token_throughput, input_token_throughput, mems, walts, total_engry = \
    benchmark_llama_raw(
        model, tokenizer, ternary_llama_config, 
        text=dataset["train"][0]['text'], 
        interval=0.1, ternarization=ternary_benchmark)
print(f"Total Generated Token: {total_tokens} Throughput: {output_token_throughput} {input_token_throughput} | Memory: {mems.mean()/1024/1024}MB")
print(f"Power consumption: {total_engry} joules; Avg Watt: {walts.mean():.2f}W")
log_time.write(f"{llama_3_1b_json["sparsity"]} {output_token_throughput}\n")
torch.cuda.empty_cache()

log_time.close()