In [None]:
import numpy as np
import pandas as pd
import plotly.express as px
import seaborn as sns

import torch
import math
import random
from torch.nn.utils import prune
import tqdm.notebook as tqdm
from functools import partial
import torchvision
from collections import OrderedDict
import pandas as pd
import torch.nn.functional as F

import matplotlib.pyplot as plt
%matplotlib inline

def set_global_seed(seed: int) -> None:
    """
    Set global seed for reproducibility.
    """

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

set_global_seed(42)

In [None]:
BATCH_SIZE = 128
DEVICE = torch.device('cpu')
if torch.cuda.is_available():
    DEVICE = torch.device('cuda', 0)

print(type(DEVICE), DEVICE)

## Загрузка и обработка данных

In [None]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomCrop(32, 4),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=(0.4914, 0.4822, 0.4465),
        std=(0.2023, 0.1994, 0.2010)
    )
])

In [None]:
ds_train = torchvision.datasets.CIFAR10(
    root='./', train=True, transform=transform, download=True
)
ds_test = torchvision.datasets.CIFAR10(
    root='./', train=False,
    transform=torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=(0.4914, 0.4822, 0.4465),
        std=(0.2023, 0.1994, 0.2010)
    )
    ]),
    download=True
)
ds_train, ds_test

In [None]:
tmean, tstd = transform.transforms[-1].mean, transform.transforms[-1].std
tmean, tstd = np.array(tmean), np.array(tstd)

inverse_transform = torchvision.transforms.Compose([
    torchvision.transforms.Normalize(
        mean=-tmean / tstd,
        std=1.0 / tstd
    ),
    torchvision.transforms.ToPILImage()
])

In [None]:
fig, axes = plt.subplots(2, 5, figsize=(13, 6))

for idx, ds in enumerate((ds_train, ds_test)):
    for jdx, kdx in enumerate(np.random.randint(0, len(ds), size=5)):
        image, label = ds[kdx]
        axes[idx, jdx].imshow(inverse_transform(image))
        axes[idx, jdx].set_title(f'Метка: {label} -> {ds.classes[label]}')

axes[0, 0].set_ylabel('Обучающая выборка')
axes[1, 0].set_ylabel('Тестовая выборка')

fig.tight_layout()
plt.show()

In [None]:
dl_train = torch.utils.data.DataLoader(
    dataset=ds_train, batch_size=BATCH_SIZE,
    num_workers=2, shuffle=True
)
dl_test = torch.utils.data.DataLoader(
    dataset=ds_test, batch_size=BATCH_SIZE,
    num_workers=2, shuffle=False
)

