# Прунинг [SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer) (50 баллов)

Будем прунить [SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer) для [задачи сегментации людей](https://www.kaggle.com/datasets/laurentmih/aisegmentcom-matting-human-datasets).

## Скачаем вспомогательный код и чекпоинт бейзлайна (не то же, что в первой домашке)

In [None]:
# !wget -O hw_files_2.zip 'https://www.dropbox.com/scl/fi/66vn2n3p2nb1tjzs2jmog/hw_files_2.zip?rlkey=0je4fwxakn3zb3mqsewhkdjc8&dl=0'
# !unzip hw_files_2.zip

### Скачаем датасет (Если остался с 1ой домшки можно переиспользовать)

In [None]:
# https://drive.google.com/file/d/1YOEDzZvhLb2DS1Yn7p7MSs41ou3ZBXUq/view?usp=sharing
# !unzip matting_human_dataset.zip

### Установим библиотеки

Эти из прошлой домашки:

In [None]:
!pip install torch transformers datasets tensorboard pillow

А эти новые:

In [None]:
!pip install torch_pruning

In [None]:
import os

import typing
import torch

from copy import deepcopy
from datasets import load_metric
from torch import nn
from torch.nn import functional as F
from torch.utils.tensorboard.writer import SummaryWriter
from tqdm.auto import tqdm

# utils у нас появились при скачивании вспомогательного кода. При желании можно в них провалиться-поизучать
from utils.data import init_dataloaders
from utils.model import evaluate_model
from utils.model import init_model_with_pretrain

from torch import nn
from transformers.models.segformer.modeling_segformer import SegformerLayer, SegformerEfficientSelfAttention

import torch_pruning as tp

In [None]:
baseline_path = 'runs/baseline_ckpt.pth'
distilled_ckpt = 'runs/distillation/ckpt_2.pth'

In [None]:
# маппинг названия классов и индексов
id2label = {
    0: "background",
    1: "human",
}
label2id = {v: k for k, v in id2label.items()}

Создадим лоадеры:

In [None]:
train_dataloader, valid_dataloader = init_dataloaders(
    root_dir=".",
    batch_size=8,
    num_workers=8,
)

Создадим baseline модель:

In [None]:
baseline_model = init_model_with_pretrain(label2id=label2id, id2label=id2label, pretrain_path=baseline_path).cuda()

И сразу отвалидируем:

In [None]:
evaluate_model(baseline_model, valid_dataloader, id2label)

Создадим модель после дистилляции:

In [None]:
distilled_model = init_model_with_pretrain(label2id=label2id, id2label=id2label, pretrain_path=distilled_ckpt)

In [None]:
distilled_model.cuda()
baseline_model.cuda();

Проверим точность:

In [None]:
evaluate_model(distilled_model, valid_dataloader, id2label)

Оценим вычислительную сложность и количество параметров моделей:

In [None]:
input_example = torch.rand(1,3,512,512, device="cuda")

In [None]:
ops, params = tp.utils.count_ops_and_params(baseline_model, input_example)
print(f"Baseline model complexity: {ops/1e6} MMAC, {params/1e6} M params")

In [None]:
ops, params = tp.utils.count_ops_and_params(distilled_model, input_example)
print(f"Distilled model complexity: {ops/1e6} MMAC, {params/1e6} M params")

Проверим, что модель после дистилляции имеет по одному SegformerLayer в block-е:

In [None]:
distilled_model

## Magnitude pruning

In [None]:
l2_importance = tp.importance.MagnitudeImportance(p=2, group_reduction="mean")

ignored_layers = []
for name, module in distilled_model.named_modules():
    if name == "decode_head.classifier":
        ignored_layers.append(module)

pruner = tp.pruner.MagnitudePruner(
        model=distilled_model,
        example_inputs=input_example,
        global_pruning=False,  # If False, a uniform ratio will be assigned to different layers.
        importance=l2_importance,  # importance criterion for parameter selection
        iterative_steps=1,  # the number of iterations to achieve target ratio
        pruning_ratio=0.75,  # remove 75% of channels
        ignored_layers=ignored_layers,
    )

In [None]:
# Прунинг
pruner.step()

In [None]:
# Проверим, запускается ли наша запруненная сеть
# distilled_model(input_example);

In [None]:
# Проанализируйте лог ошибки, и поймите почему модель перестала запускаться после прунинга
# Подсказка, это связано со слоем внимания и размером голов

for module in distilled_model.modules():
    if isinstance(module, SegformerEfficientSelfAttention):
        module.attention_head_size = module.attention_head_size // 4
        module.all_head_size = module.all_head_size // 4

In [None]:
# Убедитесь, что модель запускается
distilled_model(input_example)

In [None]:
ops, params = tp.utils.count_ops_and_params(distilled_model, input_example)
print(f"Distilled model complexity (After magnitude pruning): {ops/1e6} MMAC, {params/1e6} M params")

In [None]:
distilled_model.segformer.encoder.block[1][0].attention.prune_heads([1])
distilled_model.segformer.encoder.block[2][0].attention.prune_heads([1,2,3,4])
distilled_model.segformer.encoder.block[3][0].attention.prune_heads([1,2,3,4,5,6,7])

In [None]:
ops, params = tp.utils.count_ops_and_params(distilled_model, input_example)
print(f"Distilled model complexity (After magnitude pruning): {ops/1e6} MMAC, {params/1e6} M params")

##  Дообучение запруненной модели

In [None]:
import torch
from torch import nn
import typing

from dataclasses import dataclass
from datasets import load_metric
from utils.data import init_dataloaders
from tqdm.auto import tqdm
import torch.nn.functional as F
from utils.model import evaluate_model


kl_loss = nn.KLDivLoss()

def calc_last_layer_loss(student_logits, teacher_logits, temperature, weight):
    """Считаем лосс между выходами учителя и ученика"""
    return kl_loss(
        input=F.log_softmax(student_logits / temperature, dim=-1),
        target=F.softmax(teacher_logits / temperature, dim=-1),
    ) * temperature ** 2
    return loss

def calc_intermediate_layers_loss(student_attentions, teacher_attentions, weights, student_teacher_attention_mapping):
    intermediate_kl_loss = 0
    for i, (stud_attn_idx, teach_attn_idx) in enumerate(student_teacher_attention_mapping.items()):
        intermediate_kl_loss += weights[i] * kl_loss(
            input=torch.log(student_attentions[stud_attn_idx]),
            target=teacher_attentions[teach_attn_idx],
        )
    return intermediate_kl_loss

@dataclass
class TrainParams:
    n_epochs: int
    lr: float
    batch_size: int
    n_workers: int
    device: torch.device

    temperature: float
    loss_weight: float
    last_layer_loss_weight: float
    intermediate_layers_weights: typing.Tuple[float, float, float, float]

    # возможно, в ваших экспериментах захотите добавить что-то ещё

def train(
    teacher_model,
    student_model,
    train_params: TrainParams,
    student_teacher_attention_mapping,
    tb_writer,
    save_dir,
):
    metric = load_metric('mean_iou')
    teacher_model.to(train_params.device)
    student_model.to(train_params.device)

    teacher_model.eval()

    train_dataloader, valid_dataloader = init_dataloaders(
        root_dir=".",
        batch_size=train_params.batch_size,
        num_workers=train_params.n_workers,
    )

    optimizer = torch.optim.AdamW(student_model.parameters(), lr=train_params.lr)
    step = 0
    for epoch in range(train_params.n_epochs):
        pbar = tqdm(enumerate(train_dataloader), total=len(train_dataloader))
        for idx, batch in pbar:
            student_model.train()
            # get the inputs;
            pixel_values = batch['pixel_values'].to(train_params.device)
            labels = batch['labels'].to(train_params.device)

            optimizer.zero_grad()

            # forward + backward + optimize
            student_outputs = student_model(pixel_values=pixel_values, labels=labels, output_attentions=True)
            loss, student_logits = student_outputs.loss, student_outputs.logits

            # Чего это мы no_grad() при тренировке поставили?!
            with torch.no_grad():
                teacher_output = teacher_model(pixel_values=pixel_values, labels=labels, output_attentions=True)


            last_layer_loss = calc_last_layer_loss(
                student_logits,
                teacher_output.logits,
                train_params.temperature,
                train_params.last_layer_loss_weight,
            )

            student_attentions, teacher_attentions = student_outputs.attentions, teacher_output.attentions

            intermediate_layer_loss = calc_intermediate_layers_loss(
                student_attentions,
                teacher_attentions,
                train_params.intermediate_layers_weights,
                student_teacher_attention_mapping,
            )

            total_loss = loss * train_params.loss_weight + last_layer_loss
            if intermediate_layer_loss is not None:
                total_loss += intermediate_layer_loss

            step += 1

            total_loss.backward()
            optimizer.step()
            pbar.set_description(f'total loss: {total_loss.item():.3f}')

            for loss_value, loss_name in (
                (loss, 'loss'),
                (total_loss, 'total_loss'),
                (last_layer_loss, 'last_layer_loss'),
                (intermediate_layer_loss, 'intermediate_loss'),

            ):
                if loss_value is None: # для выключенной дистилляции атеншенов
                    continue
                tb_writer.add_scalar(
                    tag='loss_name',
                    scalar_value=loss_value.item(),
                    global_step=step,
                )

        #после модификаций модели обязательно сохраняйте ее целиком, чтобы подгрузить ее в случае чего
        torch.save(
            {
                'model': student_model,
                'state_dict': student_model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
            },
            f'{save_dir}/ckpt_{epoch}.pth',
        )

        eval_metrics = evaluate_model(student_model, valid_dataloader, id2label)

        for metric_key, metric_value in eval_metrics.items():
            if not isinstance(metric_value, float):
                continue
            tb_writer.add_scalar(
                tag=f'eval_{metric_key}',
                scalar_value=metric_value,
                global_step=epoch,
            )



In [None]:
train_params = TrainParams(
    n_epochs=10,
    lr=1e-4,
    batch_size=16,
    n_workers=8,
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    temperature=10,
    loss_weight=0.5,
    last_layer_loss_weight=0.5,
    intermediate_layers_weights=(0.5, 0.5, 0.5, 0.5),
)

In [None]:
save_dir = 'runs/magnitude_equal_pruning'
tb_writer = SummaryWriter(save_dir)

In [None]:
student_teacher_attention_mapping = {0: 1, 1: 3, 2: 5, 3: 7}

In [None]:
train(
    teacher_model=baseline_model,
    student_model=deepcopy(distilled_model),
    train_params=train_params,
    student_teacher_attention_mapping=student_teacher_attention_mapping, # заполним потом
    tb_writer=tb_writer,
    save_dir=save_dir,
)

# Taylor pruning

In [None]:
distilled_model = init_model_with_pretrain(label2id=label2id, id2label=id2label, pretrain_path=distilled_ckpt).cuda()

In [None]:
taylor_criteria = tp.importance.GroupTaylorImportance()

ignored_layers = []
for name, module in distilled_model.named_modules():
    if name == "decode_head.classifier":
        ignored_layers.append(module)

pruner = tp.pruner.MetaPruner(
        distilled_model,
        example_inputs=input_example,
        importance=taylor_criteria,
        pruning_ratio=0.75,
        global_pruning=False,
        ignored_layers=ignored_layers,
    )

In [None]:
distilled_model.train()
for idx, batch in enumerate(tqdm(train_dataloader)):
    # get the inputs;
    pixel_values = batch["pixel_values"].to("cuda")
    labels = batch["labels"].to("cuda")

    # forward + backward + optimize
    outputs = distilled_model(pixel_values=pixel_values, labels=labels)
    loss, logits = outputs.loss, outputs.logits

    loss.backward()

In [None]:
for i, g in enumerate(pruner.step(interactive=True)):
    g.prune()

In [None]:
for module in distilled_model.modules():
    if isinstance(module, SegformerEfficientSelfAttention):
        module.attention_head_size = module.attention_head_size // 4
        module.all_head_size = module.all_head_size // 4

In [None]:
ops, params = tp.utils.count_ops_and_params(distilled_model, input_example)
print(f"Distilled model complexity (After taylor pruning): {ops/1e6} MMAC, {params/1e6} M params")

In [None]:
train_params = TrainParams(
    n_epochs=5,
    lr=1e-4,
    batch_size=16,
    n_workers=8,
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    temperature=10,
    loss_weight=0.5,
    last_layer_loss_weight=0.5,
    intermediate_layers_weights=(0.5, 0.5, 0.5, 0.5),
)

In [None]:
save_dir = 'runs/taylor_equal_pruning'
tb_writer = SummaryWriter(save_dir)

In [None]:
student_teacher_attention_mapping = {0: 1, 1: 3, 2: 5, 3: 7}

In [None]:
train(
    teacher_model=baseline_model,
    student_model=deepcopy(distilled_model),
    train_params=train_params,
    student_teacher_attention_mapping=student_teacher_attention_mapping, # заполним потом
    tb_writer=tb_writer,
    save_dir=save_dir,
)