In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('..')

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from accelerate import init_empty_weights
from tqdm.auto import tqdm
import numpy as np
import pickle
import torch
import json
import os

## Custom Imports
from utils.config_utils import GenerationParams, PathConfig

from utils.torch_utils import (
    save_quant,
    load_quant,
    destruct_module_optimized,
    memory_cleanup,
    get_nonreasoning_dataset,
    load_weight,
    rsetattr,
    rgetattr,
    load_weights,
    rhasattr,
)

In [None]:
layer_idx=9

hidden_states=load_quant(os.path.join(path_config.intermediate_states, f"layer_{layer_idx-1}", f"batch_{batch_idx}")).to(device, dtype=dtype)

In [None]:
device = "cuda:0"
model_name = "../deepseek_v2_lite_awq"
n_batch=16
batch_size=8
max_length=512

generation_config = GenerationParams(
    n_batch=16,
    batch_size=8,
    max_length=512
)

path_config = PathConfig(
    model_name = model_name,
    intermediate_states = "data/intermediate_states",
    expert_states = "data/expert_states",
    expert_activations = "data/expert_activations",
)

position_ids = torch.arange(0, generation_config.max_length, dtype=torch.long, device=device).unsqueeze(0)

tokenizer=AutoTokenizer.from_pretrained(
    model_name,
    trust_remote_code=True
)

train_dataset=get_nonreasoning_dataset(tokenizer, generation_config)

In [None]:
with open(f"{model_name}/model.safetensors.index.json", "r") as f:
    weight_map = json.load(f)["weight_map"]

with init_empty_weights():
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        attn_implementation="flash_attention_2",
        low_cpu_mem_usage=True
    )

for name, parameter in model.named_parameters():
    parameter.requires_grad = False


model.train()
destruct_module_optimized(model)
memory_cleanup()

In [None]:
target_modules=[
    "model.embed_tokens.weight"
]

model=load_weights(model, model_name, weight_map, target_modules, device)

for batch_idx in tqdm(range(generation_config.n_batch), desc="Processing embeddings"):
    batch = train_dataset[generation_config.batch_size * batch_idx : generation_config.batch_size * (batch_idx + 1)]
    inputs = tokenizer(
        batch,
        max_length=generation_config.max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    ).to(device)

    hidden_states = model.model.embed_tokens(inputs["input_ids"])

    os.makedirs(os.path.join(path_config.intermediate_states, f"layer_{-1}"), exist_ok=True)
    save_quant(hidden_states, os.path.join(path_config.intermediate_states, f"layer_{-1}", f"batch_{batch_idx}"))

destruct_module_optimized(model)
memory_cleanup()

In [None]:
for layer_idx in range(len(model.model.layers)):
    model.model.layers[layer_idx].to_empty(device=device)
    target_modules=[f".layers.{layer_idx}."]
    
    model=load_weights(model, model_name, weight_map, target_modules, device)

    if rhasattr(model.model.layers[layer_idx], "mlp.gate"):
        
        top_k_output = []
        top_k_weight = []
        
        for batch_idx in tqdm(range(generation_config.n_batch), desc=f"Processing MLP Layer {layer_idx}"):
            hidden_states=load_quant(os.path.join(path_config.intermediate_states, f"layer_{layer_idx-1}", f"batch_{batch_idx}")).to(device)
            

            residual = hidden_states
            
            hidden_states = model.model.layers[layer_idx].input_layernorm(hidden_states)
            
            hidden_states, self_attn_weights, present_key_value = model.model.layers[layer_idx].self_attn(
                hidden_states=hidden_states,
                attention_mask=None,
                position_ids=position_ids,
                past_key_value=None,
                output_attentions=False,
                use_cache=False,
            )

            hidden_states = residual + hidden_states

            residual = hidden_states
            hidden_states = model.model.layers[layer_idx].post_attention_layernorm(hidden_states)

            os.makedirs(os.path.join(path_config.expert_states, f"layer_{layer_idx}"), exist_ok=True)
            save_quant(hidden_states, os.path.join(path_config.expert_states, f"layer_{layer_idx}", f"batch_{batch_idx}"))

            ## For activations
            topk_idx, topk_weight, aux_loss = model.model.layers[layer_idx].mlp.gate(hidden_states)

            top_k_output.append(topk_idx)
            top_k_weight.append(topk_weight)

            
            hidden_states = model.model.layers[layer_idx].mlp(hidden_states)
            hidden_states = residual + hidden_states

            os.makedirs(os.path.join(path_config.intermediate_states, f"layer_{layer_idx}"), exist_ok=True)
            save_quant(hidden_states, os.path.join(path_config.intermediate_states, f"layer_{layer_idx}", f"batch_{batch_idx}"))

        top_k_output=torch.cat(top_k_output, dim=0)
        top_k_weight=torch.cat(top_k_weight, dim=0)
    
        os.makedirs(os.path.join(path_config.expert_activations), exist_ok=True)
        with open(os.path.join(path_config.expert_activations, f"layer_{layer_idx}.pickle"), "wb") as f:
            pickle.dump((top_k_output, top_k_weight), f)
            
    else:
        for batch_idx in tqdm(range(generation_config.n_batch), desc=f"Processing MLP Layer {layer_idx}"):
            
            hidden_states=load_quant(os.path.join(path_config.intermediate_states, f"layer_{layer_idx-1}", f"batch_{batch_idx}")).to(device)
            
            hidden_states=model.model.layers[layer_idx](
                hidden_states,
                position_ids=position_ids
            )[0]

            os.makedirs(os.path.join(path_config.intermediate_states, f"layer_{layer_idx}"), exist_ok=True)
            save_quant(hidden_states, os.path.join(path_config.intermediate_states, f"layer_{layer_idx}", f"batch_{batch_idx}"))
            
    destruct_module_optimized(model)
    memory_cleanup()