In [1]:
import os
import sys

sys.path.append("../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from src.utils.helper import Config, color_print
from src.utils.load import load_model, load_data, save_checkpoint, load_checkpoint
from src.models.evaluate import evaluate_model, get_sparsity, get_similarity
from src.utils.sampling import SamplingDataset


In [3]:
name = "bert-4-128-yahoo"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 128
ratio = 0.5
seed = 44

In [4]:
config = Config(name, device)

In [5]:
model = load_model(config=config)

Loading the model.
{'architectures': 'bert',
 'dataset_name': 'YahooAnswersTopics',
 'model_name': 'models/bert-4-128-yahoo',
 'num_labels': 10,
 'tokenizer_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics'}
The model models/bert-4-128-yahoo is loaded.


In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    config,
    batch_size=batch_size,
    num_workers=num_workers,
    do_cache=True,
)

Loading cached dataset YahooAnswersTopics.
train.pkl is loaded from cache.
valid.pkl is loaded from cache.
test.pkl is loaded from cache.
The dataset YahooAnswersTopics is loaded
{'config_name': 'yahoo_answers_topics',
 'features': {'first_column': 'question_title', 'second_column': 'topic'},
 'path': 'yahoo_answers_topics'}


In [7]:
import copy
import torch
import numpy as np
from typing import *
import torch.nn as nn
from functools import partial
from transformers.pytorch_utils import (
    find_pruneable_heads_and_indices,
    prune_linear_layer,
)


def calculate_head_importance(
    model,
    config,
    data,
    normalize_scores_by_layer=True,
):
    device = config.device
    from functools import partial

    gradients = {}
    context_layers = {}

    def save_grad(gradients, layer_idx, grad):
        gradients[f"context_layer_{layer_idx}"] = grad

    def forward_hook(module, input, output, gradients, context_layers, layer_idx):
        context_layers[f"context_layer_{layer_idx}"] = output[0]
        output[0].register_hook(partial(save_grad, gradients, layer_idx))

    def reshape(tensors, shape, num_heads):
        batch_size = shape[0]
        seq_len = shape[1]
        head_dim = shape[2] // num_heads
        tensors = tensors.reshape(batch_size, seq_len, num_heads, head_dim)
        tensors = tensors.permute(0, 2, 1, 3)
        return tensors

    forward_handles = []

    for layer_idx in range(model.bert.config.num_hidden_layers):
        self_att = model.bert.encoder.layer[layer_idx].attention.self
        handle = self_att.register_forward_hook(
            partial(
                forward_hook,
                gradients=gradients,
                context_layers=context_layers,
                layer_idx=layer_idx,
            )
        )
        forward_handles.append(handle)

    """Calculate head importance scores"""
    # Disable dropout
    model.eval()
    # Device
    device = device or next(model.parameters()).device

    # Prepare data loader
    # Head importance tensor
    n_layers = model.bert.config.num_hidden_layers
    n_heads = model.bert.config.num_attention_heads
    num_batches = 0
    head_dim = model.bert.config.hidden_size // n_heads
    head_importance = torch.zeros(n_layers, n_heads).to(device)
    tot_tokens = 0
    first_batch = next(iter(data))
    is_embeds = "embeddings" in first_batch
    for step, batch in enumerate(data):
        num_batches += 1
        if is_embeds:
            embeddings = batch["embeddings"].to(device)
        else:
            input_ids = batch["input_ids"].to(device)
        input_mask = batch["attention_mask"].to(device)
        label_ids = batch["labels"].to(device)
        # Compute gradients
        if is_embeds:
            loss = model(
                inputs_embeds=embeddings, attention_mask=input_mask, labels=label_ids
            ).loss
        else:
            loss = model(input_ids, attention_mask=input_mask, labels=label_ids).loss
        loss.backward()

        for layer_idx in range(model.bert.config.num_hidden_layers):
            ctx = context_layers[f"context_layer_{layer_idx}"]
            grad_ctx = gradients[f"context_layer_{layer_idx}"]
            shape = ctx.shape
            ctx = reshape(ctx, shape, n_heads)
            grad_ctx = reshape(grad_ctx, shape, n_heads)

            # Take the dot
            dot = torch.einsum("bhli,bhli->bhl", [grad_ctx, ctx])
            head_importance[layer_idx] += dot.abs().sum(-1).sum(0).detach()
            del ctx, grad_ctx, dot

        tot_tokens += input_mask.float().detach().sum().data

    head_importance[:-1] /= tot_tokens
    head_importance /= num_batches
    for handle in forward_handles:
        handle.remove()
    return head_importance

