In [1]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

import torch
import wandb

import sys
sys.path.append('/root/taker/src')

from taker import Model
from taker.activations import get_midlayer_data
from taker.data_classes import PruningConfig, RunDataHistory, RunDataItem
from taker.model_repos import test_model_repos
from taker.prune import prune_and_evaluate, run_pruning
from taker.eval import evaluate_all



In [3]:
def get_bucket_peaks(activations):
    # Check if the activations tensor is of type torch.float16
    if activations.dtype == torch.float16:
        # Convert to torch.float32 for histogram calculation
        activations_float32 = activations.float()
    else:
        # Use the original tensor if it's already in a supported data type
        activations_float32 = activations

    # Prepare for histogram computation
    min_val = activations_float32.min()
    max_val = activations_float32.max()
    bins = 100

    # Initialize an empty tensor to hold the peak values
    peak_values_float32 = torch.empty(activations_float32.size()[:-1], device=activations_float32.device, dtype=torch.float32)

    # Compute the histogram and find the peak for each neuron in every layer
    for i in range(activations_float32.size()[0]):  # Assuming the first dimension is layers
        print(f"getting buckets for layer {i}")
        for j in range(activations_float32.size()[1]):  # Assuming the second dimension is neurons
            hist = torch.histc(activations_float32[i, j], bins=bins, min=min_val, max=max_val)
            peak_bin = hist.argmax()
            # Compute the center value of the peak bin
            bin_width = (max_val - min_val) / bins
            peak_value = min_val + bin_width * (peak_bin.float() + 0.5)
            peak_values_float32[i, j] = peak_value

    # If the original tensor was torch.float16, convert the result back to torch.float16
    if activations.dtype == torch.float16:
        peak_values = peak_values_float32.half()
    else:
        peak_values = peak_values_float32

    return peak_values

In [4]:
c = PruningConfig("google/gemma-2b",
#c = PruningConfig("facebook/galactica-1.3b",
#c = PruningConfig("facebook/opt-1.3b",
#c = PruningConfig("nickypro/tinyllama-15m",
    attn_mode="pre-out", do_attn_mean_offset=False, use_accelerator=False,
    ff_frac=0.1, attn_frac=0.1,
    token_limit=1000, focus="pile", cripple="code", wandb_entity="seperability", recalculate_activations=False, dtype="fp32",
    #wandb_project="bens-tests", wandb_run_name="OPT-1.3b orig scoring. ff=0.1 attn=0.1", n_steps=10)
    wandb_project="bens-tests", wandb_run_name="gemma-2b peak scoring. ff=0.1 attn=0.1", n_steps=10, scoring_normalization="peak_centered")
    #wandb_project="bens-tests", wandb_run_name="gemma-2b orig scoring. ff=0.1 attn=0.1", n_steps=10)
    #wandb_project="bens-tests", wandb_run_name="test notebook2", n_steps=10)

opt: Model = Model(c.model_repo, limit=c.token_limit, #dtype="nf4",
            use_accelerator=c.use_accelerator)
#focus_data = get_midlayer_activations( opt, "pile", 1e4, collect_ff=True, collect_attn=True )
#cripple_data = get_midlayer_activations( opt, "code", 1e4, collect_ff=True, collect_attn=True )

print_gpu_memory_usage()
focus_data = get_midlayer_data( opt, "pile", 1e4, collect_ff=True, collect_attn=True )
cripple_data = get_midlayer_data( opt, "code", 1e4, collect_ff=True, collect_attn=True )
print_gpu_memory_usage()

#FIXME: delete
#print the shapes of the activations
#print("focus shapes: ")
#print(f"ff:  + {focus_data.raw['ff'].shape}")
#print(f"attn:  + {focus_data.raw['attn'].shape}")

