In [1]:
try: # if in google colab, download necessary python files
  import google.colab
  ! git clone https://github.com/pesvut/opt-tools.git && mv ./opt-tools/src/*.py .
except ModuleNotFoundError:
  pass
! pip install -qq transformers datasets evaluate zstandard

In [None]:
from model import Model
from texts import prepare_pile, prepare_code
import torch
import numpy as np

In [None]:
def prepare( dataset_name ):
    if dataset_name == 'pile':
        return prepare_pile()
    if dataset_name == 'code':
        return prepare_code()

opt = Model("125m")

In [None]:
print( opt.model.decoder.layers[0] )

In [5]:
from random import sample
from tqdm.notebook import tqdm
import os
import datetime

def setup_counter(ff_keys):
    shape = ff_keys.size()
    counter = []
    for i in range(shape[0]):
        counter.append( torch.zeros( shape[-1]))
    return torch.stack(counter).to( opt.device )

def count_keys( dataset_name, limit=1000, sample_size=10000, num_samples=1, 
                check_accuracy=False, k=10, check_skips=False ):
    dataset, label, skip_eval = prepare( dataset_name )
    counters = []
    counter = None
    curr_count = 0
    with tqdm(total=sample_size*num_samples) as pbar:
        for data in dataset:
            text = data[label]
            input_ids = opt.get_ids( text, limit=limit )
            ids = input_ids.squeeze().detach().cpu()

            # Criteria for counting the token activation
            criteria = torch.ones_like( ids, dtype=torch.bool )

            # check if prediction is accurate enough to count
            if check_accuracy:
                residual_stream = opt.get_residual_stream( input_ids=input_ids )
                logits = opt.unembed( residual_stream[-1] ).detach().cpu()
                top_k_tokens = opt.top_k_tokens( logits, k=k ).squeeze()

                for index in range(len(ids)-1):
                    criteria[index] *= (ids[index+1] in top_k_tokens[index])

            # Choose a set of token ids to skip 
            if check_skips:
                skip_ids = set()
                for skip_string in skip_eval:
                    skip_id = int( opt.get_ids( skip_string ).squeeze()[-1] )
                    skip_ids.add( skip_id )

                for index in range(len(ids)-1):
                    criteria[index] *= (ids[index+1] in skip_ids)
                
            num_valid_tokens = criteria.sum()
            curr_count += num_valid_tokens

            ff_keys = opt.get_ff_key_activations(input_ids=input_ids)
            if counter is None:
                counter = setup_counter(ff_keys)
            
            for layer_index, layer in enumerate(ff_keys):
                for token_index, key_activation in enumerate(layer):
                    if not criteria[token_index]:
                        continue
                    counter[layer_index] += ( key_activation != 0 )


            pbar.update( int(num_valid_tokens) )
            if curr_count > sample_size:
                counter = counter / curr_count
                counters.append( counter.detach().cpu() )
                print( f'sample {len(counters)}: {curr_count}' )
                
                counter = setup_counter(ff_keys)
                curr_count = 0
            
            if len( counters ) >= num_samples:
                break
    
    return torch.stack( counters )

def acc_str( acc, pred ):
    percentage = (100*round(acc/pred, 3))
    return "%.1f"%percentage + "% - ( {acc}/{pred} )"

def evaluate( dataset_name, limit : int = 1e6 ):
    dataset, label, skip_eval = prepare( dataset_name )
    out = opt.evaluate_dataset( dataset, token_limit=1000, k=1,
        start_index=1, stopping_index=limit, skip_eval=skip_eval,
        dataset_text_label=label, count_tokens=False )
    print( f'{dataset_name} w/ skip:', 
        acc_str(out['num_skip_accurate'], out['num_skip_predictions']) )
    print( f'{dataset_name} no skip:',
        acc_str( out['num_accurate'], out['num_predictions']) )
    return out

def evaluate_all( limit: int = 1e5 ):
    pile_out = evaluate( 'pile', limit )
    code_out = evaluate( 'code', limit )

def delete_and_evaluate(
        freq_multiple: float,
        counter_sample_size: int = 5e4,
        eval_sample_size: int = 1e5,
        ):
    # Count activation of MLP middle layers
    pile_counters = count_keys( 'pile', sample_size=counter_sample_size, num_samples=1, check_accuracy=True )
    code_counters = count_keys( 'code', sample_size=counter_sample_size, num_samples=1, check_accuracy=True )
    
    # Delete when the MLP layer activates way more for code than pile
    ff_criterion = ( code_counters[0] > (freq_multiple*pile_counters[0]) )
    sums = [ x.sum() for x in ff_criterion.detach().numpy() ]
    print( "%5d -"%np.sum(sums), sums )
    opt.delete_ff_keys( ff_criterion )
    
    # See the effect this has on performance
    evaluate_all( eval_sample_size )

    try:
        # Save the indices that were deleted into the timestamped file
        now = datetime.datetime.now().strftime( "%Y-%m-%d_%H:%M:%S" )
        filename = f'tmp/{opt.model_size}-{freq_multiple}x-{now}.npy'
        os.makedirs( 'tmp', exist_ok=True )
        with open(filename, 'wb') as f:
            np.save(f, ff_criterion)
    except:
        print("Did not save sadly :(")


In [6]:

criteria = np.array( [[True, False, True],[True,False,False]], dtype=int )
now = datetime.datetime.now().strftime( "%Y-%m-%d_%H:%M:%S" )
filename = f'tmp/{opt.model_size}-{1000}x-{now}.npy'
os.makedirs( 'tmp', exist_ok=True )
with open(filename, 'wb') as f:
    np.save(f, criteria)

In [None]:
# Evaluate model before removal of any neurons
evaluate_all( 1e5 )

In [None]:
for i in range(10):
    print('\n\n- RUNNING RUN No', i )
    delete_and_evaluate(3)

In [None]:
#import matplotlib.pyplot as plt
#for layer in range(len(code_counters[0])):
#    plt.figure()
#    subsample = 1000
#    plt.plot(code_counters[0][layer][:subsample], color='red', linewidth=0.2)
#    plt.plot(pile_counters[0][layer][:subsample], color='blue', linewidth=0.2)

for 1e4 steps
```
pile w/ skip: 39.0% - ( 4022/10381 )
pile no skip: 49.0% - ( 8109/16445 )

code w/ skip: 43.0% - ( 4381/10209 )
code no skip: 60.0% - ( 12974/21528)
```

After 1 iteration:
```
pile w/ skip: 36.0% - ( 3702/10381 )
pile no skip: 46.0% - ( 7573/16445 )

code w/ skip: 35.0% - ( 3620/10209 )
code no skip: 50.0% - ( 10818/21528 )
```

After 3 iterations:
```
pile w/ skip: 31.0% - ( 3249/10381 )
pile no skip: 43.0% - ( 7007/16445 )

code w/ skip: 21.0% - ( 2108/10209 )
code no skip: 37.0% - ( 7970/21528 )
```