def calculate_head_loss(model, config, data):
    device = config.device
    model.eval()
    n_layers = model.bert.config.num_hidden_layers
    n_heads = model.bert.config.num_attention_heads
    is_embeds = "embeddings" in next(iter(data))

    head_losses = torch.zeros(n_layers, n_heads).to(device)
    num_samples = 0
    
    for step, batch in enumerate(data):
        if is_embeds:
            embeddings = batch["embeddings"].to(device)
        else:
            input_ids = batch["input_ids"].to(device)
        input_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        batch_size = input_mask.size(0)
        
        for batch_item in range(batch_size):
            num_samples += 1
            for layer in range(n_layers):
                for head in range(n_heads):
                    head_mask = torch.ones(n_layers, n_heads).to(device)
                    head_mask[layer, head] = 0

                    with torch.no_grad():
                        if is_embeds:
                            outputs = model(inputs_embeds=embeddings, attention_mask=input_mask, head_mask=head_mask, output_attentions=True)
                        else:                
                            outputs = model(input_ids=input_ids, attention_mask=input_mask, head_mask=head_mask, output_attentions=True)
                            
                    logits = outputs.logits
                    criterion = torch.nn.CrossEntropyLoss()
                    loss = criterion(logits[batch_item].unsqueeze(0), labels[batch_item].unsqueeze(0))
                    head_losses[layer, head] += loss.item()

    avg_loss = head_losses / num_samples

    return avg_loss

def calculate_layer_loss(model, config, data):
    device = config.device
    model.eval()
    n_layers = model.bert.config.num_hidden_layers
    n_heads = model.bert.config.num_attention_heads
    is_embeds = "embeddings" in next(iter(data))

    layer_losses = torch.zeros(n_layers).to(device)
    num_samples = 0

    for step, batch in enumerate(data):
        if is_embeds:
            embeddings = batch["embeddings"].to(device)
        else:
            input_ids = batch["input_ids"].to(device)
        input_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        batch_size = input_mask.size(0)
        
        for batch_item in range(batch_size):
            num_samples += 1
            for layer in range(n_layers):
            
                head_mask = torch.ones(n_layers).to(device)
                head_mask[layer] = 0

                with torch.no_grad():
                    if is_embeds:
                        outputs = model(inputs_embeds=embeddings, attention_mask=input_mask, head_mask=head_mask, output_attentions=True)
                    else:                
                        outputs = model(input_ids=input_ids, attention_mask=input_mask, head_mask=head_mask, output_attentions=True)
                        
                logits = outputs.logits
                criterion = torch.nn.CrossEntropyLoss()
                loss = criterion(logits[batch_item].unsqueeze(0), labels[batch_item].unsqueeze(0))
                layer_losses[layer] += loss.item()

    avg_loss = layer_losses / num_samples

    return avg_loss


def normalize(tensors):
    exponent = 2
    norm_by_layer = torch.pow(
        torch.pow(tensors, exponent).sum(-1), 1 / exponent
    )
    tensors /= norm_by_layer.unsqueeze(-1) + 1e-20
    return tensors