# [token, layer, neuron] -> [layer, neuron, token]
#focus_ff_activations   = focus_data["raw"]["mlp"].permute( (1,2,0) )
focus_ff_activations   = focus_data.raw["mlp"].permute( (1,2,0) )
cripple_ff_activations = cripple_data.raw["mlp"].permute( (1,2,0) )
# [token, layer, attention head, attention neuron] -> [layer, attention head, attention neuron, token]
focus_attn_activations   = focus_data.raw["attn"].permute( (1,2,3,0) ).reshape( (opt.cfg.n_layers, opt.cfg.d_model, -1) )
cripple_attn_activations = cripple_data.raw["attn"].permute( (1,2,3,0) ).reshape( (opt.cfg.n_layers, opt.cfg.d_model, -1) )

print("3")
print_gpu_memory_usage()
print("getting ff peaks")
focus_ff_peaks = get_bucket_peaks(focus_ff_activations).cuda()
cripple_ff_peaks = get_bucket_peaks(cripple_ff_activations).cuda()

#print peak shapes
print("peak shapes: ")
print(f"ff:  + {focus_ff_peaks.shape}")

#tinyllama-15m
#focus_attn_peaks = get_bucket_peaks(focus_attn_activations).reshape(6, 6, 48).cuda()
#cripple_attn_peaks = get_bucket_peaks(cripple_attn_activations).reshape(6, 6, 48).cuda()

#opt1.3b
#focus_attn_peaks = get_bucket_peaks(focus_attn_activations).reshape(24, 32, 64).cuda()
#cripple_attn_peaks = get_bucket_peaks(cripple_attn_activations).reshape(24, 32, 64).cuda()

print("4")
print_gpu_memory_usage()
#gemma 2b
print("getting attn peaks")
focus_attn_peaks = get_bucket_peaks(focus_attn_activations).reshape(18, 8, 256).cuda()
cripple_attn_peaks = get_bucket_peaks(cripple_attn_activations).reshape(18, 8, 256).cuda()

# test reversing the shapes of the activatons
print("-----------------")
print(f"ff reshaped:  + {focus_ff_activations.shape}")
print(f"ff original:  + {focus_data.raw['mlp'].shape}")
print(f"ff peaks:  + {focus_ff_peaks.shape}")

#same thing with attn
print(f"attn reshaped:  + {focus_attn_activations.shape}")
print(f"attn original:  + {focus_data.raw['attn'].shape}")
print(f"attn peaks:  + {focus_attn_peaks.shape}")

# Now get activation data again with peaks offsets
#This was breaking things before (kernel restart)
focus_data   = get_midlayer_data( opt, "pile", 1e4, collect_ff=True, collect_attn=True, ff_peak=focus_ff_peaks, attn_peak=focus_attn_peaks )
cripple_data = get_midlayer_data( opt, "code", 1e4, collect_ff=True, collect_attn=True, ff_peak=cripple_ff_peaks,  attn_peak=cripple_attn_peaks )

#print("focus data: ")
#print(focus_data.ff.peak_centered)
#only ff peaks
#cripple_data = get_midlayer_activations( opt, "code", 1e4, collect_ff=True, collect_attn=True, ff_peak=cripple_ff_peaks)

history = RunDataHistory(c.datasets)
wandb.init(
    project=c.wandb_project,
    entity=c.wandb_entity,
    name=c.wandb_run_name,
    )
wandb.config.update(c.to_dict(), allow_val_change=True)


print("6")
print_gpu_memory_usage()
with torch.no_grad(): 
    #evaluate without pruning first
    data = RunDataItem()
    eval_out = evaluate_all(opt, c.eval_sample_size, c.datasets,
                            dataset_tokens_to_skip=c.collection_sample_size)
    data.update(eval_out)
    history.add(data)
    for i in range(c.n_steps):
        print (f"step {i}")
        data = prune_and_evaluate(opt, c, focus_data, cripple_data, i)
        history.add(data)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded model 'google/gemma-2b' with bfp16:
