In [3]:
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

import torch
import wandb
import pynvml

import sys
sys.path.append('/root/taker/src')

from taker import Model
from taker.activations import get_midlayer_data
from taker.data_classes import PruningConfig, RunDataHistory, RunDataItem
from taker.model_repos import test_model_repos
from taker.prune import prune_and_evaluate, run_pruning
from taker.eval import evaluate_all



In [4]:
def print_gpu_memory_usage():
    pynvml.nvmlInit()
    
    try:
        device_count = pynvml.nvmlDeviceGetCount()
        for i in range(device_count):
            handle = pynvml.nvmlDeviceGetHandleByIndex(i)
            info = pynvml.nvmlDeviceGetMemoryInfo(handle)
            
            print(f"GPU {i}:")
            print(f"  Total memory: {info.total / 1024**3:.2f} GB")
            print(f"  Used memory: {info.used / 1024**3:.2f} GB")
            print(f"  Free memory: {info.free / 1024**3:.2f} GB")
            print(f"  Memory usage: {info.used / info.total * 100:.2f}%")
            print()
    
    finally:
        pynvml.nvmlShutdown()

In [5]:
def get_bucket_peaks(activations):
    # Check if the activations tensor is of type torch.float16
    if activations.dtype == torch.float16:
        # Convert to torch.float32 for histogram calculation
        activations_float32 = activations.float()
    else:
        # Use the original tensor if it's already in a supported data type
        activations_float32 = activations

    # Prepare for histogram computation
    min_val = activations_float32.min()
    max_val = activations_float32.max()
    bins = 100

    # Initialize an empty tensor to hold the peak values
    peak_values_float32 = torch.empty(activations_float32.size()[:-1], device=activations_float32.device, dtype=torch.float32)

    # Compute the histogram and find the peak for each neuron in every layer
    for i in range(activations_float32.size()[0]):  # Assuming the first dimension is layers
        print(f"getting buckets for layer {i}")
        for j in range(activations_float32.size()[1]):  # Assuming the second dimension is neurons
            hist = torch.histc(activations_float32[i, j], bins=bins, min=min_val, max=max_val)
            peak_bin = hist.argmax()
            # Compute the center value of the peak bin
            bin_width = (max_val - min_val) / bins
            peak_value = min_val + bin_width * (peak_bin.float() + 0.5)
            peak_values_float32[i, j] = peak_value

    # If the original tensor was torch.float16, convert the result back to torch.float16
    if activations.dtype == torch.float16:
        peak_values = peak_values_float32.half()
    else:
        peak_values = peak_values_float32

    return peak_values

In [7]:
c = PruningConfig("google/gemma-2b",
#c = PruningConfig("facebook/galactica-1.3b",
#c = PruningConfig("facebook/opt-1.3b",
#c = PruningConfig("nickypro/tinyllama-15m",
    attn_mode="pre-out", do_attn_mean_offset=False, use_accelerator=False,
    ff_frac=0.1, attn_frac=0.1,
    token_limit=1000, focus="pile", cripple="code", wandb_entity="seperability", recalculate_activations=False, dtype="nf4",
    #wandb_project="bens-tests", wandb_run_name="OPT-1.3b orig scoring. ff=0.1 attn=0.1", n_steps=10)
    wandb_project="bens-tests", wandb_run_name="gemma-2b peak scoring. ff=0.1 attn=0.1", n_steps=10, scoring_normalization="peak_centered")
    #wandb_project="bens-tests", wandb_run_name="gemma-2b orig scoring. ff=0.1 attn=0.1", n_steps=10)

opt: Model = Model(c.model_repo, limit=c.token_limit, dtype="nf4",
            use_accelerator=c.use_accelerator)

print("1")
print_gpu_memory_usage()
focus_data = get_midlayer_data( opt, "pile", 1e4, collect_ff=True, collect_attn=True )
cripple_data = get_midlayer_data( opt, "code", 1e4, collect_ff=True, collect_attn=True )
print("2")
print_gpu_memory_usage()

print(f"Number of steps: {c.n_steps}")

# [token, layer, neuron] -> [layer, neuron, token]
focus_ff_activations   = focus_data.raw["mlp"].permute( (1,2,0) )
cripple_ff_activations = cripple_data.raw["mlp"].permute( (1,2,0) )
# [token, layer, attention head, attention neuron] -> [layer, attention head, attention neuron, token]
focus_attn_activations   = focus_data.raw["attn"].permute( (1,2,3,0) ).reshape( (opt.cfg.n_layers, opt.cfg.d_model, -1) )
cripple_attn_activations = cripple_data.raw["attn"].permute( (1,2,3,0) ).reshape( (opt.cfg.n_layers, opt.cfg.d_model, -1) )

print("3")
print_gpu_memory_usage()
print("getting ff peaks")
focus_ff_peaks = get_bucket_peaks(focus_ff_activations).cuda()
cripple_ff_peaks = get_bucket_peaks(cripple_ff_activations).cuda()


#tinyllama-15m
#focus_attn_peaks = get_bucket_peaks(focus_attn_activations).reshape(6, 6, 48).cuda()
#cripple_attn_peaks = get_bucket_peaks(cripple_attn_activations).reshape(6, 6, 48).cuda()

#opt1.3b
#focus_attn_peaks = get_bucket_peaks(focus_attn_activations).reshape(24, 32, 64).cuda()
#cripple_attn_peaks = get_bucket_peaks(cripple_attn_activations).reshape(24, 32, 64).cuda()

print("4")
print_gpu_memory_usage()
#gemma 2b
print("getting attn peaks")
focus_attn_peaks = get_bucket_peaks(focus_attn_activations).reshape(18, 8, 256).cuda()
cripple_attn_peaks = get_bucket_peaks(cripple_attn_activations).reshape(18, 8, 256).cuda()

# test reversing the shapes of the activatons
print("-----------------")
print(f"ff reshaped:  + {focus_ff_activations.shape}")
print(f"ff original:  + {focus_data.raw['mlp'].shape}")
print(f"ff peaks:  + {focus_ff_peaks.shape}")

#same thing with attn
print(f"attn reshaped:  + {focus_attn_activations.shape}")
print(f"attn original:  + {focus_data.raw['attn'].shape}")
print(f"attn peaks:  + {focus_attn_peaks.shape}")

# Now get activation data again with peaks offsets
#This was breaking things before (kernel restart)
print("5")
print_gpu_memory_usage()
focus_data   = get_midlayer_data( opt, "pile", 1e4, collect_ff=True, collect_attn=True, ff_peak=focus_ff_peaks, attn_peak=focus_attn_peaks )
cripple_data = get_midlayer_data( opt, "code", 1e4, collect_ff=True, collect_attn=True, ff_peak=cripple_ff_peaks,  attn_peak=cripple_attn_peaks )

history = RunDataHistory(c.datasets)
wandb.init(
    project=c.wandb_project,
    entity=c.wandb_entity,
    name=c.wandb_run_name,
    )
wandb.config.update(c.to_dict(), allow_val_change=True)


print("6")
print_gpu_memory_usage()
with torch.no_grad(): 
    #evaluate without pruning first
    data = RunDataItem()
    eval_out = evaluate_all(opt, c.eval_sample_size, c.datasets,
                            dataset_tokens_to_skip=c.collection_sample_size)
    data.update(eval_out)
    history.add(data)
    for i in range(c.n_steps):
        print (f"step {i}")
        data = prune_and_evaluate(opt, c, focus_data, cripple_data, i)
        history.add(data)

config.json:   0%|          | 0.00/627 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/33.6k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

KeyboardInterrupt: 