In [5]:
import torch
from transformers import AutoModel, AutoTokenizer
from utils import *
from prune import *
from collections import Counter
from scipy.spatial.distance import pdist, squareform
from scipy.cluster.hierarchy import linkage, dendrogram, fcluster

In [2]:
model = AutoModel.from_pretrained("bigscience/bloom-560m", output_attentions=True)
n_layer, n_head = get_model_layers_and_heads(model.config)
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
model.eval()

prompts = get_prompts_from_file("paws_en")

prune_percent = 0.25
n_groups = (n_head * n_layer) - int(n_layer * n_head * prune_percent)


In [39]:
clustering_dict, attentions, attention_vectors = get_clustering_dict(prompts, model, tokenizer, n_groups, "cosine",  n_layer, n_head, by_layer=False)

In [40]:
group_metric = 'cosine'
if group_metric != 'random':
    squaref = squareform(pdist(attention_vectors, metric=group_metric))

In [49]:
counter = Counter()
pruning_log = []
verbose = True
for group in clustering_dict.values():
    counter.update([len(group)])
    group_scores = defaultdict(int)
    if len(group) <= 1:
        continue
    if len(group) == 2:
        # with 2 heads just keep the first 1
        head_to_keep = group[0]
    else:
        if group_metric == 'random':
            head_to_keep = random.choice(group)
        else:
            for head_id in group:
                for head_id_2 in group:
                    if head_id == head_id_2:
                        continue
                    head1 = head_id[0]*n_head + head_id[1]
                    head2 = head_id_2[0]*n_head + head_id_2[1]
                    group_scores[head_id] += squaref[head1, head2]
            head_to_keep = min(group_scores, key=lambda k: group_scores[k])
            
    for head in group:
        if head == head_to_keep:
            continue
        head_to_remove = head
        pruning_log.append((head_to_keep, head_to_remove))
        model = duplicate_prune(model, source_layer=head_to_keep[0], source_head=head_to_keep[1], target_layer=head_to_remove[0], target_head=head_to_remove[1])
if verbose:
    print("size of groups: ", counter)

size of groups:  Counter({1: 257, 2: 12, 3: 7, 4: 6, 6: 3, 12: 1, 23: 1, 5: 1})


In [51]:
len(pruning_log[0])

2