# Прунинг [SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer) (50 баллов)

Будем прунить [SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer) для [задачи сегментации людей](https://www.kaggle.com/datasets/laurentmih/aisegmentcom-matting-human-datasets).

## Скачаем вспомогательный код и чекпоинт бейзлайна (не то же, что в первой домашке)

In [2]:
# !wget -O hw_files_2.zip 'https://www.dropbox.com/scl/fi/66vn2n3p2nb1tjzs2jmog/hw_files_2.zip?rlkey=0je4fwxakn3zb3mqsewhkdjc8&dl=0'
# !unzip hw_files_2.zip

### Скачаем датасет (Если остался с 1ой домшки можно переиспользовать)

In [None]:
# https://drive.google.com/file/d/1YOEDzZvhLb2DS1Yn7p7MSs41ou3ZBXUq/view?usp=sharing
# !unzip matting_human_dataset.zip

### Установим библиотеки

Эти из прошлой домашки:

In [None]:
!pip install torch transformers datasets tensorboard pillow

А эти новые:

In [3]:
!pip install torch_pruning

Defaulting to user installation because normal site-packages is not writeable
Collecting torch_pruning
  Downloading torch_pruning-1.3.7-py3-none-any.whl (56 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.5/56.5 KB[0m [31m627.6 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Installing collected packages: torch_pruning
Successfully installed torch_pruning-1.3.7


In [9]:
import sys
from pathlib import Path
import os

# Calculate the absolute path to the parent directory (two levels up)
parent_dir = os.path.abspath(os.path.join(os.getcwd(), '../../'))

# Convert the Path object to a string and add it to sys.path
sys.path.append(str(parent_dir))


import typing
import torch

from copy import deepcopy
from datasets import load_metric
from torch import nn
from torch.nn import functional as F
from torch.utils.tensorboard.writer import SummaryWriter
from tqdm.auto import tqdm

# utils у нас появились при скачивании вспомогательного кода. При желании можно в них провалиться-поизучать
from utils.data import init_dataloaders
from utils.model import evaluate_model
from utils.model import init_model_with_pretrain, init_model_student_with_pretrain

from torch import nn
from transformers.models.segformer.modeling_segformer import SegformerLayer, SegformerEfficientSelfAttention

import torch_pruning as tp

In [20]:
distilled_ckpt = '../../assignment1/runs/distillation/ckpt_6.pth'
save_dir = '../../assignment2/runs/magnitude_equal_pruning'
baseline_path = "../../assignment2/runs/baseline_ckpt.pth"

In [11]:
# маппинг названия классов и индексов
id2label = {
    0: "background",
    1: "human",
}
label2id = {v: k for k, v in id2label.items()}

Создадим лоадеры:

In [13]:
dataset_dir = "/home/gvasserm/data/matting_human_dataset/"
train_dataloader, valid_dataloader = init_dataloaders(
    root_dir=dataset_dir,
    batch_size=8,
    num_workers=8,
)

Создадим baseline модель:

In [21]:
baseline_model = init_model_with_pretrain(label2id=label2id, id2label=id2label, pretrain_path=baseline_path).cuda()

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


И сразу отвалидируем:

In [22]:
evaluate_model(baseline_model, valid_dataloader, id2label)

  state_dict = torch.load(pretrain_path, devices)['state_dict']
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


Mean_iou: 0.9825463313870122
Mean accuracy: 0.9910212045247916


{'mean_iou': 0.9825463313870122,
 'mean_accuracy': 0.9910212045247916,
 'overall_accuracy': 0.9912375183105469,
 'per_category_iou': array([0.98135781, 0.98373485]),
 'per_category_accuracy': array([0.98774413, 0.99429828])}

Создадим модель после дистилляции (можно использовать модель,полученную в первой домашке):

In [23]:
distilled_model = init_model_student_with_pretrain(distilled_ckpt).cuda()

Проверим точность:

In [24]:
evaluate_model(distilled_model, valid_dataloader, id2label)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


Mean_iou: 0.9646906777392876
Mean accuracy: 0.9816622249066476


{'mean_iou': 0.9646906777392876,
 'mean_accuracy': 0.9816622249066476,
 'overall_accuracy': 0.9821188583374023,
 'per_category_iou': array([0.96220293, 0.96717842]),
 'per_category_accuracy': array([0.97474439, 0.98858006])}

Оценим вычислительную сложность и количество параметров моделей:

In [25]:
input_example = torch.rand(1,3,512,512, device="cuda")

In [26]:
ops, params = tp.utils.count_ops_and_params(baseline_model, input_example)
print(f"Baseline model complexity: {ops/1e6} MMAC, {params/1e6} M params")

Baseline model complexity: 6761.228288 MMAC, 3.714658 M params


In [27]:
ops, params = tp.utils.count_ops_and_params(distilled_model, input_example)
print(f"Distilled model complexity: {ops/1e6} MMAC, {params/1e6} M params")

Distilled model complexity: 5841.819136 MMAC, 2.29821 M params


Проверим, что модель после дистилляции имеет по одному SegformerLayer в block-е:

In [28]:
distilled_model

SegformerForSemanticSegmentation(
  (segformer): SegformerModel(
    (encoder): SegformerEncoder(
      (patch_embeddings): ModuleList(
        (0): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(3, 32, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
          (layer_norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        )
        (1): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        )
        (2): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(64, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((160,), eps=1e-05, elementwise_affine=True)
        )
        (3): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(160, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  

## Magnitude pruning (10 баллов)

Выполните one-shot прунинг модели по L2 норме весов в uniform режиме. Помните, что последний слой желательно не прунить. Поставьте pruning_ratio=0.75

In [29]:
def prune_model_l2(model):
    # вот тут надо воспользоваться библиотекой torch pruning

    example_inputs = torch.randn(1, 3, 512, 512, device="cuda")
    #L2 Magnitude prunning
    imp = tp.importance.MagnitudeImportance(p=2)

    # Ignore some layers, e.g., the output layer
    ignored_layers = [model.decode_head]

    pruner = tp.pruner.MagnitudePruner(
        model,
        example_inputs,
        imp,
        pruning_ratio = 0.75,
        pruning_ratio_dict = {},
        ignored_layers=ignored_layers,
    )
    pruner.step()

    return model

pruned_model = prune_model_l2(deepcopy(distilled_model))

In [30]:
# Проверим, запускается ли наша запруненная сеть
pruned_model(input_example)

RuntimeError: shape '[1, 16384, 1, 32]' is invalid for input of size 131072

## Почините модельку (20 баллов)

In [31]:
# Проанализируйте лог ошибки, и поймите почему модель перестала запускаться после прунинга
# Подсказка, это связано со слоем внимания и размером голов

def fix_attention_layer(pruned_model):
    num_heads = {}
    for m in pruned_model.modules():
        if isinstance(m, SegformerEfficientSelfAttention):
            num_heads[m.query] = m.num_attention_heads
            num_heads[m.key] = m.num_attention_heads
            num_heads[m.value] = m.num_attention_heads

    for m in pruned_model.modules():
        if isinstance(m, SegformerEfficientSelfAttention):
            print(m)
            print("num_heads:", m.num_attention_heads, 'head_dims:', m.attention_head_size, 'all_head_size:', m.all_head_size, '=>')
            m.num_attention_heads = num_heads[m.query]
            m.attention_head_size = m.query.out_features // m.num_attention_heads
            m.all_head_size = m.query.out_features
            print("num_heads:", m.num_attention_heads, 'head_dims:', m.attention_head_size, 'all_head_size:', m.all_head_size)
    
    return pruned_model

In [32]:
# Убедитесь, что модель запускается после фикса
pruned_model = fix_attention_layer(pruned_model)
pruned_model(input_example)

SegformerEfficientSelfAttention(
  (query): Linear(in_features=8, out_features=8, bias=True)
  (key): Linear(in_features=8, out_features=8, bias=True)
  (value): Linear(in_features=8, out_features=8, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
  (sr): Conv2d(8, 8, kernel_size=(8, 8), stride=(8, 8))
  (layer_norm): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
)
num_heads: 1 head_dims: 32 all_head_size: 32 =>
num_heads: 1 head_dims: 8 all_head_size: 8
SegformerEfficientSelfAttention(
  (query): Linear(in_features=16, out_features=16, bias=True)
  (key): Linear(in_features=16, out_features=16, bias=True)
  (value): Linear(in_features=16, out_features=16, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
  (sr): Conv2d(16, 16, kernel_size=(4, 4), stride=(4, 4))
  (layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
num_heads: 2 head_dims: 32 all_head_size: 64 =>
num_heads: 2 head_dims: 8 all_head_size: 16
SegformerEfficientSelfAttention(
  (query): Line

SemanticSegmenterOutput(loss=None, logits=tensor([[[[-0.9120, -0.9594, -1.2067,  ..., -0.8537, -0.9097, -0.8198],
          [-0.8806, -0.8767, -1.2150,  ..., -1.0846, -0.9317, -1.0418],
          [-0.8034, -0.9083, -1.1288,  ..., -0.6273, -0.8092, -0.8748],
          ...,
          [-0.5124, -0.5545, -0.7156,  ..., -0.6038, -1.2535, -0.7603],
          [-0.5662, -0.6462, -0.9241,  ..., -0.6759, -0.8009, -0.7472],
          [-0.5649, -0.6868, -0.7955,  ..., -0.5920, -0.9182, -0.7973]],

         [[ 0.9185,  0.9346,  1.1449,  ...,  0.7656,  0.8672,  0.7434],
          [ 0.8215,  0.8072,  1.2135,  ...,  0.9763,  0.9648,  1.0666],
          [ 0.7870,  0.9248,  1.0535,  ...,  0.6151,  0.7583,  0.9625],
          ...,
          [ 0.5162,  0.5852,  0.6627,  ...,  0.6102,  1.2155,  0.6626],
          [ 0.5043,  0.6721,  0.8944,  ...,  0.7548,  0.7050,  0.7460],
          [ 0.5812,  0.7473,  0.9994,  ...,  0.5661,  0.8649,  0.8011]]]],
       device='cuda:0', grad_fn=<ConvolutionBackward0>), hi

In [33]:
# Оценим вычислительную сложность получившейся модели
ops, params = tp.utils.count_ops_and_params(pruned_model, input_example)
print(f"Distilled model complexity (After magnitude pruning): {ops/1e6} MMAC, {params/1e6} M params")

Distilled model complexity (After magnitude pruning): 4483.917952 MMAC, 0.42249 M params


In [34]:
# Попробуем уменьшать модель еще сильнее, запрунив головы в attention.
# Функционал torch pruning это не поддерживает, однако это доступно в transformers
# Для выбора наименее полезных голов можно воспользоваться L2 нормой весов. 
# Мы же тут выкинем все, кроме нулевой.

pruned_model.segformer.encoder.block[1][0].attention.prune_heads([1])
pruned_model.segformer.encoder.block[2][0].attention.prune_heads([1,2,3,4])
pruned_model.segformer.encoder.block[3][0].attention.prune_heads([1,2,3,4,5,6,7])

In [35]:
# Снова оценим вычислительную сложность
ops, params = tp.utils.count_ops_and_params(pruned_model, input_example)
print(f"Distilled model complexity (After magnitude pruning): {ops/1e6} MMAC, {params/1e6} M params")

Distilled model complexity (After magnitude pruning): 4475.856736 MMAC, 0.402234 M params


##  Дообучение запруненной модели (5 баллов)

In [36]:
# перенесите свой трейновый пайплайн из предыдущей домашки в отдельный файл и воспользуйтесь им
from utils.train import TrainParams, train

In [38]:
train_params = TrainParams(
    n_epochs=1,
    lr=12e-5,
    batch_size=24,
    n_workers=8,
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    temperature=3,
    loss_weight=0.5,
    last_layer_loss_weight=0.5,
    intermediate_attn_layers_weights=(0.5, 0.5, 0.5, 0.5),
    intermediate_feat_layers_weights=(0.5, 0.5, 0.5, 0.5)
)

with torch.no_grad():
    teacher_attentions = baseline_model(pixel_values=torch.ones(1, 3, 512, 512).to(train_params.device), output_attentions=True).attentions
    student_attentions = pruned_model(pixel_values=torch.ones(1, 3, 512, 512).to(train_params.device), output_attentions=True).attentions

    assert len(teacher_attentions) == 8
    assert len(student_attentions) == 4

student_teacher_attention_mapping = {i: i*2 + 1 for i in range(4)}

tb_writer = SummaryWriter(save_dir)

In [40]:
train(
    teacher_model=baseline_model,
    student_model=pruned_model,
    train_params=train_params,
    student_teacher_attention_mapping=student_teacher_attention_mapping,
    tb_writer=tb_writer,
    save_dir=save_dir,
    id2label=id2label
)

train_params.n_epochs=5
train_params.lr=18e-5

train(
    teacher_model=baseline_model,
    student_model=pruned_model,
    train_params=train_params,
    student_teacher_attention_mapping=student_teacher_attention_mapping,
    tb_writer=tb_writer,
    save_dir=save_dir,
    id2label=id2label
)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
total loss: -61848.336: 100%|██████████| 1148/1148 [10:43<00:00,  1.78it/s]
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


Mean_iou: 0.8653744293493053
Mean accuracy: 0.9268792786380723


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
total loss: -68095.578: 100%|██████████| 1148/1148 [11:24<00:00,  1.68it/s]
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


Mean_iou: 0.8938826308531549
Mean accuracy: 0.9431564493025576


total loss: -65177.156: 100%|██████████| 1148/1148 [11:25<00:00,  1.67it/s]
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


Mean_iou: 0.9135316849328331
Mean accuracy: 0.9543163906238968


total loss: -66662.406: 100%|██████████| 1148/1148 [11:08<00:00,  1.72it/s]
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


Mean_iou: 0.9156998973348219
Mean accuracy: 0.9564662425854906


total loss: -65914.555: 100%|██████████| 1148/1148 [11:02<00:00,  1.73it/s]
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


Mean_iou: 0.9211714091611245
Mean accuracy: 0.9577187579337156


total loss: -67122.961: 100%|██████████| 1148/1148 [11:19<00:00,  1.69it/s]
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


Mean_iou: 0.931900305332622
Mean accuracy: 0.9641197057345005


# Taylor pruning (15 баллов)

In [42]:
import gc
gc.collect()
torch.cuda.empty_cache()

# Print current GPU memory usage
t = torch.cuda.get_device_properties(0).total_memory
r = torch.cuda.memory_reserved(0)
a = torch.cuda.memory_allocated(0)
f = r-a  # free inside reserved

print(f'Total: {t}, Reserved: {r}, Allocated: {a}, Free: {f}')

Total: 16899571712, Reserved: 4955570176, Allocated: 2996960768, Free: 1958609408


Далее требуется выполнить прунинг по Taylor критерию важности, и сравнить точности полученных моделей после тюнинга. Уровень прунинга и структуру  (uniform) оставьте такой же, как для L2.

In [43]:
baseline_model = init_model_with_pretrain(label2id=label2id, id2label=id2label, pretrain_path=baseline_path).cuda()
distilled_model = init_model_student_with_pretrain(distilled_ckpt).cuda()
pruned_model = deepcopy(distilled_model)

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [44]:
def prune_model_taylor(model):
    # вот тут надо воспользоваться библиотекой torch pruning
    # тут возвращается pruner, а не моделька 

    example_inputs = torch.randn(1, 3, 512, 512, device="cuda")

    # Ignore some layers, e.g., the output layer
    ignored_layers = [model.decode_head]

    #Taylor prunning
    imp = tp.importance.TaylorImportance()

    pruner = tp.pruner.MetaPruner(
                model, 
                example_inputs, 
                global_pruning=False, # If False, a uniform pruning ratio will be assigned to different layers.
                importance=imp, # importance criterion for parameter selection
                pruning_ratio=0.75, # target pruning ratio
                ignored_layers=ignored_layers,
                output_transform=lambda out: out.logits.sum())

    return pruner

pruner = prune_model_taylor(pruned_model)

In [None]:
# Определим критерий и уровень прунинга

Для прунинга по Тейлору необходимо накопить градиенты на весах, они используются для оценки важности каналов

In [45]:
def calibrate_model(model, train_loader, device):

    model.zero_grad()
    print("Accumulating gradients for taylor pruning...")
    pbar = tqdm(enumerate(train_loader), total=len(train_loader))
    for idx, batch in pbar:
        imgs = batch['pixel_values'].to(device)
        lbls = batch['labels'].to(device)
        loss = model(
                pixel_values=imgs, 
                labels=lbls
            ).loss
        
        loss.backward()
    return model

In [46]:
# Обратите внимание, у вас применение прунинга и его создание разнесены по функциям.
def apply_taylor_pruning(pruner):
    for g in pruner.step(interactive=True):
        g.prune()
    return None

In [47]:
pruned_model = calibrate_model(pruned_model, train_dataloader, "cuda")
apply_taylor_pruning(pruner)

Accumulating gradients for taylor pruning...


100%|██████████| 3443/3443 [09:45<00:00,  5.88it/s]


In [48]:
# Проверим, запускается ли наша запруненная сеть
pruned_model(input_example)

RuntimeError: shape '[1, 16384, 1, 32]' is invalid for input of size 131072

Попробуйте тот же фикс, как для прунинга по L2

In [49]:
# Убедитесь, что модель запускается после фикса
pruned_model = fix_attention_layer(pruned_model)
pruned_model(input_example)

SegformerEfficientSelfAttention(
  (query): Linear(in_features=8, out_features=8, bias=True)
  (key): Linear(in_features=8, out_features=8, bias=True)
  (value): Linear(in_features=8, out_features=8, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
  (sr): Conv2d(8, 8, kernel_size=(8, 8), stride=(8, 8))
  (layer_norm): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
)
num_heads: 1 head_dims: 32 all_head_size: 32 =>
num_heads: 1 head_dims: 8 all_head_size: 8
SegformerEfficientSelfAttention(
  (query): Linear(in_features=16, out_features=16, bias=True)
  (key): Linear(in_features=16, out_features=16, bias=True)
  (value): Linear(in_features=16, out_features=16, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
  (sr): Conv2d(16, 16, kernel_size=(4, 4), stride=(4, 4))
  (layer_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
)
num_heads: 2 head_dims: 32 all_head_size: 64 =>
num_heads: 2 head_dims: 8 all_head_size: 16
SegformerEfficientSelfAttention(
  (query): Line

SemanticSegmenterOutput(loss=None, logits=tensor([[[[-1.2204, -1.2700, -1.3748,  ..., -1.0825, -1.0776, -1.0892],
          [-1.2702, -1.1465, -1.2638,  ..., -1.2320, -0.9118, -0.9954],
          [-1.1736, -1.1875, -1.3741,  ..., -0.6371, -0.8822, -0.8666],
          ...,
          [-0.8053, -0.8706, -1.1042,  ..., -0.5968, -1.2919, -0.6706],
          [-0.9656, -0.9587, -1.3761,  ..., -0.7648, -0.9936, -0.6254],
          [-0.9083, -1.0362, -1.3280,  ..., -0.7713, -1.0960, -0.7790]],

         [[ 1.2403,  1.2960,  1.3943,  ...,  1.0853,  1.1161,  1.0943],
          [ 1.2749,  1.1636,  1.3368,  ...,  1.2758,  1.0029,  1.2188],
          [ 1.1994,  1.2574,  1.4077,  ...,  0.7134,  0.9193,  0.9919],
          ...,
          [ 0.8889,  0.9950,  1.1544,  ...,  0.7167,  1.4545,  0.7343],
          [ 1.0120,  1.0776,  1.4508,  ...,  0.9096,  1.0147,  0.7238],
          [ 0.9933,  1.1896,  1.5637,  ...,  0.8353,  1.1530,  0.8827]]]],
       device='cuda:0', grad_fn=<ConvolutionBackward0>), hi

In [50]:
# Оценим сложность полученной модели
ops, params = tp.utils.count_ops_and_params(pruned_model, input_example)
print(f"Distilled model complexity (After taylor pruning): {ops/1e6} MMAC, {params/1e6} M params")

Distilled model complexity (After taylor pruning): 4483.917952 MMAC, 0.42249 M params


Выполним дообучение, и сравним точности

In [51]:
train_params.n_epochs=1
train_params.lr=12e-5

In [52]:
save_dir = '../../assignment2/runs//taylor_equal_pruning'
tb_writer = SummaryWriter(save_dir)

In [53]:
train(
    teacher_model=baseline_model,
    student_model=pruned_model,
    train_params=train_params,
    student_teacher_attention_mapping=student_teacher_attention_mapping,
    tb_writer=tb_writer,
    save_dir=save_dir,
    id2label=id2label
)

train_params.n_epochs=5
train_params.lr=18e-5

train(
    teacher_model=baseline_model,
    student_model=pruned_model,
    train_params=train_params,
    student_teacher_attention_mapping=student_teacher_attention_mapping,
    tb_writer=tb_writer,
    save_dir=save_dir,
    id2label=id2label
)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
total loss: -67187.289: 100%|██████████| 1148/1148 [10:58<00:00,  1.74it/s]
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


Mean_iou: 0.8826069606563256
Mean accuracy: 0.936557109570126


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
total loss: -67636.344: 100%|██████████| 1148/1148 [11:04<00:00,  1.73it/s]
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


Mean_iou: 0.9000582493453515
Mean accuracy: 0.9458899839746695


total loss: -66479.633: 100%|██████████| 1148/1148 [10:54<00:00,  1.76it/s]
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


Mean_iou: 0.9254577826363342
Mean accuracy: 0.9600806486478971


total loss: -65710.164: 100%|██████████| 1148/1148 [10:24<00:00,  1.84it/s]
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


Mean_iou: 0.9379602108055853
Mean accuracy: 0.9675411037730601


total loss: -66448.703: 100%|██████████| 1148/1148 [10:22<00:00,  1.85it/s]
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


Mean_iou: 0.9464933973183418
Mean accuracy: 0.9722341138546797


total loss: -67671.445: 100%|██████████| 1148/1148 [11:11<00:00,  1.71it/s]
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


Mean_iou: 0.9437714826851367
Mean accuracy: 0.9703885253869773


In [None]:
'''
Experiments Hyperparameter Tunning Notebook

Model           Size (M)  Mean_iou Mean_acc
Baseline	    3.714	   0.9825	0.991
Distilled	    2.298	   0.9647	0.981
MagPrunned      0.422      0.9319   0.964
TaylorPrunned   0.422      0.9437   0.970
'''