In [1]:
try: # if in google colab, download necessary python files
  import google.colab
  ! git clone https://github.com/pesvut/opt-tools.git && mv ./opt-tools/src/*.py .
except ModuleNotFoundError:
  pass
! pip install -qq transformers datasets evaluate zstandard

In [2]:
from model import Model
from texts import prepare_pile, prepare_code
import torch
import numpy as np

In [3]:
def prepare( dataset_name ):
    if dataset_name == 'pile':
        return prepare_pile()
    if dataset_name == 'code':
        return prepare_code()

opt = Model("13b")

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/441 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/721 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/719 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/51.1k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/9.29G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/9.18G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/5.47G [00:00<?, ?B/s]

- Loaded OPT-None
 - Registered 40 OPT Attention Layers


In [4]:
print( opt.model.decoder.layers[0] )

OPTDecoderLayer(
  (self_attn): OPTAttention(
    (k_proj): Linear(in_features=5120, out_features=5120, bias=True)
    (v_proj): Linear(in_features=5120, out_features=5120, bias=True)
    (q_proj): Linear(in_features=5120, out_features=5120, bias=True)
    (out_proj): Linear(in_features=5120, out_features=5120, bias=True)
  )
  (activation_fn): ReLU()
  (self_attn_layer_norm): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)
  (fc1): Linear(in_features=5120, out_features=20480, bias=True)
  (fc2): Linear(in_features=20480, out_features=5120, bias=True)
  (final_layer_norm): LayerNorm((5120,), eps=1e-05, elementwise_affine=True)
)


In [5]:
from random import sample
from tqdm.notebook import tqdm
import os
import datetime

def setup_counter(ff_keys):
    shape = ff_keys.size()
    counter = []
    for i in range(shape[0]):
        counter.append( torch.zeros( shape[-1]))
    return torch.stack(counter).to( opt.device )

def count_keys( dataset_name, limit=1000, sample_size=10000, num_samples=1, 
                check_accuracy=False, k=10, check_skips=False ):
    dataset, label, skip_eval = prepare( dataset_name )
    counters = []
    counter = None
    curr_count = 0
    with tqdm(total=sample_size*num_samples) as pbar:
        for data in dataset:
            text = data[label]
            input_ids = opt.get_ids( text, limit=limit )
            ids = input_ids.squeeze().detach().cpu()

            # Criteria for counting the token activation
            criteria = torch.ones_like( ids, dtype=torch.bool )

            # check if prediction is accurate enough to count
            if check_accuracy:
                residual_stream = opt.get_residual_stream( input_ids=input_ids )
                logits = opt.unembed( residual_stream[-1] ).detach().cpu()
                top_k_tokens = opt.top_k_tokens( logits, k=k ).squeeze()

                for index in range(len(ids)-1):
                    criteria[index] *= (ids[index+1] in top_k_tokens[index])

            # Choose a set of token ids to skip 
            if check_skips:
                skip_ids = set()
                for skip_string in skip_eval:
                    skip_id = int( opt.get_ids( skip_string ).squeeze()[-1] )
                    skip_ids.add( skip_id )

                for index in range(len(ids)-1):
                    criteria[index] *= (ids[index+1] in skip_ids)
                
            num_valid_tokens = criteria.sum()
            curr_count += num_valid_tokens

            ff_keys = opt.get_ff_key_activations(input_ids=input_ids)
            if counter is None:
                counter = setup_counter(ff_keys)
            
            for layer_index, layer in enumerate(ff_keys):
                for token_index, key_activation in enumerate(layer):
                    if not criteria[token_index]:
                        continue
                    counter[layer_index] += ( key_activation != 0 )


            pbar.update( int(num_valid_tokens) )
            if curr_count > sample_size:
                counter = counter / curr_count
                counters.append( counter.detach().cpu() )
                print( f'sample {len(counters)}: {curr_count}' )
                
                counter = setup_counter(ff_keys)
                curr_count = 0
            
            if len( counters ) >= num_samples:
                break
    
    return torch.stack( counters )

def acc_str( acc, pred ):
    percentage = (100*round(acc/pred, 3))
    return "%.1f"%percentage + "% - ( {acc}/{pred} )"

def evaluate( dataset_name, limit : int = 1e6 ):
    dataset, label, skip_eval = prepare( dataset_name )
    out = opt.evaluate_dataset( dataset, token_limit=1000, k=1,
        start_index=1, stopping_index=limit, skip_eval=skip_eval,
        dataset_text_label=label, count_tokens=False )
    print( f'{dataset_name} w/ skip:', 
        acc_str(out['num_skip_accurate'], out['num_skip_predictions']) )
    print( f'{dataset_name} no skip:',
        acc_str( out['num_accurate'], out['num_predictions']) )
    return out

def evaluate_all( limit: int = 1e5 ):
    pile_out = evaluate( 'pile', limit )
    code_out = evaluate( 'code', limit )