In [None]:
class BasicBlock(torch.nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = torch.nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = torch.nn.BatchNorm2d(planes)
        self.conv2 = torch.nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = torch.nn.BatchNorm2d(planes)

        self.shortcut = torch.nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = torch.nn.Sequential(
                torch.nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                torch.nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(torch.nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = torch.nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = torch.nn.BatchNorm2d(planes)
        self.conv2 = torch.nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = torch.nn.BatchNorm2d(planes)
        self.conv3 = torch.nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = torch.nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = torch.nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = torch.nn.Sequential(
                torch.nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                torch.nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(torch.nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = torch.nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = torch.nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return torch.nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])


def ResNet34():
    return ResNet(BasicBlock, [3, 4, 6, 3])


def ResNet50():
    return ResNet(Bottleneck, [3, 4, 6, 3])


def ResNet101():
    return ResNet(Bottleneck, [3, 4, 23, 3])


def ResNet152():
    return ResNet(Bottleneck, [3, 8, 36, 3])


def test():
    net = ResNet18()
    y = net(torch.randn(1, 3, 32, 32))
    print(y.size())

test()

In [None]:
conv_net = ResNet18()
conv_net.load_state_dict(torch.load(GLOBAL_PATH + 'Models/ResNet18_27.11.2022-21_53.pth', map_location=DEVICE))

In [None]:
def test_model(model, train_dataloader) -> float:
    model.to(DEVICE)
    model.eval()
    with torch.no_grad():
        train_accuracies = []
        for images, labels in tqdm.tqdm(train_dataloader, total=len(train_dataloader), leave=False):
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)
            outputs  = model(images)
            train_accuracies.append(torch.sum(outputs.argmax(dim=1) == labels)/labels.shape[0])
        return float((sum(train_accuracies) / len(train_accuracies)).cpu() * 100)

In [None]:
test_model(conv_net, dl_test)

In [None]:
def count_parameters(net) -> int:
  count_params = 0
  count_zero_params = 0
  for name, module in net.named_modules():
      if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
          count_params += module.weight.numel()
          count_zero_params += torch.sum(module.weight == 0)
          sparsity = 100. * float(torch.sum(module.weight == 0)) / float(module.weight.nelement())
          print(f'Sparsity in {name}.weight with {module.weight.numel()} parameters: {sparsity:0.3f}%')
  print(f'Global sparsity: {100. * (float(count_zero_params) / float(count_params)):0.3f}%')
  return count_params

In [None]:
count_parameters(conv_net)

## Euclidean norm prune unstructured




In [None]:
conv_net = ResNet18()
conv_net.load_state_dict(torch.load(GLOBAL_PATH + 'Models/ResNet18_27.11.2022-21_53.pth', map_location=DEVICE))

In [None]:
def prune_euclidean_norm_unstructured(net, amount) -> torch.tensor:
    modules_mask_list = []
    for name, module in conv_net.named_modules():
          if isinstance(module, torch.nn.Conv2d):
              dist_matrix = torch.cdist(module.weight.flatten(start_dim=1),
                                        module.weight.flatten(start_dim=1)
                                      ).cpu().detach().numpy()
              dist_matrix[np.tril_indices(len(dist_matrix))] = np.nan

              module_norm = np.nanmin(dist_matrix[:-1], axis=1)

              prune_mask = np.full(len(module_norm) + 1, True)
              prune_mask[np.argsort(module_norm)[:round(len(prune_mask) * amount)]] = False
              prune_mask = torch.tensor(prune_mask)

              mask = torch.ones_like(module.weight)
              mask[~prune_mask] = 0
              prune.custom_from_mask(module=module, name='weight', mask=mask)
              filters_count = len(module.weight)
              cur_module_df = pd.DataFrame(dist_matrix,
                                           columns=np.arange(filters_count))
              plt.rc('axes', labelsize=18)    # fontsize of the x and y labels
              plt.rc('xtick', labelsize=14)    # fontsize of the tick labels
              plt.rc('ytick', labelsize=14)    # fontsize of the tick labels
              fig, ax = plt.subplots(figsize=(17, 10))
              sns.heatmap(ax=ax, data=cur_module_df,
                  square=True,
                  cmap='GnBu_r',
                  cbar_kws={'label': 'Евклидово расстояние'})
              ax.set_xlabel('Номер фильтра')
              ax.set_ylabel('Номер фильтра')
              new_ticks = range(0, filters_count + 1, filters_count // 4)
              plt.xticks(new_ticks, new_ticks, rotation ='horizontal')
              plt.yticks(new_ticks, new_ticks, rotation ='horizontal')
              plt.show()

              modules_mask_list.append(prune_mask)


In [None]:
prune_euclidean_norm_unstructured(conv_net, 0.2)

In [None]:
count_parameters(conv_net)

In [None]:
test_model(conv_net, dl_test)

In [None]:
def apply_mask(net, buffer):
    for name, module in conv_net.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            buffer[name + '.weight_mask']
            prune.custom_from_mask(module=module, name='weight', mask=buffer[name + '.weight_mask'])
    torch.cuda.empty_cache()

In [None]:
apply_mask(conv_net, dict(conv_net.named_buffers()))

In [None]:
def training_loop(n_epochs, network, loss_fn, optimizer,scheduler, dl_train, dl_test, device):
    '''
    :param int n_epochs: Число итераций оптимизации
    :param torch.nn.Module network: Нейронная сеть
    :param Callable loss_fn: Функция потерь
    :param torch.nn.Optimizer optimizer: Оптимизатор
    :param torch.utils.data.DataLoader dl_train: Даталоадер для обучающей выборки
    :param torch.utils.data.DataLoader dl_test: Даталоадер для тестовой выборки
    :param torch.Device device: Устройство на котором будут происходить вычисления
    :returns: Списки значений функции потерь и точности на обучающей и тестовой выборках после каждой итерации
    '''
    buffer = dict(network.named_buffers())
    loss_fn.to(device)
    train_losses, test_losses, train_accuracies, test_accuracies = [], [], [], []
    pbar = tqdm.tqdm(range(n_epochs), total=n_epochs, leave=False)
    for epoch in (pbar):
        network.train()

        # Итерация обучения сети
        tmp_train_losses, tmp_train_accuracies = [], []
        for images, labels in tqdm.tqdm(dl_train, total=len(dl_train), leave=False):
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs  = network(images)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()

            apply_mask(network, buffer)

            tmp_train_losses.append(loss)
            tmp_train_accuracies.append(torch.sum(outputs.argmax(dim=1) == labels)/labels.shape[0])

        train_losses.append((sum(tmp_train_losses) / len(tmp_train_losses)).cpu())
        train_accuracies.append((sum(tmp_train_accuracies) / len(tmp_train_accuracies)).cpu() * 100)
        # Оцениваем качество модели каждые 3 итерации
        if epoch % 1 == 0 or epoch == n_epochs - 1:
            # Переводим сеть в инференс режим
            network.eval()

            # При тестировании сети нет необходимости считать градиенты, поэтому можно отключить автоматическое дифференцирование
            #   для ускорения операций
            with torch.no_grad():
                tmp_test_losses, tmp_test_accuracies = [], []
                for images, labels in tqdm.tqdm(dl_test, total=len(dl_test), leave=False):
                    images = images.to(device)
                    labels = labels.to(device)

                    outputs  = network(images)

                    tmp_test_losses.append(loss_fn(outputs, labels))
                    tmp_test_accuracies.append(torch.sum(outputs.argmax(dim=1) == labels)/labels.shape[0])

                test_losses.append((sum(tmp_test_losses) / len(tmp_test_losses)).cpu())
                test_accuracies.append((sum(tmp_test_accuracies) / len(tmp_test_accuracies)).cpu() * 100)

            pbar.set_description(
                'Loss (Train/Test): {0:.3f}/{1:.3f}. Accuracy, % (Train/Test): {2:.2f}/{3:.2f}\n'.format(
                    train_losses[-1], test_losses[-1], train_accuracies[-1], test_accuracies[-1]
                )
            )

    return train_losses, test_losses, train_accuracies, test_accuracies

## 20% L1 norm prune unstructured


In [None]:
conv_net = ResNet18()
conv_net.load_state_dict(torch.load(GLOBAL_PATH + 'Models/ResNet18_27.11.2022-21_53.pth', map_location=DEVICE))

In [None]:
prune_euclidean_norm_unstructured(conv_net, 0.2)

In [None]:
count_parameters(conv_net)

In [None]:
test_model(conv_net, dl_test)

In [None]:
train_func = partial(
    training_loop, n_epochs=2, loss_fn=torch.nn.CrossEntropyLoss(),
    dl_train=dl_train, dl_test=dl_test, device=DEVICE
)

In [None]:
conv_net.to(DEVICE);

In [None]:
optimizer = torch.optim.SGD(conv_net.parameters(), lr=0.05 * 2e-1, momentum=0.9, weight_decay=5e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=150, gamma=0.80)

In [None]:
train_losses, test_losses, train_accs, test_accs = train_func(
    network=conv_net,
    optimizer=optimizer,
    scheduler=scheduler
)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
ax.plot(train_accs, label="Точность на обучении", color='red', marker='.', linestyle='-.')
ax.plot(test_accs, label="Точность на тесте", color='red', marker='*')

ax.set_xlabel("Номер эпохи")
ax.set_ylabel("$\%$")

ax.grid(True)
ax.legend()

fig.tight_layout()
plt.show()

In [None]:
count_parameters(conv_net)

In [None]:
test_model(conv_net, dl_test)

## 40% L1 norm prune unstructured


In [None]:
conv_net = ResNet18()
conv_net.load_state_dict(torch.load(GLOBAL_PATH + 'Models/ResNet18_27.11.2022-21_53.pth', map_location=DEVICE))

In [None]:
prune_euclidean_norm_unstructured(conv_net, 0.4)

In [None]:
count_parameters(conv_net)

In [None]:
test_model(conv_net, dl_test)

In [None]:
train_func = partial(
    training_loop, n_epochs=2, loss_fn=torch.nn.CrossEntropyLoss(),
    dl_train=dl_train, dl_test=dl_test, device=DEVICE
)

In [None]:
conv_net.to(DEVICE);

In [None]:
optimizer = torch.optim.SGD(conv_net.parameters(), lr=0.05 * 3e-1, momentum=0.9, weight_decay=5e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=150, gamma=0.8)

In [None]:
train_losses, test_losses, train_accs, test_accs = train_func(
    network=conv_net,
    optimizer=optimizer,
    scheduler=scheduler
)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
ax.plot(train_accs, label="Точность на обучении", color='red', marker='.', linestyle='-.')
ax.plot(test_accs, label="Точность на тесте", color='red', marker='*')

ax.set_xlabel("Номер эпохи")
ax.set_ylabel("$\%$")

ax.grid(True)
ax.legend()

fig.tight_layout()
plt.show()

In [None]:
count_parameters(conv_net)

In [None]:
test_model(conv_net, dl_test)

## 60% L1 norm prune unstructured


In [None]:
conv_net = ResNet18()
conv_net.load_state_dict(torch.load(GLOBAL_PATH + 'Models/ResNet18_27.11.2022-21_53.pth', map_location=DEVICE))

In [None]:
prune_euclidean_norm_unstructured(conv_net, 0.6)

In [None]:
count_parameters(conv_net)

In [None]:
test_model(conv_net, dl_test)

In [None]:
train_func = partial(
    training_loop, n_epochs=2, loss_fn=torch.nn.CrossEntropyLoss(),
    dl_train=dl_train, dl_test=dl_test, device=DEVICE
)

In [None]:
conv_net.to(DEVICE);

In [None]:
optimizer = torch.optim.SGD(conv_net.parameters(), lr=0.05 * 5e-1, momentum=0.9, weight_decay=5e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.8)

In [None]:
train_losses, test_losses, train_accs, test_accs = train_func(
    network=conv_net,
    optimizer=optimizer,
    scheduler=scheduler
)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
ax.plot(train_accs, label="Точность на обучении", color='red', marker='.', linestyle='-.')
ax.plot(test_accs, label="Точность на тесте", color='red', marker='*')

ax.set_xlabel("Номер эпохи")
ax.set_ylabel("$\%$")

ax.grid(True)
ax.legend()

fig.tight_layout()
plt.show()

In [None]:
count_parameters(conv_net)

In [None]:
test_model(conv_net, dl_test)