In [None]:
import numpy as np
import pandas as pd
import plotly.express as px
import torch
import math
from torch.nn.utils import prune
import tqdm.notebook as tqdm
from functools import partial
import torchvision
from collections import OrderedDict
import pandas as pd
from scipy.stats import norm
import random

import matplotlib.pyplot as plt
%matplotlib inline

def set_global_seed(seed: int) -> None:
    """
    Set global seed for reproducibility.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

set_global_seed(42)

In [None]:
BATCH_SIZE = 128
DEVICE = torch.device('cpu')
if torch.cuda.is_available():
    DEVICE = torch.device('cuda', 0)

print(type(DEVICE), DEVICE)

## Загрузка и обработка данных

In [None]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomCrop(32, 4),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

In [None]:
ds_train = torchvision.datasets.CIFAR10(
    root='./', train=True, transform=transform, download=True
)
ds_test = torchvision.datasets.CIFAR10(
    root='./', train=False,
    transform=torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )]),
    download=True
)
ds_train, ds_test

In [None]:
tmean, tstd = transform.transforms[-1].mean, transform.transforms[-1].std
tmean, tstd = np.array(tmean), np.array(tstd)

inverse_transform = torchvision.transforms.Compose([
    torchvision.transforms.Normalize(
        mean=-tmean / tstd,
        std=1.0 / tstd
    ),
    torchvision.transforms.ToPILImage()
])

In [None]:

fig, axes = plt.subplots(2, 5, figsize=(13, 6))

for idx, ds in enumerate((ds_train, ds_test)):
    for jdx, kdx in enumerate(np.random.randint(0, len(ds), size=5)):
        image, label = ds[kdx]
        axes[idx, jdx].imshow(inverse_transform(image))
        axes[idx, jdx].set_title(f'Метка: {label} -> {ds.classes[label]}')

axes[0, 0].set_ylabel('Обучающая выборка')
axes[1, 0].set_ylabel('Тестовая выборка')

fig.tight_layout()
plt.show()

In [None]:
dl_train = torch.utils.data.DataLoader(
    dataset=ds_train, batch_size=BATCH_SIZE,
    num_workers=2, shuffle=True
)
dl_test = torch.utils.data.DataLoader(
    dataset=ds_test, batch_size=BATCH_SIZE,
    num_workers=2, shuffle=False
)

In [None]:
class ConvNet(torch.nn.Module):
    cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']

    def __init__(self, n_classes=10, use_batchnorm=False, dropout_p=0.0):
        '''
        :param int n_classes: Число выходных признаков
        :param bool use_batchnorm: Использовать ли батчнорм между свёрточными слоями
        :param float dropout_p: Вероятность обнуления активации слоем Dropout
        '''
        super().__init__()

        self.n_classes = n_classes

        self.in_channels = 3
        self.features = torch.nn.Sequential()
        for cfg_item in self.cfg:
            if isinstance(cfg_item, int):
                self.features.append(torch.nn.Conv2d(self.in_channels,
                                                     out_channels=int(cfg_item),
                                                     kernel_size=3,
                                                     padding=1))
                if use_batchnorm:
                    self.features.append(torch.nn.BatchNorm2d(int(cfg_item)))
                self.features.append(torch.nn.ReLU(inplace=True))
                self.in_channels = int(cfg_item)
            elif cfg_item == "M":
                self.features.append(torch.nn.MaxPool2d(kernel_size=(2, 2), stride=2))

        # self.avgpool = torch.nn.AdaptiveAvgPool2d(output_size=(2, 2))
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(in_features=512, out_features=512),
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout(p=dropout_p),
            torch.nn.Linear(in_features=512, out_features=512),
            torch.nn.ReLU(inplace=True),
            torch.nn.Dropout(p=dropout_p),
            torch.nn.Linear(in_features=512, out_features=10)
        )
        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                m.bias.data.zero_()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        # x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

In [None]:
def test_model(model, train_dataloader) -> float:
    model.to(DEVICE)
    model.eval()
    with torch.no_grad():
        train_accuracies = []
        for images, labels in tqdm.tqdm(train_dataloader, total=len(train_dataloader), leave=False):
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)
            outputs  = model(images)
            train_accuracies.append(torch.sum(outputs.argmax(dim=1) == labels)/labels.shape[0])
        return float((sum(train_accuracies) / len(train_accuracies)).cpu() * 100)

In [None]:
def count_parameters(net) -> int:
  count_params = 0
  count_zero_params = 0
  for name, module in net.named_modules():
      if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
          count_params += module.weight.numel()
          count_zero_params += torch.sum(module.weight == 0)
          sparsity = 100. * float(torch.sum(module.weight == 0)) / float(module.weight.nelement())
          print(f'Module len: {len(module.weight)}; Sparsity in {name}.weight with {module.weight.numel()} parameters: {sparsity:0.3f}%')
  print(f'Global sparsity: {100. * (float(count_zero_params) / float(count_params)):0.3f}%')
  return count_params

In [None]:
PRUN_PERCENT = 10
N_PERCENT = 5
M_PERCENT = 10
DUMMY_PARAMS_PERCENT = 90

In [None]:
class Relevance():
    def __init__(self, model, dummy_percent, n_percent, m_percent) -> None:
        self.n_percent = n_percent
        self.module_size_list = []
        self.n_list = []
        self.dummy_size_list = []
        self.m_list = []
        self.r_list = []
        self.q_list = []

        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Conv2d):
                module_size = module.weight.shape[0]
                self.module_size_list.append(module_size)
                self.n_list.append(round(module_size * (n_percent / 100)))
                dummy_size = round(module_size * (dummy_percent / 100))
                self.dummy_size_list.append(dummy_size)
                self.m_list.append(round(dummy_size * (m_percent / 100)))

                self.r_list.append(torch.zeros(module_size))
                self.q_list.append(torch.zeros(dummy_size))


        self.criterion_value_list = [0]

        self.feature_indexes_list = [np.arange(n) for n in self.module_size_list]
        self.feature_mask_list = []
        self.dummy_feature_mask_list = []

    def _del_mask(self, model):
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Conv2d):
                if prune.is_pruned(module):
                    with torch.no_grad():
                        self._tensor_name = 'weight'
                        if hasattr(module, self._tensor_name):
                            delattr(module, self._tensor_name)
                        orig = module._parameters[self._tensor_name + "_orig"]
                        del module._parameters[self._tensor_name + "_orig"]
                        del module._buffers[self._tensor_name + "_mask"]
                        module._forward_pre_hooks = OrderedDict()
                        setattr(module, self._tensor_name, orig)
        self.feature_mask_list = []
        torch.cuda.empty_cache()

    def update_mask(self, model):
        feature_mask_list = []
        dummy_feature_mask_list = []
        layer_num = 0
        self._del_mask(model)
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Conv2d):
                feature_indexes_sample = random.sample(list(self.feature_indexes_list[layer_num]), self.n_list[layer_num])
                feature_mask = torch.zeros(self.module_size_list[layer_num])
                feature_mask[self.feature_indexes_list[layer_num]] = 1
                feature_mask[feature_indexes_sample] = 0
                feature_mask_list.append(feature_mask)

                dummy_feature_indexes = np.arange(self.dummy_size_list[layer_num])
                dummy_feature_indexes_sample = random.sample(list(dummy_feature_indexes), self.m_list[layer_num])
                dummy_feature_mask = torch.ones(self.dummy_size_list[layer_num])
                dummy_feature_mask[dummy_feature_indexes_sample] = 0
                dummy_feature_mask_list.append(dummy_feature_mask)
                layer_num += 1

        self.feature_mask_list = feature_mask_list

        self.dummy_feature_mask_list = dummy_feature_mask_list
        torch.cuda.empty_cache()

    def apply_mask(self, model):
        layer_num = 0
        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Conv2d):
                if (self.feature_mask_list):
                    weight_mask = torch.ones_like(module.weight)
                    weight_mask[~torch.tensor(self.feature_mask_list[layer_num], dtype=bool)] = 0
                    prune.custom_from_mask(module=module, name='weight', mask=weight_mask)
                elif (self.feature_indexes_list):
                    feature_mask = torch.zeros(self.module_size_list[layer_num])
                    feature_mask[self.feature_indexes_list[layer_num]] = 1
                    weight_mask = torch.ones_like(module.weight)
                    weight_mask[~torch.tensor(feature_mask, dtype=bool)] = 0
                    prune.custom_from_mask(module=module, name='weight', mask=weight_mask)
                layer_num += 1
        torch.cuda.empty_cache()

    def update_relevance(self, criterion_value):
        for i in range(len(self.r_list)):
            cur_delta = (criterion_value - np.mean(self.criterion_value_list))
            self.r_list[i] += cur_delta * self.feature_mask_list[i].flatten()
            self.q_list[i] += cur_delta * self.dummy_feature_mask_list[i].flatten()

        self.criterion_value_list.append(criterion_value)
        torch.cuda.empty_cache()

    def cut_by_probability(self):
        module_count = len(self.r_list)
        for i in range(module_count):
            mu = float(torch.mean(self.q_list[i]))
            sigma = float(torch.std(self.q_list[i]))
            probability_list = [norm.cdf(r, mu, sigma) for r in self.r_list[i][self.feature_indexes_list[i]]]
            prune_indexes = np.argsort(probability_list)[:self.n_list[i]]
            print(len(prune_indexes))
            self.feature_indexes_list[i] = np.array([idx for idx in self.feature_indexes_list[i] if idx not in self.feature_indexes_list[i][prune_indexes]])

            self.r_list[i] = torch.zeros(self.module_size_list[i])
            self.q_list[i] = torch.zeros(self.dummy_size_list[i])
        torch.cuda.empty_cache()




In [None]:
def training_loop(n_epochs, network, loss_fn, optimizer,scheduler, dl_train, dl_test, device):
    '''
    :param int n_epochs: Число итераций оптимизации
    :param torch.nn.Module network: Нейронная сеть
    :param Callable loss_fn: Функция потерь
    :param torch.nn.Optimizer optimizer: Оптимизатор
    :param torch.utils.data.DataLoader dl_train: Даталоадер для обучающей выборки
    :param torch.utils.data.DataLoader dl_test: Даталоадер для тестовой выборки
    :param torch.Device device: Устройство на котором будут происходить вычисления
    :returns: Списки значений функции потерь и точности на обучающей и тестовой выборках после каждой итерации
    '''
    loss_fn.to(device)
    train_losses, test_losses, train_accuracies, test_accuracies = [], [], [], []
    pbar = tqdm.tqdm(range(n_epochs), total=n_epochs, leave=False)
    for epoch in (pbar):

        # Итерация обучения сети
        for batch_idx, (images, labels) in enumerate(tqdm.tqdm(dl_train, total=len(dl_train), leave=False)):
            images = images.to(device)
            labels = labels.to(device)

            global_idx = (epoch * len(dl_train) + batch_idx)

            if global_idx < 21896 and global_idx %  test_relevance.n_percent == 0:
                test_relevance.update_mask(network)
            else:
                test_relevance._del_mask(network)
            test_relevance.apply_mask(network)

            optimizer.zero_grad()
            loss = loss_fn(network(images), labels)
            loss.backward()
            optimizer.step()
            scheduler.step()

            if global_idx < 21896 and global_idx % test_relevance.n_percent == 0:
                test_relevance.update_relevance(float(loss))

            if global_idx != 0 and global_idx % 5474 == 0:
                print('!' * 100)
                test_relevance._del_mask(network)
                count_parameters(network)
                test_relevance.cut_by_probability()
                test_relevance.apply_mask(network)
                count_parameters(network)
                print('!' * 100)

            torch.cuda.empty_cache()

        # Оцениваем качество модели каждые 3 итерации
        if epoch % 3 == 0 or epoch == n_epochs - 1:
            # Переводим сеть в инференс режим
            network.eval()

            # При тестировании сети нет необходимости считать градиенты, поэтому можно отключить автоматическое дифференцирование
            #   для ускорения операций
            with torch.no_grad():
                # Вычисление качества и функции потерь на обучающей выборке
                tmp_train_losses, tmp_train_accuracies = [], []
                for images, labels in tqdm.tqdm(dl_train, total=len(dl_train), leave=False):
                    images = images.to(device)
                    labels = labels.to(device)

                    test_relevance.apply_mask(network)
                    outputs  = network(images)
                    test_relevance._del_mask(network)

                    tmp_train_losses.append(loss_fn(outputs, labels))
                    tmp_train_accuracies.append(torch.sum(outputs.argmax(dim=1) == labels)/labels.shape[0])

                train_losses.append((sum(tmp_train_losses) / len(tmp_train_losses)).cpu())
                train_accuracies.append((sum(tmp_train_accuracies) / len(tmp_train_accuracies)).cpu() * 100)
                # Вычисление качества и функции потерь на тестовой выборке
                tmp_test_losses, tmp_test_accuracies = [], []
                for images, labels in tqdm.tqdm(dl_test, total=len(dl_test), leave=False):
                    images = images.to(device)
                    labels = labels.to(device)

                    test_relevance.apply_mask(network)
                    outputs  = network(images)
                    test_relevance._del_mask(network)

                    tmp_test_losses.append(loss_fn(outputs, labels))
                    tmp_test_accuracies.append(torch.sum(outputs.argmax(dim=1) == labels)/labels.shape[0])

                test_losses.append((sum(tmp_test_losses) / len(tmp_test_losses)).cpu())
                test_accuracies.append((sum(tmp_test_accuracies) / len(tmp_test_accuracies)).cpu() * 100)

            pbar.set_description(
                'Loss (Train/Test): {0:.3f}/{1:.3f}. Accuracy, % (Train/Test): {2:.2f}/{3:.2f}\n'.format(
                    train_losses[-1], test_losses[-1], train_accuracies[-1], test_accuracies[-1]
                )
            )

    return train_losses, test_losses, train_accuracies, test_accuracies

## 20%

In [None]:
N_PERCENT = 5
M_PERCENT = 10
DUMMY_PARAMS_PERCENT = 90

train_func = partial(
    training_loop, n_epochs=70, loss_fn=torch.nn.CrossEntropyLoss(),
    dl_train=dl_train, dl_test=dl_test, device=DEVICE
)

In [None]:
conv_net = ConvNet()
conv_net.to(DEVICE)
optimizer = torch.optim.SGD(conv_net.parameters(), lr=0.05, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.9)

test_relevance = Relevance(conv_net, DUMMY_PARAMS_PERCENT, N_PERCENT, M_PERCENT)

In [None]:
train_losses, test_losses, train_accs, test_accs = train_func(
    network=conv_net,
    optimizer=optimizer,
    scheduler=scheduler
)

In [None]:
test_relevance.apply_mask(conv_net)
test_model(conv_net, dl_test)

In [None]:
test_relevance.apply_mask(conv_net)
count_parameters(conv_net)

In [None]:
import time
np.save(GLOBAL_PATH + f'/VGG_train_accuracy_prun_20_{time.strftime("%d.%m.%Y-%H:%M")}.npy', train_accs)
np.save(GLOBAL_PATH + f'/VGG_test_accuracy_prun_20_{time.strftime("%d.%m.%Y-%H:%M")}.npy', test_accs)

In [None]:
SMALL_SIZE = 16
MEDIUM_SIZE = 10
BIGGER_SIZE = 8
plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=SMALL_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title


fig, ax = plt.subplots(1, 1, figsize=(7, 4))
ax.plot(np.arange(len(train_accs)) * 3, train_accs, label="Точность на обучении", color='red', marker='.', linestyle='-.')
ax.plot(np.arange(len(test_accs)) * 3, test_accs, label="Точность на тесте", color='red', marker='*')

ax.set_xlabel("Номер эпохи")
ax.set_ylabel("Точность, %")

ax.grid(True)
ax.legend(loc='lower right')

# fig.text(
#     0.5, 0.5, 'Только для ознакомления',
#     fontsize=40, color='gray', alpha=0.6,
#     ha='center', va='center', rotation='30'
# )
fig.tight_layout()
plt.show()

## 40%

In [None]:
N_PERCENT = 10
M_PERCENT = 10
DUMMY_PARAMS_PERCENT = 90

train_func = partial(
    training_loop, n_epochs=70, loss_fn=torch.nn.CrossEntropyLoss(),
    dl_train=dl_train, dl_test=dl_test, device=DEVICE
)
conv_net = ConvNet()
conv_net.to(DEVICE)
optimizer = torch.optim.SGD(conv_net.parameters(), lr=0.05, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.9)

test_relevance = Relevance(conv_net, DUMMY_PARAMS_PERCENT, N_PERCENT, M_PERCENT)
train_losses, test_losses, train_accs, test_accs = train_func(
    network=conv_net,
    optimizer=optimizer,
    scheduler=scheduler
)

In [None]:
test_relevance.apply_mask(conv_net)
count_parameters(conv_net)

In [None]:
import time
np.save(GLOBAL_PATH + f'/VGG_train_accuracy_prun_40_{time.strftime("%d.%m.%Y-%H:%M")}.npy', train_accs)
np.save(GLOBAL_PATH + f'/VGG_test_accuracy_prun_40_{time.strftime("%d.%m.%Y-%H:%M")}.npy', test_accs)

In [None]:
SMALL_SIZE = 16
MEDIUM_SIZE = 10
BIGGER_SIZE = 8
plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=SMALL_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title


fig, ax = plt.subplots(1, 1, figsize=(7, 4))
ax.plot(np.arange(len(train_accs)) * 3, train_accs, label="Точность на обучении", color='red', marker='.', linestyle='-.')
ax.plot(np.arange(len(test_accs)) * 3, test_accs, label="Точность на тесте", color='red', marker='*')

ax.set_xlabel("Номер эпохи")
ax.set_ylabel("Точность, %")

ax.grid(True)
ax.legend(loc='lower right')

# fig.text(
#     0.5, 0.5, 'Только для ознакомления',
#     fontsize=40, color='gray', alpha=0.6,
#     ha='center', va='center', rotation='30'
# )
fig.tight_layout()
plt.show()

## 60%

In [None]:
N_PERCENT = 15
M_PERCENT = 10
DUMMY_PARAMS_PERCENT = 90

train_func = partial(
    training_loop, n_epochs=70, loss_fn=torch.nn.CrossEntropyLoss(),
    dl_train=dl_train, dl_test=dl_test, device=DEVICE
)
conv_net = ConvNet()
conv_net.to(DEVICE)
optimizer = torch.optim.SGD(conv_net.parameters(), lr=0.05, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.9)

test_relevance = Relevance(conv_net, DUMMY_PARAMS_PERCENT, N_PERCENT, M_PERCENT)
train_losses, test_losses, train_accs, test_accs = train_func(
    network=conv_net,
    optimizer=optimizer,
    scheduler=scheduler
)

In [None]:
test_relevance.apply_mask(conv_net)
test_model(conv_net, dl_test)

In [None]:
test_relevance.apply_mask(conv_net)
count_parameters(conv_net)

In [None]:
import time
np.save(GLOBAL_PATH + f'/VGG_train_accuracy_prun_60_{time.strftime("%d.%m.%Y-%H:%M")}.npy', train_accs)
np.save(GLOBAL_PATH + f'/VGG_test_accuracy_prun_60_{time.strftime("%d.%m.%Y-%H:%M")}.npy', test_accs)

In [None]:
SMALL_SIZE = 16
MEDIUM_SIZE = 10
BIGGER_SIZE = 8
plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=SMALL_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title


fig, ax = plt.subplots(1, 1, figsize=(7, 4))
ax.plot(np.arange(len(train_accs)) * 3, train_accs, label="Точность на обучении", color='red', marker='.', linestyle='-.')
ax.plot(np.arange(len(test_accs)) * 3, test_accs, label="Точность на тесте", color='red', marker='*')

ax.set_xlabel("Номер эпохи")
ax.set_ylabel("Точность, %")

ax.grid(True)
ax.legend(loc='lower right')

# fig.text(
#     0.5, 0.5, 'Только для ознакомления',
#     fontsize=40, color='gray', alpha=0.6,
#     ha='center', va='center', rotation='30'
# )
fig.tight_layout()
plt.show()