def delete_and_evaluate(
        freq_multiple: float,
        counter_sample_size: int = 5e4,
        eval_sample_size: int = 1e5,
        ):
    # Count activation of MLP middle layers
    pile_counters = count_keys( 'pile', sample_size=counter_sample_size, num_samples=1, check_accuracy=True )
    code_counters = count_keys( 'code', sample_size=counter_sample_size, num_samples=1, check_accuracy=True )
    
    # Delete when the MLP layer activates way more for code than pile
    ff_criterion = ( code_counters[0] > (freq_multiple*pile_counters[0]) )
    sums = [ x.sum() for x in ff_criterion.detach().numpy() ]
    print( "%5d -"%np.sum(sums), sums )
    opt.delete_ff_keys( ff_criterion )
    
    try:
        # Save the indices that were deleted into the timestamped file
        print("saving file...")
        now = datetime.datetime.now().strftime( "%Y-%m-%d_%H:%M:%S" )
        filename = f'tmp/{opt.model_size}-{freq_multiple}x-{now}.npy'
        os.makedirs( 'tmp', exist_ok=True )
        with open(filename, 'wb') as f:
            np.save(f, np.array(ff_criterion) )
        print("saved successfully")
    except Exception:
        print("Did not save sadly :(")

    
    # See the effect this has on performance
    evaluate_all( eval_sample_size )

In [6]:
pre_removals = [
    "/notebooks/src/tmp/13b-3.9x-2022-11-17_14:15:52.npy",
    "/notebooks/src/tmp/13b-3.9x-2022-11-17_16:13:26.npy",
    "/notebooks/src/tmp/13b-3.9x-2022-11-17_17:05:07.npy",
    "/notebooks/src/tmp/13b-3.9x-2022-11-17_15:47:29.npy",
    "/notebooks/src/tmp/13b-3.9x-2022-11-17_16:38:41.npy",
    "/notebooks/src/tmp/13b-3.9x-2022-11-17_17:05:07.npy",
    "/notebooks/src/tmp/13b-3.9x-2022-11-17_18:02:52.npy",
    "/notebooks/src/tmp/13b-3.9x-2022-11-17_18:31:52.npy",
    "/notebooks/src/tmp/13b-3.9x-2022-11-17_19:01:42.npy",
]

