In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import torch
import numpy as np

import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter

In [None]:
datasets = torch.load('datasets.pt')

In [None]:
tokenizer = AutoTokenizer.from_pretrained('./custom_llama_statistics/model')
model = AutoModelForCausalLM.from_pretrained('./custom_llama_statistics/model', torch_dtype=torch.float32, trust_remote_code=True)

In [None]:
device = 'cpu'
device_2 = 'cuda'

In [None]:
def get_batch(data, batch_size, block_size):
    start_idxs = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in start_idxs])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in start_idxs])
    return x, y

In [None]:
sparsity_level = 0.5

In [None]:
avg_loss = 0.0
n_batch = 64
accum_steps = 4
batch_size = 1
block_size = 2048
torch.manual_seed(42)

gate_proj_states_thresholds = [torch.zeros([1,]) for _ in range(len(model.model.layers))]
up_proj_states_mean_squares = [torch.zeros(model.config.intermediate_size) for _ in range(len(model.model.layers))]
attention_inputs_thresholds = [torch.zeros([1,]) for _ in range(len(model.model.layers))]
attention_outputs_thresholds = [torch.zeros([1,]) for _ in range(len(model.model.layers))]

gate_proj_states = [torch.zeros([accum_steps * batch_size * block_size, model.config.intermediate_size]) for _ in range(len(model.model.layers))]
up_proj_states = [torch.zeros([accum_steps * batch_size * block_size, model.config.intermediate_size]) for _ in range(len(model.model.layers))]
attention_input_states = [torch.zeros([accum_steps * batch_size * block_size, model.config.hidden_size]) for _ in range(len(model.model.layers))]
attention_output_states = [torch.zeros([accum_steps * batch_size * block_size, model.config.hidden_size]) for _ in range(len(model.model.layers))]

with torch.no_grad():
    for step in range(n_batch // accum_steps):
        print(step * accum_steps)
        for batch_idx in range(accum_steps):
            inputs, labels = get_batch(datasets['train'], batch_size, block_size)
            inputs = inputs.to(device)
            outputs = model(inputs, labels=inputs)
            avg_loss = avg_loss + outputs.loss / n_batch

            for layer_idx in range(len(model.model.layers)):
                states = model.model.layers[layer_idx].mlp.gate_proj_states
                gate_proj_states[layer_idx][batch_idx * batch_size * block_size : (batch_idx + 1) * batch_size * block_size, :] = states.reshape(-1, states.size(-1))

                states = model.model.layers[layer_idx].mlp.up_proj_states
                up_proj_states[layer_idx][batch_idx * batch_size * block_size : (batch_idx + 1) * batch_size * block_size, :] = states.reshape(-1, states.size(-1))

                states = model.model.layers[layer_idx].self_attn.attention_input_states
                attention_input_states[layer_idx][batch_idx * batch_size * block_size : (batch_idx + 1) * batch_size * block_size, :] = states.reshape(-1, states.size(-1))

                states = model.model.layers[layer_idx].self_attn.attention_output_states
                attention_output_states[layer_idx][batch_idx * batch_size * block_size : (batch_idx + 1) * batch_size * block_size, :] = states.reshape(-1, states.size(-1))
        
        for layer_idx in range(len(model.model.layers)):   
            gate_proj_states_thresholds[layer_idx] += gate_proj_states[layer_idx].to(device_2).abs().flatten().kthvalue(int(gate_proj_states[layer_idx].numel() * sparsity_level)).values.to('cpu')

            attention_inputs_thresholds[layer_idx] += attention_input_states[layer_idx].to(device_2).abs().flatten().kthvalue(int(attention_input_states[layer_idx].numel() * sparsity_level)).values.to('cpu')

            attention_outputs_thresholds[layer_idx] += attention_output_states[layer_idx].to(device_2).abs().flatten().kthvalue(int(attention_output_states[layer_idx].numel() * sparsity_level)).values.to('cpu')
            
            up_proj_states_mean_squares[layer_idx] += (torch.sum(up_proj_states[layer_idx].to(device_2) ** 2, dim=0).to('cpu') / up_proj_states[layer_idx].size(0)).to('cpu')

for layer_idx in range(len(model.model.layers)):
    gate_proj_states_thresholds[layer_idx] /= n_batch // accum_steps
    attention_inputs_thresholds[layer_idx] /= n_batch // accum_steps
    attention_outputs_thresholds[layer_idx] /= n_batch // accum_steps
    up_proj_states_mean_squares[layer_idx] /= n_batch // accum_steps

avg_loss

In [None]:
importance_thresholds = [torch.zeros([1,]) for _ in range(len(model.model.layers))]
gate_proj_states_thresholds_2 = [torch.zeros(model.config.intermediate_size) for _ in range(len(model.model.layers))]

with torch.no_grad():
    for step in range(n_batch // accum_steps):
        print(step * accum_steps)
        for batch_idx in range(accum_steps):
            inputs, labels = get_batch(datasets['train'], batch_size, block_size)
            inputs = inputs.to(device)
            outputs = model(inputs, labels=inputs)
            avg_loss = avg_loss + outputs.loss / n_batch

            for layer_idx in range(len(model.model.layers)):
                states = model.model.layers[layer_idx].mlp.gate_proj_states
                gate_proj_states[layer_idx][batch_idx * batch_size * block_size : (batch_idx + 1) * batch_size * block_size, :] = states.reshape(-1, states.size(-1))
        
        for layer_idx in range(len(model.model.layers)):   
            importance_scores = gate_proj_states[layer_idx] ** 2 * up_proj_states_mean_squares[layer_idx]
            importance_thresholds[layer_idx] += importance_scores.to(device_2).flatten().kthvalue(int(importance_scores.numel() * sparsity_level)).values.to('cpu')

for layer_idx in range(len(model.model.layers)):
    importance_thresholds[layer_idx] /= n_batch // accum_steps
    gate_proj_states_thresholds_2[layer_idx] = (importance_thresholds[layer_idx].expand_as(gate_proj_states_thresholds_2[layer_idx]) / up_proj_states_mean_squares[layer_idx]) ** 0.5

In [None]:
thresholds = {'gate_proj_states_thresholds': gate_proj_states_thresholds, 'attention_inputs_thresholds': attention_inputs_thresholds, 'attention_outputs_thresholds': attention_outputs_thresholds, 'gate_proj_states_thresholds_2': gate_proj_states_thresholds_2}

torch.save(thresholds, 'thresholds_0_5.pt')