In [1]:
import torch
import wandb

import sys
sys.path.append('/home/ubuntu/taker/src')

from taker import Model
from taker.activations import get_midlayer_activations
from taker.data_classes import PruningConfig, RunDataHistory
from taker.model_repos import test_model_repos
from taker.prune import prune_and_evaluate, run_pruning



In [3]:
def get_bucket_peaks(activations):
    # Check if the activations tensor is of type torch.float16
    if activations.dtype == torch.float16:
        # Convert to torch.float32 for histogram calculation
        activations_float32 = activations.float()
    else:
        # Use the original tensor if it's already in a supported data type
        activations_float32 = activations

    # Prepare for histogram computation
    min_val = activations_float32.min()
    max_val = activations_float32.max()
    bins = 100

    # Initialize an empty tensor to hold the peak values
    peak_values_float32 = torch.empty(activations_float32.size()[:-1], device=activations_float32.device, dtype=torch.float32)

    # Compute the histogram and find the peak for each neuron in every layer
    for i in range(activations_float32.size()[0]):  # Assuming the first dimension is layers
        for j in range(activations_float32.size()[1]):  # Assuming the second dimension is neurons
            hist = torch.histc(activations_float32[i, j], bins=bins, min=min_val, max=max_val)
            peak_bin = hist.argmax()
            # Compute the center value of the peak bin
            bin_width = (max_val - min_val) / bins
            peak_value = min_val + bin_width * (peak_bin.float() + 0.5)
            peak_values_float32[i, j] = peak_value

    # If the original tensor was torch.float16, convert the result back to torch.float16
    if activations.dtype == torch.float16:
        peak_values = peak_values_float32.half()
    else:
        peak_values = peak_values_float32

    return peak_values

In [4]:

c = PruningConfig("nickypro/tinyllama-15m",
    attn_mode="pre-out", do_attn_mean_offset=False, use_accelerator=False,
    ff_frac=0.5, attn_frac=0.5,
    token_limit=1000, focus="pile", cripple="code", wandb_entity="seperability",
    wandb_project="bens-tests", wandb_run_name="test notebook2", n_steps=10,)

opt: Model = Model(c.model_repo, limit=c.token_limit, dtype="fp32",
            use_accelerator=c.use_accelerator)
focus_data = get_midlayer_activations( opt, "pile", 1e4, collect_ff=True, collect_attn=True )
cripple_data = get_midlayer_activations( opt, "code", 1e4, collect_ff=True, collect_attn=True )

history = RunDataHistory(c.datasets)
wandb.init(
    project=c.wandb_project,
    entity=c.wandb_entity,
    name=c.wandb_run_name,
    )
wandb.config.update(c.to_dict(), allow_val_change=True)
with torch.no_grad():
    for i in range(c.n_steps):
        data = prune_and_evaluate(opt, c, focus_data, cripple_data, i)
        history.add(data)

- Loaded nickypro/tinyllama-15m
 - Registered 6 Attention Layers


Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/10000.0 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (3324 > 2048). Running this sequence through the model will result in indexing errors
 90%|█████████ | 9048/10000.0 [00:16<00:01, 543.62it/s]


KeyboardInterrupt: 