In [None]:
from typing import *
import torch
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
from torch import nn

In [None]:
from python.configuration_llama import CustomLlamaConfig
from python.modeling_llama import CustomLlamaForCausalLM

In [None]:
original_config = AutoConfig.from_pretrained('../../base_model/Meta-Llama-3-8B-hf')
original_model = AutoModelForCausalLM.from_pretrained('../../base_model/Meta-Llama-3-8B-hf', torch_dtype=original_config.torch_dtype)

In [None]:
config = CustomLlamaConfig().from_pretrained('.')

In [None]:
config.activation_sparsity_type = None
config.use_spvmm = False
config.use_vmmsp = False
config.use_spvmm_cpu = False
config.use_vmmsp_cpu = False

In [None]:
config

In [None]:
model = CustomLlamaForCausalLM(config)

In [None]:
model.to(config.torch_dtype)

In [None]:
model.load_state_dict(original_model.state_dict(), strict=False)

In [None]:
for original_layer, layer in zip(original_model.model.layers, model.model.layers):
    layer.self_attn.q_proj_weight_t = nn.Parameter(original_layer.self_attn.q_proj.weight.transpose(0,1).contiguous())
    layer.self_attn.o_proj_weight_t = nn.Parameter(original_layer.self_attn.o_proj.weight.transpose(0, 1).contiguous())
    layer.mlp.down_proj_weight_t = nn.Parameter(original_layer.mlp.down_proj.weight.transpose(0, 1).contiguous())

In [None]:
thresholds = torch.load('../thresholds_0_5.pt')

In [None]:
thresholds.keys()

In [None]:
for layer_idx, layer in enumerate(model.model.layers):
    layer.mlp.gate_proj_states_thresholds.data = thresholds['gate_proj_states_thresholds'][layer_idx].to(config.torch_dtype)
    layer.mlp.gate_proj_states_thresholds_2.data = thresholds['gate_proj_states_thresholds_2'][layer_idx].to(config.torch_dtype)
    layer.self_attn.attention_inputs_thresholds.data = thresholds['attention_inputs_thresholds'][layer_idx].to(config.torch_dtype)
    layer.self_attn.attention_outputs_thresholds.data = thresholds['attention_outputs_thresholds'][layer_idx].to(config.torch_dtype)

In [None]:
model.state_dict()

In [None]:
config.register_for_auto_class()
model.register_for_auto_class(AutoModelForCausalLM)

In [None]:
config.save_pretrained('./model')
model.save_pretrained('./model')

In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('../../base_model/Meta-Llama-3-8B-hf')
tokenizer.save_pretrained('./model')