In [13]:
import torch
from transformers import AutoModel, AutoTokenizer
from utils import *
from prune import *
from collections import Counter, defaultdict
from scipy.spatial.distance import pdist, squareform
from scipy.cluster.hierarchy import linkage, dendrogram, fcluster
import random

In [2]:
model = AutoModel.from_pretrained("bigscience/bloom-560m", output_attentions=True)
n_layer, n_head = get_model_layers_and_heads(model.config)
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
model.eval()

prompts = get_prompts_from_file("paws_en")

prune_percent = 0.25
n_groups = (n_head * n_layer) - int(n_layer * n_head * prune_percent)


  return torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count


In [3]:
n_head*n_layer - n_groups

96

In [4]:
print(n_head*n_layer)

384


In [4]:
# clustering_dict, attentions, attention_vectors = get_clustering_dict(prompts, model, tokenizer, n_groups, "cosine",  n_layer, n_head, by_layer=False)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [3]:
path = "/home/data_shares/mapillary/thesis_models/pruned_models/"
duplicate_prune_model(prompts,path,model, model_name="bigscience/bloom-560m", tokenizer=tokenizer, prune_method="imbalanced", prune_task="paws_en", prune_percent=0.25, metric="euclidean",group_metric="euclidean", verbose=True)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Clustering
Clustering Done
Counter({1: 244, 2: 23, 3: 8, 4: 7, 6: 2, 9: 2, 7: 1, 5: 1})


'/home/data_shares/mapillary/thesis_models/pruned_models/bloom-560m/imbalanced/paws_en/euclidean_euclidean/0.25'

In [40]:
group_metric = 'cosine'
if group_metric != 'random':
    squaref = squareform(pdist(attention_vectors, metric=group_metric))

In [49]:
counter = Counter()
pruning_log = []
verbose = True
for group in clustering_dict.values():
    counter.update([len(group)])
    group_scores = defaultdict(int)
    if len(group) <= 1:
        continue
    if len(group) == 2:
        # with 2 heads just keep the first 1
        head_to_keep = group[0]
    else:
        if group_metric == 'random':
            head_to_keep = random.choice(group)
        else:
            for head_id in group:
                for head_id_2 in group:
                    if head_id == head_id_2:
                        continue
                    head1 = head_id[0]*n_head + head_id[1]
                    head2 = head_id_2[0]*n_head + head_id_2[1]
                    group_scores[head_id] += squaref[head1, head2]
            head_to_keep = min(group_scores, key=lambda k: group_scores[k])
            
    for head in group:
        if head == head_to_keep:
            continue
        head_to_remove = head
        pruning_log.append((head_to_keep, head_to_remove))
        model = duplicate_prune(model, source_layer=head_to_keep[0], source_head=head_to_keep[1], target_layer=head_to_remove[0], target_head=head_to_remove[1])
if verbose:
    print("size of groups: ", counter)

size of groups:  Counter({1: 257, 2: 12, 3: 7, 4: 6, 6: 3, 12: 1, 23: 1, 5: 1})


In [51]:
len(pruning_log[0])

2

In [7]:
heads = []
for head in range(n_head):
    for layer in range(n_layer):
        heads.append((layer, head))

        

In [9]:
k=0.5*n_head*n_layer
to_prune = random.choices(heads, k=int(k))
pruning_dict = defaultdict(list)
for head in to_prune:
    pruning_dict[head[0]].append(head[1])



In [12]:
to_prune

[(15, 5),
 (18, 15),
 (2, 9),
 (11, 14),
 (4, 12),
 (0, 5),
 (23, 5),
 (12, 3),
 (13, 12),
 (8, 6),
 (2, 6),
 (17, 3),
 (23, 1),
 (12, 12),
 (9, 14),
 (23, 15),
 (18, 6),
 (2, 3),
 (11, 3),
 (0, 1),
 (3, 3),
 (2, 1),
 (17, 3),
 (8, 6),
 (8, 3),
 (5, 0),
 (4, 10),
 (20, 9),
 (3, 14),
 (15, 2),
 (17, 1),
 (0, 2),
 (4, 14),
 (21, 10),
 (2, 1),
 (14, 10),
 (18, 14),
 (15, 15),
 (21, 10),
 (19, 4),
 (15, 1),
 (7, 2),
 (8, 2),
 (16, 12),
 (18, 3),
 (3, 6),
 (22, 5),
 (23, 10),
 (14, 11),
 (10, 11),
 (22, 4),
 (8, 0),
 (9, 11),
 (8, 11),
 (22, 11),
 (5, 14),
 (3, 1),
 (14, 14),
 (0, 5),
 (2, 9),
 (9, 10),
 (13, 15),
 (19, 7),
 (22, 5),
 (5, 6),
 (18, 7),
 (0, 3),
 (3, 1),
 (22, 14),
 (1, 7),
 (9, 8),
 (9, 4),
 (14, 1),
 (7, 6),
 (14, 15),
 (10, 12),
 (17, 13),
 (0, 8),
 (16, 15),
 (5, 10),
 (6, 11),
 (7, 4),
 (18, 1),
 (5, 8),
 (21, 5),
 (19, 11),
 (7, 13),
 (8, 5),
 (8, 10),
 (5, 7),
 (21, 8),
 (19, 8),
 (14, 12),
 (21, 14),
 (21, 15),
 (0, 6),
 (9, 10),
 (22, 8),
 (23, 12),
 (16, 8),
 (9, 1