for filename in pre_removals:
    ff_criterion = np.load(filename)
    sums = [ x.sum() for x in ff_criterion ]
    print( "%5d -"%np.sum(sums), sums )
    opt.delete_ff_keys( ff_criterion )

 8985 - [720, 437, 1103, 621, 621, 662, 670, 466, 310, 382, 401, 328, 298, 209, 172, 130, 27, 24, 22, 34, 77, 83, 58, 35, 43, 27, 41, 52, 39, 44, 59, 49, 102, 127, 90, 79, 121, 112, 110]
 6518 - [42, 65, 176, 634, 551, 1162, 802, 644, 277, 336, 336, 254, 447, 119, 94, 98, 23, 34, 19, 19, 44, 28, 43, 25, 12, 11, 7, 12, 9, 11, 16, 19, 26, 21, 15, 28, 22, 21, 16]
 5523 - [103, 70, 100, 211, 439, 1068, 447, 294, 383, 174, 255, 245, 277, 286, 128, 78, 32, 138, 90, 88, 124, 86, 85, 59, 25, 17, 8, 9, 14, 17, 9, 11, 6, 11, 17, 20, 22, 33, 44]
 8985 - [720, 437, 1103, 621, 621, 662, 670, 466, 310, 382, 401, 328, 298, 209, 172, 130, 27, 24, 22, 34, 77, 83, 58, 35, 43, 27, 41, 52, 39, 44, 59, 49, 102, 127, 90, 79, 121, 112, 110]
 4745 - [21, 22, 40, 208, 1081, 540, 399, 574, 155, 80, 179, 164, 353, 277, 104, 104, 17, 81, 93, 13, 37, 27, 23, 14, 9, 10, 8, 4, 3, 5, 5, 6, 6, 9, 11, 16, 14, 14, 19]
 5523 - [103, 70, 100, 211, 439, 1068, 447, 294, 383, 174, 255, 245, 277, 286, 128, 78, 32, 138, 90, 88

In [7]:
# Evaluate model before removal of any neurons
evaluate_all( 1e5 )

Downloading builder script:   0%|          | 0.00/2.65k [00:00<?, ?B/s]

No config specified, defaulting to: the_pile/all
accuracy 45.5% (curr 37.0%): : 100396it [04:45, 351.32it/s]                           


pile w/ skip: 45.5% - ( {acc}/{pred} )
pile no skip: 51.0% - ( {acc}/{pred} )


Using custom data configuration codeparrot--codeparrot-clean-valid-d84b00ddcc43747c


Downloading and preparing dataset json/codeparrot--codeparrot-clean-valid to /root/.cache/huggingface/datasets/codeparrot___json/codeparrot--codeparrot-clean-valid-d84b00ddcc43747c/0.0.0/da492aad5680612e4028e7f6ddc04b1dfcec4b64db470ed7cc5f2bb265b9b6b5...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/142M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

0 tables [00:00, ? tables/s]

Dataset json downloaded and prepared to /root/.cache/huggingface/datasets/codeparrot___json/codeparrot--codeparrot-clean-valid-d84b00ddcc43747c/0.0.0/da492aad5680612e4028e7f6ddc04b1dfcec4b64db470ed7cc5f2bb265b9b6b5. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

accuracy 42.0% (curr 52.7%): : 100420it [06:49, 245.41it/s]                           

code w/ skip: 42.0% - ( {acc}/{pred} )
code no skip: 32.1% - ( {acc}/{pred} )





In [None]:
FREQ_MULTIPLE = 3.9

for i in range(4):
    print('\n\n- RUNNING RUN No', i )
    delete_and_evaluate( FREQ_MULTIPLE )



- RUNNING RUN No 0


No config specified, defaulting to: the_pile/all


  0%|          | 0/50000.0 [00:00<?, ?it/s]

sample 1: 50184


Using custom data configuration codeparrot--codeparrot-clean-valid-d84b00ddcc43747c
Reusing dataset json (/root/.cache/huggingface/datasets/codeparrot___json/codeparrot--codeparrot-clean-valid-d84b00ddcc43747c/0.0.0/da492aad5680612e4028e7f6ddc04b1dfcec4b64db470ed7cc5f2bb265b9b6b5)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/50000.0 [00:00<?, ?it/s]

sample 1: 50215
 1018 - [2, 6, 18, 30, 97, 256, 139, 49, 20, 34, 19, 34, 25, 43, 0, 2, 2, 11, 29, 21, 15, 38, 37, 14, 49, 5, 4, 3, 6, 3, 2, 0, 0, 1, 1, 0, 1, 0, 2]
saving file...
saved successfully


No config specified, defaulting to: the_pile/all
accuracy 45.5% (curr 37.0%): : 100396it [04:44, 352.38it/s]                           


pile w/ skip: 45.5% - ( {acc}/{pred} )
pile no skip: 51.0% - ( {acc}/{pred} )


Using custom data configuration codeparrot--codeparrot-clean-valid-d84b00ddcc43747c
Reusing dataset json (/root/.cache/huggingface/datasets/codeparrot___json/codeparrot--codeparrot-clean-valid-d84b00ddcc43747c/0.0.0/da492aad5680612e4028e7f6ddc04b1dfcec4b64db470ed7cc5f2bb265b9b6b5)


  0%|          | 0/1 [00:00<?, ?it/s]

accuracy 42.0% (curr 52.4%): : 100420it [06:51, 244.17it/s]                           


code w/ skip: 42.0% - ( {acc}/{pred} )
code no skip: 32.1% - ( {acc}/{pred} )


- RUNNING RUN No 1


No config specified, defaulting to: the_pile/all


  0%|          | 0/50000.0 [00:00<?, ?it/s]

sample 1: 50173


Using custom data configuration codeparrot--codeparrot-clean-valid-d84b00ddcc43747c
Reusing dataset json (/root/.cache/huggingface/datasets/codeparrot___json/codeparrot--codeparrot-clean-valid-d84b00ddcc43747c/0.0.0/da492aad5680612e4028e7f6ddc04b1dfcec4b64db470ed7cc5f2bb265b9b6b5)


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/50000.0 [00:00<?, ?it/s]

In [None]:
for i in range(4):
    print('\n\n- RUNNING RUN No', i )
    delete_and_evaluate( FREQ_MULTIPLE )

In [None]:
#import matplotlib.pyplot as plt
#for layer in range(len(code_counters[0])):
#    plt.figure()
#    subsample = 1000
#    plt.plot(code_counters[0][layer][:subsample], color='red', linewidth=0.2)
#    plt.plot(pile_counters[0][layer][:subsample], color='blue', linewidth=0.2)

for 1e4 steps
```
pile w/ skip: 39.0% - ( 4022/10381 )
pile no skip: 49.0% - ( 8109/16445 )

code w/ skip: 43.0% - ( 4381/10209 )
code no skip: 60.0% - ( 12974/21528)
```

After 1 iteration:
```
pile w/ skip: 36.0% - ( 3702/10381 )
pile no skip: 46.0% - ( 7573/16445 )

code w/ skip: 35.0% - ( 3620/10209 )
code no skip: 50.0% - ( 10818/21528 )
```

After 3 iterations:
```
pile w/ skip: 31.0% - ( 3249/10381 )
pile no skip: 43.0% - ( 7007/16445 )

code w/ skip: 21.0% - ( 2108/10209 )
code no skip: 37.0% - ( 7970/21528 )
```