- Added 288 hooks across 18 layers
1
hi!


pile: 10199it [00:17, 581.17it/s]                            
code: 10742it [00:15, 698.71it/s]                            


2
hi!
Number of steps: 10
3
hi!
getting ff peaks
getting buckets for layer 0
getting buckets for layer 1
getting buckets for layer 2
getting buckets for layer 3
getting buckets for layer 4
getting buckets for layer 5
getting buckets for layer 6
getting buckets for layer 7
getting buckets for layer 8
getting buckets for layer 9
getting buckets for layer 10
getting buckets for layer 11
getting buckets for layer 12
getting buckets for layer 13
getting buckets for layer 14
getting buckets for layer 15
getting buckets for layer 16
getting buckets for layer 17
getting buckets for layer 0
getting buckets for layer 1
getting buckets for layer 2
getting buckets for layer 3
getting buckets for layer 4
getting buckets for layer 5
getting buckets for layer 6
getting buckets for layer 7
getting buckets for layer 8
getting buckets for layer 9
getting buckets for layer 10
getting buckets for layer 11
getting buckets for layer 12
getting buckets for layer 13
getting buckets for layer 14
getting bucket

pile: 10199it [00:22, 457.15it/s]                            
code: 10742it [00:25, 424.00it/s]                            
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mbpas992[0m ([33mseperability[0m). Use [1m`wandb login --relogin`[0m to force relogin


6
hi!


  0%|          | 0/100000.0 [00:00<?, ?it/s]



code     Acc: 93.36|77.83 (Skip: 90.40|72.59): : 100033it [00:43, 2304.44it/s]                           
pile     Acc: 85.16|57.19 (Skip: 81.72|52.72): : 100116it [00:38, 2599.02it/s]                           


step 0


  0%|          | 0/100000.0 [00:00<?, ?it/s]



code     Acc: 90.51|66.09 (Skip: 86.35|58.88): : 100033it [00:44, 2264.37it/s]                           
pile     Acc: 81.47|49.04 (Skip: 77.24|44.24): : 100116it [00:37, 2702.26it/s]                           


step 1


  0%|          | 0/100000.0 [00:00<?, ?it/s]



code     Acc: 75.37|39.34 (Skip: 66.68|30.08): : 100033it [00:43, 2281.74it/s]                           
pile     Acc: 68.19|32.17 (Skip: 62.52|28.17): : 100116it [00:36, 2729.82it/s]                           


step 2


  0%|          | 0/100000.0 [00:00<?, ?it/s]



code     Acc: 40.06|15.39 (Skip: 27.63|8.57): : 100033it [00:43, 2299.77it/s]                           
pile     Acc: 46.57|18.08 (Skip: 39.13|15.35): : 100116it [00:36, 2744.59it/s]                           


step 3


  0%|          | 0/100000.0 [00:00<?, ?it/s]



code     Acc: 17.57|5.48 (Skip: 9.04|2.40): : 100033it [00:43, 2277.54it/s]                           
pile     Acc: 23.34|6.83 (Skip: 16.28|5.27): : 100116it [00:36, 2713.14it/s]                           


step 4


  0%|          | 0/100000.0 [00:00<?, ?it/s]



code     Acc: 7.86|2.04 (Skip: 3.05|0.66): : 100033it [00:43, 2303.78it/s]                           
pile     Acc: 10.61|2.63 (Skip: 5.91|1.64): : 100116it [00:36, 2739.89it/s]                           


step 5


  0%|          | 0/100000.0 [00:00<?, ?it/s]



code     Acc: 5.13|1.20 (Skip: 1.30|0.22):  86%|████████▌ | 85736/100000.0 [00:37<00:05, 2720.20it/s]

KeyboardInterrupt: 

code     Acc: 5.13|1.20 (Skip: 1.30|0.22):  86%|████████▌ | 85736/100000.0 [00:50<00:05, 2720.20it/s]