def head_importance_prunning(
    model, config, dominant_concern, sparsity_ratio, method="unstructed", scheduler=None
):
    num_attention_heads = model.config.num_attention_heads
    num_hidden_layers = model.config.num_hidden_layers
    model = model.to(config.device)
    total_heads_to_prune = int(num_attention_heads * num_hidden_layers * sparsity_ratio)

    total_heads_to_prune = max(total_heads_to_prune, num_hidden_layers)
    print(f"Total heads to prune: {total_heads_to_prune}")
    pruned_heads = set()

    if scheduler is not None:
        steps = scheduler.get_steps()
    else:
        steps = [0.25, 0.25, 0.25, 0.25]
        # steps = [1.0]

    for step_ratio in steps:
        heads_to_prune = int(total_heads_to_prune * step_ratio)
        
        head_importance_list = calculate_head_importance(
            model, config, dominant_concern
        )
        head_loss = calculate_head_loss(
            model, config, dominant_concern
        )
        layer_loss = calculate_layer_loss(
            model, config, dominant_concern
        )
        head_importance_list = head_importance_list.cpu()
        head_loss = head_loss.cpu()
        layer_loss = layer_loss.cpu()
        # head_importance_list = normalize(head_importance_list)
        # head_loss = normalize(head_loss)
        # layer_loss = normalize(layer_loss)
        head_importance_list = torch.log1p(head_importance_list)
        print(f"head_importance_list\n {head_importance_list}")
        print(f"head_loss\n {head_loss}")
        print(f"layer_loss\n {layer_loss}")
        alpha = head_loss * layer_loss.unsqueeze(1)
        print(alpha)

        head_score = head_importance_list * alpha
        print(f"head_score\n {head_score}")
        if method == "unstructed":
            sorted_indices = torch.argsort(head_score.view(-1))
            prune_list = []
            for idx in sorted_indices:
                layer_index = int(idx // num_attention_heads)
                head_index = int(idx % num_attention_heads)

                if (layer_index, head_index) not in pruned_heads:
                    prune_list.append((layer_index, head_index))
                
                if len(prune_list) >= heads_to_prune:
                    break
            
        elif method == "structed":
            heads_per_layer = heads_to_prune // num_hidden_layers
            prune_list = []
            for layer_idx in range(num_hidden_layers):
                sorted_heads = torch.argsort(head_importance_list[layer_idx])
                prune_list.extend(
                    [
                        (layer_idx, head.item())
                        for head in sorted_heads[:heads_per_layer]
                    ]
                )
        for layer_index, head_index in prune_list:
            if (layer_index, head_index) not in pruned_heads:
                prune_heads(
                    model.bert.encoder.layer[layer_index].attention,
                    [head_index],
                    method=method,
                )
                pruned_heads.add((layer_index, head_index))
        print(sorted(pruned_heads))


def prune_heads(layer, heads, method):
    if len(heads) == 0:
        return
    heads, index = find_pruneable_heads_and_indices(
        heads,
        layer.self.num_attention_heads,
        layer.self.attention_head_size,
        layer.pruned_heads,
    )

    # if method == "unstructed":
    layer.self.query = zero_out_head_weights(
        layer.self.query, heads, layer.self.attention_head_size
    )
    layer.self.key = zero_out_head_weights(
        layer.self.key, heads, layer.self.attention_head_size
    )
    layer.self.value = zero_out_head_weights(
        layer.self.value, heads, layer.self.attention_head_size
    )
    layer.output.dense = zero_out_head_weights(
        layer.output.dense, heads, layer.self.attention_head_size, dim=1
    )
    # elif method == "structed":
    #     layer.self.query = prune_linear_layer(layer.self.query, index)
    #     layer.self.key = prune_linear_layer(layer.self.key, index)
    #     layer.self.value = prune_linear_layer(layer.self.value, index)
    #     layer.output.dense = prune_linear_layer(layer.output.dense, index)

    #     layer.self.num_attention_heads = layer.self.num_attention_heads - len(heads)
    #     layer.self.all_head_size = layer.self.attention_head_size *  layer.self.num_attention_heads
    #     layer.pruned_heads = layer.pruned_heads.union(heads)


def zero_out_head_weights(
    layer: nn.Linear, heads: Set[int], head_size: int, dim: int = 0
) -> nn.Linear:
    """
    Zero out the weights of the specified heads in the linear layer.

    Args:
        layer (`torch.nn.Linear`): The layer to modify.
        heads (`Set[int]`): The indices of heads to zero out.
        head_size (`int`): The size of each head.
        dim (`int`, *optional*, defaults to 0): The dimension on which to zero out the weights.

    Returns:
        `torch.nn.Linear`: The modified layer with weights of specified heads zeroed out.
    """
    for head in heads:
        start_index = head * head_size
        end_index = (head + 1) * head_size
        if dim == 0:
            layer.weight.data[start_index:end_index] = 0
        elif dim == 1:
            layer.weight.data[:, start_index:end_index] = 0

    return layer


In [8]:
from src.models.evaluate import (
    evaluate_model,
    get_sparsity,
    get_similarity,
    get_perplexity,
)

In [9]:
result_list = []

for concern in range(config.num_labels):
  config.init_seed()
  positive_samples = SamplingDataset(
      train_dataloader,
      config,
      concern,
      num_samples,
      True,
      4,
      resample=False,
  )
  negative_samples = SamplingDataset(
      train_dataloader,
      config,
      concern,
      num_samples,
      False,
      4,
      resample=False,
  )
  module = copy.deepcopy(model)

  head_importance_prunning(module, config, positive_samples, ratio)

#   print(f"Evaluate the pruned model {concern}")
  result = evaluate_model(module, config, test_dataloader)
#   result_list.append(result)
#   get_similarity(model, module, valid_dataloader, concern, num_samples, config)
#   print("original model's perplexity")
#   get_perplexity(model, valid_dataloader, config)
#   print("pruned model's perplexity")
#   get_perplexity(module, valid_dataloader, config)
  

Total heads to prune: 8




head_importance_list
 tensor([[5.6371e-05, 6.4266e-05, 5.7026e-05, 5.8088e-05],
        [4.8789e-05, 8.0193e-05, 6.4501e-05, 6.9048e-05],
        [5.0231e-05, 7.1948e-05, 5.1353e-05, 5.3254e-05],
        [5.9122e-02, 4.9883e-02, 6.9143e-02, 5.9519e-02]])
head_loss
 tensor([[1.4575, 1.4098, 1.4667, 1.5068],
        [1.4333, 1.4087, 1.4519, 1.4669],
        [1.4812, 1.4692, 1.4888, 1.4394],
        [1.5510, 1.4956, 1.4531, 1.4402]])
layer_loss
 tensor([1.4241, 1.3157, 1.3873, 1.3975])
tensor([[2.0757, 2.0078, 2.0887, 2.1459],
        [1.8858, 1.8534, 1.9101, 1.9299],
        [2.0548, 2.0382, 2.0654, 1.9969],
        [2.1675, 2.0901, 2.0307, 2.0127]])
head_score
 tensor([[1.1701e-04, 1.2903e-04, 1.1911e-04, 1.2465e-04],
        [9.2005e-05, 1.4863e-04, 1.2321e-04, 1.3325e-04],
        [1.0321e-04, 1.4664e-04, 1.0607e-04, 1.0634e-04],
        [1.2815e-01, 1.0426e-01, 1.4041e-01, 1.1980e-01]])
[(1, 0), (2, 0)]
head_importance_list
 tensor([[5.5054e-05, 6.4347e-05, 5.6189e-05, 5.9340e-05],
 

Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.4236
Precision: 0.6311, Recall: 0.5462, F1-Score: 0.5581
              precision    recall  f1-score   support

           0     0.4136    0.6120    0.4936      2992
           1     0.6786    0.4241    0.5220      2992
           2     0.6361    0.6361    0.6361      3012
           3     0.2936    0.6594    0.4063      2998
           4     0.7448    0.7302    0.7374      2973
           5     0.8999    0.4682    0.6160      3054
           6     0.6330    0.4049    0.4939      3003
           7     0.4639    0.6122    0.5278      3012
           8     0.7050    0.4624    0.5585      2982
           9     0.8427    0.4527    0.5890      2982

    accuracy                         0.5461     30000
   macro avg     0.6311    0.5462    0.5581     30000
weighted avg     0.6313    0.5461    0.5580     30000

Total heads to prune: 8
head_importance_list
 tensor([[4.9999e-05, 7.6041e-05, 4.8225e-05, 6.2064e-05],
        [3.4291e-05, 6.8904e-05, 4.9083e-05, 5.1879e-05],
        [4.253

Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.3959
Precision: 0.6355, Recall: 0.5600, F1-Score: 0.5739
              precision    recall  f1-score   support

           0     0.4245    0.6110    0.5010      2992
           1     0.6673    0.4532    0.5398      2992
           2     0.6522    0.6325    0.6422      3012
           3     0.2969    0.6631    0.4102      2998
           4     0.7855    0.6700    0.7232      2973
           5     0.9136    0.4918    0.6394      3054
           6     0.6416    0.4029    0.4950      3003
           7     0.5000    0.5933    0.5427      3012
           8     0.6825    0.5429    0.6048      2982
           9     0.7906    0.5392    0.6411      2982

    accuracy                         0.5598     30000
   macro avg     0.6355    0.5600    0.5739     30000
weighted avg     0.6357    0.5598    0.5739     30000

Total heads to prune: 8
head_importance_list
 tensor([[4.0902e-05, 5.5916e-05, 5.2655e-05, 6.8180e-05],
        [4.5500e-05, 1.2641e-04, 6.4290e-05, 5.4530e-05],
        [5.168

Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.3753
Precision: 0.6363, Recall: 0.5709, F1-Score: 0.5841
              precision    recall  f1-score   support

           0     0.4276    0.6029    0.5003      2992
           1     0.6797    0.4412    0.5351      2992
           2     0.6489    0.6338    0.6412      3012
           3     0.3040    0.6624    0.4168      2998
           4     0.7667    0.6932    0.7281      2973
           5     0.8861    0.6343    0.7393      3054
           6     0.6333    0.4106    0.4982      3003
           7     0.5291    0.5876    0.5569      3012
           8     0.6917    0.5117    0.5883      2982
           9     0.7953    0.5315    0.6372      2982

    accuracy                         0.5710     30000
   macro avg     0.6363    0.5709    0.5841     30000
weighted avg     0.6365    0.5710    0.5843     30000

Total heads to prune: 8
head_importance_list
 tensor([[3.9856e-05, 5.2533e-05, 3.7867e-05, 4.3479e-05],
        [4.2380e-05, 5.1160e-05, 4.3393e-05, 3.7194e-05],
        [4.036

Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7314f88f8900>Traceback (most recent call last):

  File "/home/jieungkim/.cache/pypoetry/virtualenvs/decomposetransformer-UESb9BbT-py3.12/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()  File "/home/jieungkim/.cache/pypoetry/virtualenvs/decomposetransformer-UESb9BbT-py3.12/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/home/jieungkim/anaconda3/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7314f88f8900>
Traceback (most recent call last):
  File "/home/jieungkim/.cache/pypoetry/virtualenvs/decompo

Loss: 1.4013
Precision: 0.6347, Recall: 0.5566, F1-Score: 0.5726
              precision    recall  f1-score   support

           0     0.4434    0.5775    0.5017      2992
           1     0.6404    0.4131    0.5022      2992
           2     0.6568    0.5979    0.6260      3012
           3     0.2771    0.6955    0.3963      2998
           4     0.7593    0.6895    0.7227      2973
           5     0.8773    0.6113    0.7206      3054
           6     0.6293    0.4099    0.4965      3003
           7     0.5424    0.5837    0.5623      3012
           8     0.7074    0.4752    0.5685      2982
           9     0.8136    0.5124    0.6288      2982

    accuracy                         0.5567     30000
   macro avg     0.6347    0.5566    0.5726     30000
weighted avg     0.6349    0.5567    0.5727     30000

Total heads to prune: 8
head_importance_list
 tensor([[3.3910e-05, 5.0382e-05, 6.7213e-05, 5.9223e-05],
        [3.3663e-05, 6.3479e-05, 4.5566e-05, 3.2855e-05],
        [2.628

Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.4471
Precision: 0.6199, Recall: 0.5417, F1-Score: 0.5500
              precision    recall  f1-score   support

           0     0.3909    0.6270    0.4816      2992
           1     0.7084    0.3937    0.5061      2992
           2     0.6418    0.6371    0.6395      3012
           3     0.3060    0.6341    0.4128      2998
           4     0.7063    0.7659    0.7349      2973
           5     0.8684    0.4126    0.5594      3054
           6     0.6237    0.4056    0.4915      3003
           7     0.4542    0.5591    0.5012      3012
           8     0.6671    0.5443    0.5994      2982
           9     0.8327    0.4373    0.5734      2982

    accuracy                         0.5413     30000
   macro avg     0.6199    0.5417    0.5500     30000
weighted avg     0.6201    0.5413    0.5498     30000

Total heads to prune: 8
head_importance_list
 tensor([[3.9889e-05, 5.6511e-05, 3.9480e-05, 4.9755e-05],
        [5.3199e-05, 4.3852e-05, 5.0532e-05, 4.2366e-05],
        [3.571

Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.4166
Precision: 0.6279, Recall: 0.5461, F1-Score: 0.5572
              precision    recall  f1-score   support

           0     0.4224    0.5648    0.4833      2992
           1     0.6486    0.3690    0.4704      2992
           2     0.6829    0.5292    0.5963      3012
           3     0.2739    0.7041    0.3944      2998
           4     0.7942    0.6424    0.7103      2973
           5     0.8422    0.7377    0.7865      3054
           6     0.6392    0.4006    0.4925      3003
           7     0.5145    0.6524    0.5753      3012
           8     0.7013    0.2921    0.4124      2982
           9     0.7595    0.5687    0.6504      2982

    accuracy                         0.5466     30000
   macro avg     0.6279    0.5461    0.5572     30000
weighted avg     0.6281    0.5466    0.5576     30000

Total heads to prune: 8
head_importance_list
 tensor([[7.1710e-05, 6.3561e-05, 6.9254e-05, 7.1243e-05],
        [5.1929e-05, 8.9838e-05, 7.5928e-05, 9.6699e-05],
        [6.002

Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.4292
Precision: 0.6419, Recall: 0.5539, F1-Score: 0.5692
              precision    recall  f1-score   support

           0     0.4178    0.6009    0.4929      2992
           1     0.7021    0.3443    0.4620      2992
           2     0.6712    0.5930    0.6296      3012
           3     0.2799    0.6791    0.3965      2998
           4     0.7853    0.6583    0.7162      2973
           5     0.9235    0.5537    0.6923      3054
           6     0.6356    0.4129    0.5006      3003
           7     0.5169    0.6365    0.5705      3012
           8     0.6723    0.5567    0.6091      2982
           9     0.8140    0.5034    0.6220      2982

    accuracy                         0.5539     30000
   macro avg     0.6419    0.5539    0.5692     30000
weighted avg     0.6422    0.5539    0.5693     30000

Total heads to prune: 8
head_importance_list
 tensor([[4.6345e-05, 5.4424e-05, 4.5405e-05, 5.1008e-05],
        [4.1405e-05, 5.3053e-05, 4.9137e-05, 4.8097e-05],
        [4.488

Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.3791
Precision: 0.6435, Recall: 0.5654, F1-Score: 0.5801
              precision    recall  f1-score   support

           0     0.3877    0.6273    0.4793      2992
           1     0.6830    0.4141    0.5156      2992
           2     0.6788    0.5843    0.6280      3012
           3     0.2984    0.6681    0.4125      2998
           4     0.7938    0.6680    0.7255      2973
           5     0.9009    0.6785    0.7740      3054
           6     0.6405    0.4099    0.4999      3003
           7     0.5551    0.6421    0.5954      3012
           8     0.6959    0.4467    0.5441      2982
           9     0.8007    0.5148    0.6267      2982

    accuracy                         0.5656     30000
   macro avg     0.6435    0.5654    0.5801     30000
weighted avg     0.6437    0.5656    0.5804     30000

Total heads to prune: 8
head_importance_list
 tensor([[6.4833e-05, 3.1870e-05, 4.2578e-05, 4.4088e-05],
        [2.3413e-05, 6.2936e-05, 3.9190e-05, 4.1970e-05],
        [2.911

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7314f88f8900>
Traceback (most recent call last):
  File "/home/jieungkim/.cache/pypoetry/virtualenvs/decomposetransformer-UESb9BbT-py3.12/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()
  File "/home/jieungkim/.cache/pypoetry/virtualenvs/decomposetransformer-UESb9BbT-py3.12/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/home/jieungkim/anaconda3/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7314f88f8900>
Traceback (most recent call last):
  File "/home/jieungkim/.cache/pypoetry/virtualenvs/decompo

Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.4532
Precision: 0.6171, Recall: 0.5445, F1-Score: 0.5556
              precision    recall  f1-score   support

           0     0.3958    0.6243    0.4845      2992
           1     0.5960    0.4733    0.5276      2992
           2     0.6641    0.6066    0.6340      3012
           3     0.3000    0.6391    0.4083      2998
           4     0.8094    0.5658    0.6660      2973
           5     0.8992    0.3854    0.5395      3054
           6     0.6285    0.3989    0.4881      3003
           7     0.5098    0.5156    0.5127      3012
           8     0.6368    0.6402    0.6385      2982
           9     0.7309    0.5956    0.6563      2982

    accuracy                         0.5441     30000
   macro avg     0.6171    0.5445    0.5556     30000
weighted avg     0.6174    0.5441    0.5554     30000

Total heads to prune: 8
head_importance_list
 tensor([[5.7599e-05, 4.4358e-05, 4.0555e-05, 4.2634e-05],
        [4.4702e-05, 4.6824e-05, 4.9800e-05, 4.5123e-05],
        [3.168

Evaluating the model:   0%|          | 0/1875 [00:00<?, ?it/s]

Loss: 1.3779
Precision: 0.6321, Recall: 0.5727, F1-Score: 0.5846
              precision    recall  f1-score   support

           0     0.4045    0.6126    0.4873      2992
           1     0.6143    0.4499    0.5194      2992
           2     0.7165    0.5378    0.6145      3012
           3     0.3102    0.6481    0.4196      2998
           4     0.8340    0.5291    0.6475      2973
           5     0.8261    0.7449    0.7834      3054
           6     0.6487    0.3966    0.4923      3003
           7     0.5771    0.5631    0.5700      3012
           8     0.6858    0.5879    0.6331      2982
           9     0.7036    0.6566    0.6793      2982

    accuracy                         0.5729     30000
   macro avg     0.6321    0.5727    0.5846     30000
weighted avg     0.6323    0.5729    0.5849     30000

