# **Seminar 4 - Интерпретация Нейросетей**
*Naumov Anton (Any0019)*

*To contact me in telegram: @any0019*

## 1. Почему мы любим PyTorch - Modules

In [None]:
import torch
import torchvision

print(torch.__version__)
print(torchvision.__version__)

In [None]:
import torch
from torch import nn
import torch.nn.functional as F

In [None]:
nn.Module?

In [None]:
nn.Module??

### 1.1. Простой модуль

In [None]:
class MyLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        # Задаём через nn.Parameter, чтобы torch знал, что это обучаемые веса модуля
        self.weight = nn.Parameter(torch.randn(in_features, out_features))
        self.bias = nn.Parameter(torch.randn(out_features))

    def forward(self, input):
        return (input @ self.weight) + self.bias

In [None]:
m = MyLinear(4, 3)
sample_input = torch.randn(4)
print(sample_input, m(sample_input), sep="\n")

In [None]:
for name, param in m.named_parameters():
    print(f"Name ~ '{name}'", param, sep="\n", end="\n-----\n")

# # Аналогично, но только сами параметры, когда не нужны имена
# for param in m.parameters():
#     pass

In [None]:
net = nn.Sequential(
  MyLinear(4, 3),
  nn.ReLU(),
  MyLinear(3, 1)
)

sample_input = torch.randn(4)
print(sample_input, net(sample_input), sep="\n")

### 1.2. Нейросеть с SubModule

In [None]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.l0 = MyLinear(4, 3)
        self.l1 = MyLinear(3, 1)

    def forward(self, x):
        x = self.l0(x)
        x = F.relu(x)
        x = self.l1(x)
        return x

In [None]:
net = Net()
for name, child in net.named_children():
    print(f"Name ~ '{name}'", child, child.parameters(), sep="\n", end="\n-----\n")

### 1.3. Сложная нейросеть

In [None]:
class BigNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = MyLinear(5, 4)
        self.net = Net()
    
    def forward(self, x):
        return self.net(self.l1(x))

In [None]:
# pip install termcolor -- если не установлен пакет
from termcolor import colored

In [None]:
big_net = BigNet()


print(colored("Children:\n", color="red", attrs=["bold", "underline"]))
for name, child in big_net.named_children():
    print(f"Name ~ '{name}'", child, sep="\n", end="\n-----\n")
print("\n========\n")


print(colored("Modules:\n", color="red", attrs=["bold", "underline"]))
for name, module in big_net.named_modules():
    print(f"Name ~ '{name}'", module, sep="\n", end="\n-----\n")
print("\n========\n")


print(colored("Parameters:\n", color="red", attrs=["bold", "underline"]))
for name, param in big_net.named_parameters():
    print(f"Name ~ '{name}'", param, sep="\n", end="\n-----\n")

### 1.4. Динамические модули

In [None]:
class DynamicNet(nn.Module):
    def __init__(self, num_layers):
        super().__init__()
        # nn.ModuleList - список модулей
        self.linears = nn.ModuleList(
            [MyLinear(4, 4) for _ in range(num_layers)]
        )
        # nn.ModuleDict - словарь модулей
        self.activations = nn.ModuleDict({
          "relu": nn.ReLU(),
          "lrelu": nn.LeakyReLU()
        })
        self.final = MyLinear(4, 1)

    def forward(self, x, act):
        for linear in self.linears:
            x = linear(x)
        x = self.activations[act](x)
        x = self.final(x)
        return x

dynamic_net = DynamicNet(3)
sample_input = torch.randn(4)
output = dynamic_net(sample_input, "relu")
print(sample_input, output, sep="\n")

### 1.5. Состояние модуля (train vs eval)

In [None]:
class ModalModule(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        if self.training:
            # Добавляет константу, но только в .train() режиме
            return x + 1.
        else:
            return x

In [None]:
m = ModalModule()
x = torch.randn(4)

print(f"Input:\n{x}\n")

print(f"Training mode output:\n{m(x)}\n")

m.eval()
print(f"Evaluation mode output:\n{m(x)}")

### 1.6. Тип данных и вычислений

In [None]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
dtype = torch.float64

# Переместить все параметры модели на device
dynamic_net.to(device=device)
# dynamic_net.cpu()
# dynamic_net.cuda(int: ...)

# Изменить тип всех параметров модели
dynamic_net.to(dtype=dtype)
# dynamic_net = dynamic_net.double()  # float64
# dynamic_net = dynamic_net.float()  # float32
# dynamic_net = dynamic_net.half() # float16

sample_input = sample_input.to(device, dtype=dtype)

output = dynamic_net(sample_input, "relu")

print(sample_input, output, sep="\n")

### 1.7. Применение функций к модулям нейросети

In [None]:
# Сделаем функцию для инициализации весов модели
# обёртка no_grad() - используется тут, чтобы избежать подсчёта градиентов для этой операции
@torch.no_grad()
def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight, gain=1.0)  # см следующую ячейку
        m.bias.fill_(0.0)

# Применяем функцию рекурсивно ко всем модулям и подмодулям
dynamic_net.apply(init_weights)

Xavier / Glorot initialization

`Understanding the difficulty of training deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010) ([ссылка](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf))

$$\text{std} = \text{gain} \times \sqrt{\frac{2}{\text{n}_{in} + \text{n}_{out}}}$$

, где $n_{in}$ и $n_{out}$ - число входов и выходов слоя соответственно

Веса получаются следующим образом: $w \sim \mathcal{N}(0, std^2)$

### 1.8. Сохранение и загрузка модели

In [None]:
big_net = BigNet()

In [None]:
# Словарь со всеми весами модели
big_net.state_dict()

In [None]:
# Сохраняем state_dict нашей модели
torch.save(
    big_net.state_dict(),
    "net.pt",
)

In [None]:
# Инициализируем модель с таким же набором параметров
new_big_net = BigNet()

# Загружаем state_dict сохранённой модели в память
state_dict = torch.load("net.pt")
print(state_dict)

# Подгружаем state_dict в инициализированную модель
new_big_net.load_state_dict(state_dict)

**[Важно!] В общем случае - сохраняйте всю необходимую информацию о модели, оптимайзере, этапе обучения, ...**

Best practice $\longrightarrow$ подумайте, что вам будет нужно, если вы захотите, загрузив назад модель, продолжить её обучение:
* обучаемые параметры
* оптимизатор
* шедулер
* какие графики рисуете
* на какой эпохе находитесь
* ...

In [None]:
# К примеру:

# torch.save(
#     {
#         "epoch": epoch,
#         "model_state_dict": model.state_dict(),
#         "optimizer_state_dict": optimizer.state_dict(),
#         "scheduler_state_dict": scheduler.state_dict(),
#         "losses": losses,
#     },
#     chkp_path,
# )

### 1.9. Буфферы

In [None]:
nn.Module.register_buffer?

In [None]:
class RunningMean(nn.Module):
    def __init__(self, num_features, momentum=0.9):
        super().__init__()
        self.momentum = momentum
        # регистрируем буфер - параметр модели, но не обучаемый
        self.register_buffer(
            "mean",
            torch.zeros(num_features),
            persistent=True,  # содержится ли в state_dict модели?
        )

    def forward(self, x):
        self.mean = self.momentum * self.mean + (1.0 - self.momentum) * x
        return self.mean

In [None]:
m = RunningMean(4)
for _ in range(10):
    input = torch.randn(4)
    m(input)

print(m.state_dict())

In [None]:
print(colored("Parameters:\n", color="red", attrs=["bold", "underline"]))
for name, param in m.named_parameters():
    print(f"Name ~ '{name}'", param, sep="\n", end="\n-----\n")
print("\n========\n")


print(colored("Buffers:\n", color="red", attrs=["bold", "underline"]))
for name, buffer in m.named_buffers():
    print(f"Name ~ '{name}'", buffer, sep="\n", end="\n-----\n")

### 1.10. Инициализация

Все параметры и floating point буфферы инициализируются на этапе инициализации модуля:
* Тип ~ `param.float()` или `param.to(dtype=torch.float32)`.
* Устройство ~ `param.cpu()`, или `param.to("cpu")`, или `param.to(torch.device("cpu"))`.
* Инициализация значений ~ схема, соответствующая исторически предпочитаемой инициализацией для данного вида слоёв.

In [None]:
# Инициализировать на другом устройстве сразу
m = nn.Linear(5, 3, device='cuda')

# Инициализировать другим типом данных сразу
m = nn.Linear(5, 3, dtype=torch.half)

# Пропустить стандартную инициализацию и провести кастомную (для примера ортогональную)
m = nn.Linear(5, 3)
nn.init.orthogonal_(m.weight)

## 2. Forward / Backward hooks

Hooks - способ взаимодействия с `torch.Tensor` и/или `nn.Module` для получения и модификации входов / выходов / градиентов в момент их прохода через forward / backward pass.

### 2.1. Module level hooks

In [None]:
def forward_pre_hook(m, inputs):
    # Исполняется перед выполнением forward на соответствующем элементе
    # Может изменить входы в forward
    print(
        colored("Froward pre hook", color="red", attrs=["bold", "underline"]),
        inputs,
        " - Sizes ~ [" + ", ".join([str(el.shape) for el in inputs if el is not None]) + "]",
        sep="\n",
        end="\n\n",
    )
    return None  # None - не менять inputs, либо new_inputs

def forward_hook(m, inputs, output):
    # Исполняется после выполнения forward на соответствующем элементе
    # Может изменить выход из forward
    print(
        colored("Froward hook", color="red", attrs=["bold", "underline"]),
        inputs,
        " - Sizes ~ [" + ", ".join([str(el.shape) for el in inputs if el is not None]) + "]",
        output,
        f" - Size ~ {output.shape}",
        sep="\n",
        end="\n\n",
    )
    return None  # None - не менять output, либо new_output

def backward_hook(m, grad_inputs, grad_outputs):
    # Исполняется после выполнения backward на соответствующем элементе
    # Может изменить grad_inputs (выход на этапе backward)
    print(
        colored("Backward hook", color="red", attrs=["bold", "underline"]),
        grad_inputs,
        " - Sizes ~ [" + ", ".join([str(el.shape) for el in grad_inputs if el is not None]) + "]",
        grad_outputs,
        " - Sizes ~ [" + ", ".join([str(el.shape) for el in grad_outputs if el is not None]) + "]",
        sep="\n",
        end="\n\n",
    )
    return None  # None - не менять grad_inputs, либо new_grad_inputs

In [None]:
m = nn.Linear(4, 1)

fp_handle = m.register_forward_pre_hook(forward_pre_hook)
f_handle = m.register_forward_hook(forward_hook)
# b_handle = m.register_backward_hook(backward_hook)  # --> deprecated
b_handle = m.register_full_backward_hook(backward_hook)

In [None]:
sample_input = torch.randn(3, 4)
sample_input.requires_grad = True

print("Input", sample_input, sep="\n", end="\n\n")

out = m(sample_input)

In [None]:
out.backward(torch.ones_like(out))

In [None]:
fp_handle.remove()
f_handle.remove()
b_handle.remove()

In [None]:
out = m(sample_input)
out.backward(torch.ones_like(out))

### 2.2. Tensor level hooks

In [None]:
def tensor_hook(grad):
    print(
        colored("Tensor backward hook", color="red", attrs=["bold", "underline"]),
        grad,
        f" - Size ~ {grad.shape}",
        sep="\n",
        end="\n\n",
    )
    return None  # None - не менять grad, либо new_grad

In [None]:
m = nn.Linear(4, 1)

w_t_handle = m.weight.register_hook(tensor_hook)
b_t_handle = m.bias.register_hook(tensor_hook)

In [None]:
sample_input = torch.randn(3, 4)
sample_input.requires_grad = True

print("Input", sample_input, sep="\n", end="\n\n")

out = m(sample_input)
out.backward(torch.ones_like(out))

In [None]:
m.bias.grad, m.weight.grad

In [None]:
w_t_handle.remove()
b_t_handle.remove()

In [None]:
out = m(sample_input)
out.backward(torch.ones_like(out))

### 2.3. Как достать и сохранить что-то из модели?

```python
from collections import defaultdict

hook_data = defaultdict(list)

def hook(*args):
    # берём из global scope-а перменную, созданную раньше
    global hook_data
    
    ...
    for key, value in ...:
        hook_data[key].append(value)
    
    return None
```

## 3. Интерпретация нейросетей

Главный вопрос интерпретации - **почему нейросеть повела себя так, как повела?**

* Почему ответ был именно такой в конкретном случае?
* Что нужно подать на вход, чтобы получить подобный ответ?
* На что сильнее всего смотрит нейросеть, принимая решение?
* ...

Все стандартные подходы интерпретации из классического ML так же будут работать и с нейросетями (рассматривая нейросеть как функцию от входов как отдельных переменных), но при этом многие из них часто ориентированы на рассмотрение важности одной или нескольких фичей, а в случае с нейросетями часто входы будут иметь тысячи или даже миллионы фичей на входе (к пр., картинка 1920 x 1080 x 3 пикселей), что выйдет не очень хорошо:
* значимость одной фичи маленькая
* виды данных имеют свои особенности и зависимости
* виды моделей имеют свои подходы к их анализу

Сегодня мы будем рассматривать подходы для анализа нейросетей (на примере свёрточных нейросетей), подходы глобально можно разделить на группы по нескольким признакам:
1. Анализ данных и зоны видимости при активации нейросетей
2. Attribution - изучаем какая часть входа(ов) отвечает за активацию нейросети
3. Feature visualization - подбираем изображения, наиболее соответствующие ожиданиям нейросети

### 3.1. Подготовим модель и данные

In [None]:
%pip install --upgrade torchvision==0.14.0

In [None]:
#!:bash
python3 --version
python3 -c "import torchvision; import torch; print(f'torch ~ {torch.__version__}\ntorchvision ~ {torchvision.__version__}')"

In [None]:
import torch
import torchvision

from importlib import reload
reload(torch)
reload(torchvision)
print(torch.__version__, torchvision.__version__, sep="\n")

In [None]:
from torchvision import datasets
import os

In [None]:
dataset_path = "../Sem3 - DL tricks/data"
original_train_ds = datasets.STL10(root=dataset_path, split="train", download=True)
original_val_ds = datasets.STL10(root=dataset_path, split="test", download=True)
classes = original_train_ds.classes

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

h = 4
w = 8
fig, ax = plt.subplots(h, w, figsize=(30, 15))

fig.suptitle(f"all classes ~ [{', '.join(classes)}]", y=0.85 + 0.02*h)
for i in range(h * w):
    plt.subplot(h, w, i+1)
    img, cl = original_train_ds[i]
    plt.imshow(img)
    plt.xticks([])
    plt.yticks([])
    plt.title(f"{cl} ~ {classes[cl]}")
plt.show()

In [None]:
from torchvision import models
from torchvision.models import ResNet18_Weights
from torch import nn

def get_model_and_transforms():
    weights = models.ResNet18_Weights.DEFAULT
    model = models.resnet18(weights=weights, progress=True)

    # замораживаем первые слои
    for name, param in model.named_parameters():
        if not name.startswith("layer4"):
            param.requires_grad = False

    # заменяем классификатор
    model.fc = nn.Linear(model.fc.in_features, len(classes))

    transforms = weights.transforms()
    
    return model, transforms

model, transforms = get_model_and_transforms()

print(model, transforms, sep="\n=========\n")

In [None]:
import numpy as np
from termcolor import colored

def beautiful_int(i):
    i = str(i)
    return ".".join(reversed([i[max(j, 0):j+3] for j in range(len(i) - 3, -3, -3)]))

# Counting how many parameters does our model have
def model_num_params(model, verbose_all=True, verbose_only_learnable=False):
    sum_params = 0
    sum_learnable_params = 0
    for param in model.named_parameters():
        num_params = np.prod(param[1].shape)
        if verbose_all or (verbose_only_learnable and param[1].requires_grad):
            print(
                colored(
                    '{: <42} ~  {: <9} params ~ grad: {}'.format(
                        param[0],
                        beautiful_int(num_params),
                        param[1].requires_grad,
                    ),
                    {True: "green", False: "red"}[param[1].requires_grad],
                )
            )
        sum_params += num_params
        if param[1].requires_grad:
            sum_learnable_params += num_params
    print(
        f'\nIn total:\n  - {beautiful_int(sum_params)} params\n  - {beautiful_int(sum_learnable_params)} learnable params'
    )
    return sum_params, sum_learnable_params


sum_params, sum_learnable_params = model_num_params(model)

In [None]:
from torchvision import transforms as tr

mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

train_transform = tr.Compose([
    tr.Resize(size=(256, 256)),
    tr.RandomRotation(degrees=(-10, 10)),
    tr.RandomCrop(size=(224, 224)),
    tr.RandomHorizontalFlip(),
    tr.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 0.5)),
    tr.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
    tr.ToTensor(),
    tr.Normalize(mean=mean, std=std),
])

val_transform = torchvision.transforms.Compose([
    tr.Resize(size=(224, 224)),
    tr.ToTensor(),
    tr.Normalize(mean=mean, std=std),
])

In [None]:
def de_normalize(img):
    img = img.detach().numpy().transpose((1, 2, 0))
    return img * std + mean


img_ind = 0

fig, ax = plt.subplots(1, 3, figsize=(20, 7))

plt.subplot(131)
plt.imshow(original_train_ds[img_ind][0])
plt.title("Before")

plt.subplot(132)
plt.imshow(de_normalize(train_transform(original_train_ds[img_ind][0])))
plt.title("After train transform")

plt.subplot(133)
plt.imshow(de_normalize(val_transform(original_train_ds[img_ind][0])))
plt.title("After val transform")

plt.show()

In [None]:
from torch.utils.data import Dataset, DataLoader

class ImageDataset(Dataset):
    def __init__(self, dataset, transforms):
        super(ImageDataset).__init__()
        self.dataset = dataset
        self.transforms = transforms

    def __getitem__(self, index):
        img, cl = self.dataset[index]
        return self.transforms(img), cl
    
    def __len__(self):
        return len(self.dataset)

In [None]:
train_ds = ImageDataset(original_train_ds, train_transform)
val_ds = ImageDataset(original_val_ds, val_transform)

In [None]:
print(len(train_ds), len(val_ds))

In [None]:
batch_size = 128

train_loader = DataLoader(
    train_ds,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
)

val_loader = DataLoader(
    val_ds,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
)

### 3.2. Дообучим модельку на задачу классификации STL10 датасета

In [None]:
import torch.nn.functional as F
from tqdm.notebook import tqdm, trange
from IPython.display import clear_output


def create_model_and_optimizer(lr=1e-3, beta1=0.9, beta2=0.999, device="cpu"):
    model, _ = get_model_and_transforms()
    model = model.to(device)
    
    params = []
    for param in model.parameters():
        if param.requires_grad:
            params.append(param)
    
    optimizer = torch.optim.Adam(params, lr, [beta1, beta2])
    return model, optimizer


def train(model, optimizer, loader, criterion):
    model.train()
    losses_tr = []
    for images, targets in tqdm(loader):
        images = images.to(device)
        targets = targets.to(device)
        
        optimizer.zero_grad()
        out = model(images)
        loss = criterion(out, targets)
        
        loss.backward()
        optimizer.step()
        losses_tr.append(loss.item()) 
    
    return model, optimizer, np.mean(losses_tr)


def val(model, loader, criterion, metric_names=None):
    model.eval()
    losses_val = []
    if metric_names:
        metrics = {name: [] for name in metric_names}
    with torch.no_grad():
        for images, targets in tqdm(loader):
            images = images.to(device)
            targets = targets.to(device)
            out = model(images)
            loss = criterion(out, targets)
            losses_val.append(loss.item())
            
            if metric_names:
                if 'accuracy' in metrics:
                    _, pred_classes = torch.max(out, dim=-1)
                    metrics['accuracy'].append((pred_classes == targets).float().mean().item())
                if 'top2accuracy' in metrics:
                    preds = torch.argsort(out, dim=1, descending=True)
                    metrics['top2accuracy'].append(
                        np.mean([targets[i] in preds[i, :2] for i in range(len(targets))])
                    )
                if 'top3accuracy' in metrics:
                    preds = torch.argsort(out, dim=1, descending=True)
                    metrics['top3accuracy'].append(
                        np.mean([targets[i] in preds[i, :3] for i in range(len(targets))])
                    )
    
        if metric_names:
            for name in metrics:
                metrics[name] = np.mean(metrics[name])
    
    return np.mean(losses_val), metrics if metric_names else None


def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']
    

def learning_loop(model, optimizer, train_loader, val_loader, criterion, scheduler=None, min_lr=None, epochs=10, val_every=1, draw_every=1, metric_names=None):
    losses = {'train': [], 'val': []}
    lrs = []
    if metric_names:
        metrics = {name: [] for name in metric_names}

    for epoch in range(1, epochs+1):
        print(f'#{epoch}/{epochs}:')
        model, optimizer, loss = train(model, optimizer, train_loader, criterion)
        losses['train'].append(loss)

        if not (epoch % val_every):
            loss, metrics_ = val(model, val_loader, criterion, metric_names)
            losses['val'].append(loss)
            if metric_names:
                for name in metrics_:
                    metrics[name].append(metrics_[name])
            
            lrs.append(get_lr(optimizer))
            if scheduler:
                try:
                    scheduler.step()
                except:
                    scheduler.step(loss)

        if not (epoch % draw_every):
            clear_output(True)
            ww = 3 if metric_names else 2
            fig, ax = plt.subplots(1, ww, figsize=(20, 10))
            fig.suptitle(f'#{epoch}/{epochs}:')

            plt.subplot(1, ww, 1)
            plt.title('losses')
            plt.plot(losses['train'], 'r.-', label='train')
            plt.plot(losses['val'], 'g.-', label='val')
            plt.legend()
            
            plt.subplot(1, ww, 2)
            plt.title('learning rate')
            plt.plot(lrs, '.-', label='lr')
            plt.legend()
            
            if metric_names:
                plt.subplot(1, ww, 3)
                plt.title('additional metrics')
                for name in metric_names:
                    plt.plot(metrics[name], '.-', label=name)
                plt.legend()
            
            plt.show()
        
        if min_lr and get_lr(optimizer) <= min_lr:
            print(f'Learning process ended with early stop after epoch {epoch}')
            break
    
    return model, optimizer, losses, lrs, metrics if metric_names else None

In [None]:
%%time
NUM_EPOCHS = 30

device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

model, optimizer = create_model_and_optimizer(
    lr = 1e-4,
    device = device,
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, NUM_EPOCHS, eta_min=1e-6)

criterion = nn.CrossEntropyLoss()

model, optimizer, losses, lrs, metrics = learning_loop(
    model = model,
    optimizer = optimizer,
    train_loader = train_loader,
    val_loader = val_loader,
    criterion = criterion,
    scheduler = scheduler,
    epochs = NUM_EPOCHS,
    min_lr = None,
    metric_names = {'accuracy', 'top3accuracy'},
)

In [None]:
chkp_path = "./model.pt"

# Save
torch.save(
    {
        'epoch': NUM_EPOCHS,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'losses': losses,
    },
    chkp_path,
)


# Load
# checkpoint = torch.load(chkp_path)

# model, optimizer = create_model_and_optimizer(
#     lr = 1e-4,
#     device = device,
# )

# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, NUM_EPOCHS, eta_min=1e-6)

# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
# epoch = checkpoint['epoch']
# losses = checkpoint['losses']

In [None]:
def real_confusion_matrix(model, val_loader, class_labels, use_probs=False, normalize=True, round_size=4):
    model.eval()
    with torch.no_grad():
        n_classes = len(class_labels)
        conf_matrix = np.zeros((n_classes, n_classes))
        for i, (img, cl) in enumerate(tqdm(val_loader)):
            probs = model(img.to(device)).exp()
            if use_probs:
                for j in range(img.shape[0]):
                    for c in range(n_classes):
                        conf_matrix[cl[j].item(), c] += probs[j,c]
            else:
                _, pred_classes = torch.max(probs, 1)
                for j in range(img.shape[0]):
                    conf_matrix[cl[j].item(), pred_classes[j].item()] += 1.
        
        if normalize:
            conf_matrix /= conf_matrix.sum(1)
        
        fig = plt.figure(figsize=(18, 10))
        fig.suptitle(f'Confusion matrix (norm={normalize}, use_probs={use_probs})')
        ax = fig.add_subplot(111)
        cax = ax.matshow(conf_matrix.T)
        fig.colorbar(cax)
        
        @plt.FuncFormatter
        def fake_labels(x, pos):
            return class_labels[(int(x))] if x < len(class_labels) else "@"
        
        ax.xaxis.set_ticks(list(range(len(class_labels))))
        ax.xaxis.set_ticklabels(class_labels)
        ax.set_xlabel('predicted class')
        
        ax.yaxis.set_ticks(list(range(len(class_labels))))
        ax.yaxis.set_ticklabels(class_labels)
        ax.set_ylabel('true class')
        
        for x in range(conf_matrix.shape[0]):
            for y in range(conf_matrix.shape[1]):
                ax.text(x, y, round(conf_matrix[x,y], round_size), va='center', ha='center')
        
        
        plt.show()
        
        return conf_matrix

In [None]:
pcm = real_confusion_matrix(
    model,
    val_loader,
    classes,
    use_probs=True,
    normalize=True,
    round_size=5,
)

In [None]:
pcm = real_confusion_matrix(
    model,
    val_loader,
    classes,
    use_probs=False,
    normalize=True,
    round_size=5,
)

### 3.3. Что видит каждый слой нейросети?

Посмотрим как *примерно* выглядят входы каждого соответствующего слоя нейросети

In [None]:
# # Убрать hooks без handle:
# from collections import OrderedDict
# from typing import Dict, Callable

# module._forward_hooks: Dict[int, Callable] = OrderedDict()

In [None]:
layer_inputs = []
handles = []

def extract_input_pre_hook(m, inputs):
    global layer_inputs
    layer_inputs.append([m, inputs[0]])
    return None

conv_layers = []
model = model.to("cpu")

for name, module in model.named_modules():
    if isinstance(module, nn.Conv2d):
        if "downsample" not in name:
            conv_layers.append(module)
            handles.append(module.register_forward_pre_hook(extract_input_pre_hook))
    elif isinstance(module, (nn.MaxPool2d, nn.AvgPool2d)):
        conv_layers.append(module)
        handles.append(module.register_forward_pre_hook(extract_input_pre_hook))

In [None]:
conv_layers

In [None]:
def min_max_scale(img):
    img = img - img.min()
    return img / img.max()

img_ind = 4

img = val_ds[img_ind][0]
plt.imshow(min_max_scale(de_normalize(img)))
plt.show()

In [None]:
import numpy as np

def get_hw(s):
    h = int(np.round(np.sqrt(s)))
    w = s // h + (s % h > 0)
    return h, w

h, w = get_hw(conv_layers[0].weight.shape[0])

fig, ax = plt.subplots(figsize=(35, 35))
fig.suptitle("Layer 1 kernels", y=0.9)
for i, kernel in enumerate(conv_layers[0].weight):
    plt.subplot(h, w, i + 1)
    plt.imshow(
        min_max_scale(
            kernel.detach().numpy().transpose(1, 2, 0)
        )
    )
    plt.axis('off')

plt.show()

In [None]:
model.eval()
out = model(img[None, :, :, :])

In [None]:
for handle in handles:
    handle.remove()

In [None]:
model.eval()
with torch.no_grad():
    fmaps = [conv_layers[0](img.unsqueeze(0))]
    for module in conv_layers[1:]:
        fmaps.append(module(fmaps[-1]))

In [None]:
def visualize_layer_fmap(fmaps, layer_ind, show_first=None, h=None):
    actual = isinstance(fmaps[layer_ind], list)
    if actual:
        fmap = fmaps[layer_ind+1][1][0]
    else:
        fmap = fmaps[layer_ind][0]
    s = show_first or fmap.shape[0]
    
    if h is None:
        h, w = get_hw(s)
    else:
        w = s // h + (s % h > 0)
    fig, ax = plt.subplots(h, w, figsize=(w * 5, h * 5))
    fig.suptitle(f"Layer #{layer_ind} {'actual' if actual else 'approximate'} feature maps", y=0.9)
    i = 1
    for fmap_img in fmap:
        if show_first and i > show_first:
            break
        if fmap_img.sum() == 0:
            continue
        plt.subplot(h, w, i)
        plt.imshow(
            min_max_scale(fmap_img.detach()),
            cmap="gray",
        )
        plt.axis('off')
        i += 1

    plt.show()

In [None]:
visualize_layer_fmap(fmaps, 0, show_first=10, h=2)

In [None]:
visualize_layer_fmap(layer_inputs, 0, show_first=10, h=2)

In [None]:
visualize_layer_fmap(fmaps, 5, show_first=10, h=2)

In [None]:
visualize_layer_fmap(layer_inputs, 5, show_first=10, h=2)

In [None]:
visualize_layer_fmap(fmaps, 10, show_first=10, h=2)

In [None]:
visualize_layer_fmap(layer_inputs, 10, show_first=10, h=2)

Не слишком информативно, хотя и даёт понимание о сложности интерпретации

### 3.4. Анализ относительно датасета

#### 3.4.1. Картинки, сильнее всего активирующиеся на фичи

In [None]:
val_preds = []
val_conv_activations = []
val_pre_fc_preds = []
targets = []

def extract_input_hook(m, inputs, output):
    val_conv_activations.append(inputs[0].cpu().detach().numpy())
    val_pre_fc_preds.append(output.cpu().detach().numpy())
    return None

handle = model.avgpool.register_forward_hook(extract_input_hook)

model.eval()
model.to(device)
with torch.no_grad():
    for img, cl in tqdm(val_loader):
        out = model(img.to(device))
        val_preds.append(out.cpu().detach().numpy())
        targets.append(cl.cpu().detach().numpy())
model.cpu()

val_preds = np.concatenate(val_preds, axis=0)
val_pre_fc_preds = np.concatenate(val_pre_fc_preds, axis=0)
val_conv_activations = np.concatenate(val_conv_activations, axis=0)
targets = np.concatenate(targets, axis=0)

handle.remove()

In [None]:
val_preds.shape, val_pre_fc_preds.shape, val_conv_activations.shape, targets.shape

In [None]:
cl_to_ind = {cl: i for i, cl in enumerate(classes)}

In [None]:
selected_class = "cat"
ind = cl_to_ind[selected_class]

mask = (targets == ind).astype(bool)
pre_fc_weights = model.fc.weight.detach().numpy()[ind, :]

In [None]:
fig, ax = plt.subplots(figsize=(25, 5))
plt.plot(pre_fc_weights)
plt.show()

In [None]:
select_top = 4
best_features = np.argsort(-pre_fc_weights)
worst_features = best_features[-select_top:]
best_features = best_features[:select_top]

In [None]:
show_top = 5

fig, ax = plt.subplots(select_top * 2, show_top, figsize=(25, 10 * select_top))

for i, feature_ind in enumerate(best_features):
    best_img_inds = np.argsort(-val_pre_fc_preds[:, feature_ind, 0, 0])[:show_top]
    for j, img_ind in enumerate(best_img_inds):
        plt.subplot(select_top * 2, show_top, i * show_top + j + 1)
        plt.imshow(original_val_ds[img_ind][0])
        plt.title(f"Top #{i+1} best feature\nTop #{j+1} best image")

for i, feature_ind in enumerate(worst_features):
    worst_img_inds = np.argsort(-val_pre_fc_preds[:, feature_ind, 0, 0])[:show_top]
    for j, img_ind in enumerate(worst_img_inds):
        plt.subplot(select_top * 2, show_top, (select_top + i) * show_top + j + 1)
        plt.imshow(original_val_ds[img_ind][0])
        plt.title(f"Top #{select_top-i} worst feature\nTop #{j+1} best image")
        
plt.show()

#### 3.4.2. Как посчитать receptive_field нейрона на каком-то слое?

Картинка для понимания:

![receptive_field](receptive_field.jpg "Receptive Field")

Посчитаем относительно только одной из размерностей (H или W), для второй аналогично.

$$k_t - \text{Kernel size on layer t}$$
$$s_t - \text{Stride on layer t}$$
$$d_t - \text{Dilation on layer t}$$
$$r_t - \text{Receptive field of neuron on layer t}$$

1. $r_0 = 1$ - на 0-м слое каждый "нейрон" (пиксель) хранит в себе информацию только о самом себе
2. $r_1 = k_1$ - на 1-м слое каждый нейрон смотрит ровно на $k_1$ пикселей
3. $r_2 = (k_2 - 1) * s_1 + k_1$ - на 1-м слое один нейрон смотрит на $k_1$, а каждый следующий сдвинут от него на $s_1$, таких следующих нейронов $k_2 - 1$ штук, поэтому получаем такую формулу для 2-го слоя
4. Как перейти к общему случаю? $r_{t+1} = (k_{t+1} - 1) \cdot j_t + r_t \longrightarrow$ на t-ом слое один нейрон смотрит на $r_t$, а каждый следующий смещён от него на jump $j_t$ пикселей оригинального изображения, где $j_t = \prod_{i=1}^{t} s_i$
5. Как здесь участвует dilation? $r_{t+1} = ((k_{t+1} - 1) \cdot d_{t+1} + 1 - 1) \cdot j_t + r_t = (k_{t+1} - 1) \cdot d_{t+1} \cdot j_t + r_t\longrightarrow$ фактически увеличивает kernel size $k_{t}^* = (k_{t} - 1) \cdot d_{t} + 1$

**Итоговоая формула:**
$$ r_{t+1} = (k_{t+1} - 1) \cdot d_{t+1} \cdot j_t + r_t = \sum_{i=1}^{t+1} \Big( (k_{i} - 1) \cdot d_{i} \cdot \prod_{j=1}^{i} s_j \Big) + 1$$


In [None]:
conv_layers

In [None]:
def get_receptive_fields(conv_layers):
    def receptive_field(old_receptive_field, jump, kernel_size):
        return old_receptive_field + (kernel_size - 1) * jump
    
    res = [1]
    old_receptive_field = 1
    jump = 1
    for layer in conv_layers:
        if isinstance(layer, nn.Conv2d):
            k = layer.kernel_size[0]
            d = layer.dilation[0]
            s = layer.stride[0]
        elif isinstance(layer, (nn.MaxPool2d, nn.AvgPool2d)):
            k = layer.kernel_size
            d = layer.dilation
            s = layer.stride
        else:
            raise ValueError(f"Unknown layer type {type(layer)}")
        
        old_receptive_field = receptive_field(
            old_receptive_field,
            jump,
            (k - 1) * d + 1,
        )
        jump *= s
        res.append(old_receptive_field)
    
    return res


receptive_fields = get_receptive_fields(conv_layers)

for i, (layer_, from_, to_) in enumerate(zip(conv_layers, receptive_fields[:-1], receptive_fields[1:])):
    layer_type = str(type(layer_)).split(".")[-1][:-2]
    k = layer_.kernel_size
    s = layer_.stride
    d = layer_.dilation
    if isinstance(k, tuple):
        k = k[0]
        s = s[0]
        d = d[0]
    print(f"#{i: <2}: {layer_type: <9} - k={k} , s={s} , d={d}  ~  changed receptive field  ~  {from_: >3} --> {to_: <3}")

#### 3.4.3. Посмотрим на какие части изображений активируются какие слои

In [None]:
from functools import partial
from collections import defaultdict, OrderedDict

chosen_layers = [3, 6, 9, 12]
max_num_channels = 10

layer_outputs = defaultdict(list)

def extract_output_hook(m, inputs, output, layer_ind):
    layer_outputs[layer_ind].append(output[:, :max_num_channels, :, :].cpu().detach().numpy())
    return None

handles = []

for i, module in enumerate(conv_layers):
    if i in chosen_layers:
        assert module._forward_hooks == OrderedDict(), f"Delete previous hooks first\n{i}\n{module}\n{module._forward_hooks}"
        handles.append(
            module.register_forward_hook(
                partial(extract_output_hook, layer_ind=i)
            )
        )

In [None]:
model.eval()
model.to(device)
with torch.no_grad():
    for img, _ in tqdm(val_loader):
        out = model(img.to(device))
model.cpu()

for key, values in layer_outputs.items():
    layer_outputs[key] = np.concatenate(values, axis=0)

for handle in handles:
    handle.remove()

In [None]:
layer_outputs.keys()

In [None]:
layer_outputs[3].shape

In [None]:
# # Force clear everything up
# for module in conv_layers:
#     module._forward_hooks = OrderedDict()

# del layer_outputs

In [None]:
pad_size = [0]
j = 1

for module in conv_layers:
    p = module.padding
    s = module.stride
    if isinstance(p, tuple):
        p = p[0]
        s = s[0]
    pad_size.append(pad_size[-1] + p * j)
    j *= s
    
pad_size = pad_size[1:]
print(pad_size)

In [None]:
def find_best_examples(layer_ind, num_channels=5, num_examples=5):
    assert num_channels <= max_num_channels, "This amount of channels wasn't computed"
    res = np.zeros((num_channels, num_examples, 3), dtype=int)
    
    p = pad_size[layer_ind]
    r = receptive_fields[layer_ind + 1]
    h, w = layer_outputs[layer_ind].shape[2:]
    for i, channel_ind in enumerate(range(num_channels)):
        max_values = np.max(
            layer_outputs[layer_ind][:, channel_ind, :, :],
            axis=(1, 2),
        )
        
        best_images_inds = np.argsort(-max_values)[:num_examples]
        
        for j, image_ind in enumerate(best_images_inds):
            # На каком изображении из val датасета
            res[i, j, 0] = image_ind
            # Положение максимальной активации по осям H и W (с учётом смещённости за счёт паддинга)
            pos = np.argmax(layer_outputs[layer_ind][image_ind, channel_ind].reshape(-1), axis=0)
            res[i, j, 1] = pos // w - p
            res[i, j, 2] = pos % w - p
            
            assert (
                np.max(layer_outputs[layer_ind][image_ind, channel_ind]) ==
                layer_outputs[layer_ind][image_ind, channel_ind, res[i, j, 1] + p, res[i, j, 2] + p]
            ), (
                f"\n{layer_outputs[layer_ind][image_ind, channel_ind, res[i, j, 1], res[i, j, 2]]}"
                f"\n{np.max(layer_outputs[layer_ind][image_ind, channel_ind])}"
            )
    
    
    fig, ax = plt.subplots(num_channels, num_examples, figsize=(5 * num_examples, 5 * num_channels))
    fig.suptitle(f"Layer #{layer_ind} activations examples", y=0.9)
    
    for i, channel_ind in enumerate(range(num_channels)):
        for j, image_ind in enumerate(range(num_examples)):
            plt.subplot(num_channels, num_examples, i * num_examples + j + 1)
            img = np.clip(de_normalize(val_ds[res[channel_ind, image_ind, 0]][0]), 0, 1)
            hh, ww = res[channel_ind, image_ind, 1:]
            img = img[max(hh, 0): hh + r, max(ww, 0): ww + r, :]
            plt.imshow(img)
            plt.title(f"Channel #{i+1}\nImage #{j+1} with top activation")
            plt.axis('off')

    plt.show()

In [None]:
find_best_examples(layer_ind=3, num_channels=10, num_examples=8)

In [None]:
find_best_examples(layer_ind=6, num_channels=10, num_examples=8)

In [None]:
find_best_examples(layer_ind=9, num_channels=10, num_examples=8)

In [None]:
find_best_examples(layer_ind=12, num_channels=10, num_examples=8)

In [None]:
del layer_outputs

#### 3.4.4. Какая часть изображения ответственна за результат классификации?

In [None]:
img_ind = 4

img = val_ds[img_ind][0]
img_vis = min_max_scale(de_normalize(img))
plt.imshow(img_vis)
plt.show()

In [None]:
from torch.nn import functional as F

model.eval()
with torch.no_grad():
    out = model(img.unsqueeze(0))[0].detach()
    probs = F.softmax(out, dim=0).numpy()

pred_cl_ind = np.argmax(probs)

print("\n".join(list(map(lambda el: f"{el[0]: <10} ~ {str(round(el[1], 4))}", zip(classes, probs)))))

In [None]:
import cv2

assert np.allclose(
    probs,
    F.softmax(torch.tensor(val_preds[img_ind]), dim=0).numpy(),
)

relevance = model.fc.weight[pred_cl_ind].detach().numpy() * val_pre_fc_preds[img_ind, :, 0, 0]

# class activation map
cam = (val_conv_activations[img_ind] * relevance[:, None, None]).sum(0)

cam = cv2.resize(cam, val_transform.transforms[0].size)

cam = min_max_scale(cam)

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(20, 5))

plt.subplot(131)
plt.title("Image")
plt.imshow(img_vis)
plt.axis("off")

plt.subplot(132)
plt.title("Class Activation Map")
plt.imshow(cam)
plt.axis("off")

plt.subplot(133)
plt.title("Both")
plt.imshow(img_vis, alpha=0.5)
plt.imshow(cam, alpha=0.5)
plt.axis("off")

plt.show()

In [None]:
def get_cam(img_ind):
    img = val_ds[img_ind][0]
    img = min_max_scale(de_normalize(img))
    
    probs = F.softmax(torch.tensor(val_preds[img_ind]), dim=0).numpy()
    
    pred_cl_ind = np.argmax(probs)
    
    to_print = list(map(lambda el: f"{el[0]: <10} ~ {str(round(el[1], 4))}", zip(classes, probs)))
    to_print[targets[img_ind]] = colored(to_print[targets[img_ind]], "green")
    to_print[pred_cl_ind] = colored(to_print[pred_cl_ind], "red")  # , attrs=["bold", "underline"])
    print("\n".join(to_print))
    
    relevance = model.fc.weight[pred_cl_ind].detach().numpy() * val_pre_fc_preds[img_ind, :, 0, 0]

    # class activation map
    cam = (val_conv_activations[img_ind] * relevance[:, None, None]).sum(0)

    cam = cv2.resize(cam, val_transform.transforms[0].size)

    cam = min_max_scale(cam)
    
    fig, ax = plt.subplots(1, 3, figsize=(20, 5))

    plt.subplot(131)
    plt.title("Image")
    plt.imshow(img)
    plt.axis("off")

    plt.subplot(132)
    plt.title("Class Activation Map")
    plt.imshow(cam)
    plt.axis("off")

    plt.subplot(133)
    plt.title("Both")
    plt.imshow(img, alpha=0.5)
    plt.imshow(cam, alpha=0.5)
    plt.axis("off")

    plt.show()

In [None]:
get_cam(0)

In [None]:
get_cam(1)

In [None]:
get_cam(2)

In [None]:
get_cam(3)

In [None]:
val_probs = np.exp(val_preds - val_preds.max()) / np.exp(val_preds - val_preds.max()).sum(1, keepdims=True)

In [None]:
least_configdent = np.argsort(val_probs.max(1))[:5]
least_configdent

In [None]:
get_cam(least_configdent[0])

In [None]:
get_cam(least_configdent[1])

In [None]:
get_cam(least_configdent[3])

In [None]:
get_cam(least_configdent[4])

### 3.5. Как посмотреть что ищет слой нейросети, не по данным, а по самой модели?

**Feature Visualization by Optimization:**

Формально говоря, модель так же дефиренцируема по своим входам, так почему бы нам не "обучить" идеальный вход для конкретного нейрона / канала / слоя / выхода / ...?

In [None]:
for param in model.parameters():
    param.requires_grad = False

#### 3.5.1. Представление класса (Gradient based)

In [None]:
from IPython.display import clear_output

max_generation_epochs = 4000
alpha = 2.
alpha_reduce_each = 500
alpha_reduce_mul = 0.5
draw_each = 100
chosen_class = "cat"

class_ind = cl_to_ind[chosen_class]
img_shape = val_ds[0][0].shape
synthetic_image = torch.nn.Parameter(torch.rand(len(classes), *img_shape), requires_grad=True)
starting_image = synthetic_image[class_ind].detach().numpy().transpose(1, 2, 0)

generated_probs = defaultdict(list)
optimized_values = defaultdict(list)
xx = np.arange(len(classes))

model.eval()
model.to(device)
synthetic_image = synthetic_image.to(device)
synthetic_image.retain_grad()
for epoch in trange(max_generation_epochs):
    synthetic_image.grad = None
    out = model(synthetic_image)
    pred_class_logit = out[xx, xx]
    pred_class_logit.sum().backward()
    synthetic_image.data = synthetic_image.data + alpha * synthetic_image.grad
    
    with torch.no_grad():
        probs = F.softmax(out, dim=1)
        pred_class_prob = probs[xx, xx]
        for ind in xx:
            generated_probs[ind].append(pred_class_prob[ind].item())
            optimized_values[ind].append(pred_class_logit[ind].item())
        if (epoch + 1) % draw_each == 0:
            clear_output(True)
            fig = plt.subplots(2, 1, figsize=(20, 8))
            
            plt.subplot(211)
            for ind in xx:
                plt.plot(generated_probs[ind])
            plt.title("Probability of target class")
            plt.xlabel("epoch")
            plt.ylabel("probability")
            plt.yscale("log")
            
            plt.subplot(212)
            for ind in xx:
                plt.plot(optimized_values[ind])
            plt.title("Logit of target class")
            plt.xlabel("epoch")
            
            plt.show()
        if (epoch + 1) % alpha_reduce_each == 0:
            alpha *= alpha_reduce_mul

model.cpu()
synthetic_image = synthetic_image.cpu()

fig, ax = plt.subplots(1, 2, figsize=(10, 5))

plt.subplot(121)
plt.imshow(starting_image)
plt.title("Starting image")
plt.axis("off")

plt.subplot(122)
plt.imshow(min_max_scale(synthetic_image[class_ind].detach().numpy().transpose(1, 2, 0)))
plt.title("Resulting image")
plt.axis("off")

plt.show()

In [None]:
fig, ax = plt.subplots(2, 5, figsize=(25, 12))

for ind in range(10):
    plt.subplot(2, 5, ind+1)
    plt.imshow(min_max_scale(synthetic_image[ind].detach().numpy().transpose(1, 2, 0)))
    plt.title(classes[ind])
    plt.axis("off")

plt.show()

Не особенно на что-то похоже...

Потому что мы никак не ограничиваем изображение на то, чтобы оно было хоть сколько-то похоже на изображение в нашем понимании.

Добавим член, отвечающий за L2-регуляризацию изображения.

In [None]:
from IPython.display import clear_output

max_generation_epochs = 4000
alpha = 4.
alpha_reduce_each = 500
alpha_reduce_mul = 0.5
draw_each = 100
chosen_class = "cat"
l2_coef = 4.

class_ind = cl_to_ind[chosen_class]
img_shape = val_ds[0][0].shape
synthetic_image = torch.nn.Parameter(torch.rand(len(classes), *img_shape), requires_grad=True)
starting_image = synthetic_image[class_ind].detach().numpy().transpose(1, 2, 0)

generated_probs = defaultdict(list)
optimized_values = defaultdict(list)
xx = np.arange(len(classes))

model.eval()
model.to(device)
synthetic_image = synthetic_image.to(device)
synthetic_image.retain_grad()
for epoch in trange(max_generation_epochs):
    synthetic_image.grad = None
    out = model(synthetic_image)
    pred_class_logit = out[xx, xx]
    l2_norm = (synthetic_image**2).sum(dim=(1,2,3)).sqrt().sum(0)
    (pred_class_logit.sum() - l2_coef * l2_norm).backward()
    synthetic_image.data = synthetic_image.data + alpha * synthetic_image.grad
    
    with torch.no_grad():
        probs = F.softmax(out, dim=1)
        pred_class_prob = probs[xx, xx]
        for ind in xx:
            generated_probs[ind].append(pred_class_prob[ind].item())
            optimized_values[ind].append(pred_class_logit[ind].item())
        if (epoch + 1) % draw_each == 0:
            clear_output(True)
            fig = plt.subplots(2, 1, figsize=(20, 8))
            
            plt.subplot(211)
            for ind in xx:
                plt.plot(generated_probs[ind])
            plt.title("Probability of target class")
            plt.xlabel("epoch")
            plt.ylabel("probability")
            plt.yscale("log")
            
            plt.subplot(212)
            for ind in xx:
                plt.plot(optimized_values[ind])
            plt.title("Logit of target class")
            plt.xlabel("epoch")
            
            plt.show()
        if (epoch + 1) % alpha_reduce_each == 0:
            alpha *= alpha_reduce_mul

model.cpu()
synthetic_image = synthetic_image.cpu()


fig, ax = plt.subplots(1, 2, figsize=(10, 5))

plt.subplot(121)
plt.imshow(starting_image)
plt.title("Starting image")
plt.axis("off")

plt.subplot(122)
plt.imshow(min_max_scale(synthetic_image[class_ind].detach().numpy().transpose(1, 2, 0)))
plt.title("Resulting image")
plt.axis("off")

plt.show()

In [None]:
fig, ax = plt.subplots(2, 5, figsize=(30, 15))

for ind in range(10):
    plt.subplot(2, 5, ind+1)
    plt.imshow(min_max_scale(synthetic_image[ind].detach().numpy().transpose(1, 2, 0)))
    plt.title(classes[ind])
    plt.axis("off")

plt.show()

In [None]:
mean_tensor = torch.tensor(val_transform.transforms[-1].mean).unsqueeze(0).to(device)
std_tensor = torch.tensor(val_transform.transforms[-1].std).unsqueeze(0).to(device)

In [None]:
from IPython.display import clear_output

max_generation_epochs = 4000
alpha = 4.
alpha_reduce_each = 500
alpha_reduce_mul = 0.5
draw_each = 100
chosen_class = "cat"
reg_coef = 10.

class_ind = cl_to_ind[chosen_class]
img_shape = val_ds[0][0].shape
synthetic_image = torch.nn.Parameter(torch.rand(len(classes), *img_shape), requires_grad=True)
starting_image = synthetic_image[class_ind].detach().numpy().transpose(1, 2, 0)

generated_probs = defaultdict(list)
optimized_values = defaultdict(list)
xx = np.arange(len(classes))

model.eval()
model.to(device)
synthetic_image = synthetic_image.to(device)
synthetic_image.retain_grad()
for epoch in trange(max_generation_epochs):
    synthetic_image.grad = None
    out = model(synthetic_image)
    pred_class_logit = out[xx, xx]
    regularization = ((synthetic_image.mean((0, 2, 3)) - mean_tensor)**2).sum() + ((synthetic_image.std((0, 2, 3)) - std_tensor)**2).sum()
    (pred_class_logit.sum() - reg_coef * regularization).backward()
    synthetic_image.data = synthetic_image.data + alpha * synthetic_image.grad
    
    with torch.no_grad():
        probs = F.softmax(out, dim=1)
        pred_class_prob = probs[xx, xx]
        for ind in xx:
            generated_probs[ind].append(pred_class_prob[ind].item())
            optimized_values[ind].append(pred_class_logit[ind].item())
        if (epoch + 1) % draw_each == 0:
            clear_output(True)
            fig = plt.subplots(2, 1, figsize=(20, 8))
            
            plt.subplot(211)
            for ind in xx:
                plt.plot(generated_probs[ind])
            plt.title("Probability of target class")
            plt.xlabel("epoch")
            plt.ylabel("probability")
            plt.yscale("log")
            
            plt.subplot(212)
            for ind in xx:
                plt.plot(optimized_values[ind])
            plt.title("Logit of target class")
            plt.xlabel("epoch")
            
            plt.show()
        if (epoch + 1) % alpha_reduce_each == 0:
            alpha *= alpha_reduce_mul

model.cpu()
synthetic_image = synthetic_image.cpu()


fig, ax = plt.subplots(1, 2, figsize=(10, 5))

plt.subplot(121)
plt.imshow(starting_image)
plt.title("Starting image")
plt.axis("off")

plt.subplot(122)
plt.imshow(min_max_scale(synthetic_image[class_ind].detach().numpy().transpose(1, 2, 0)))
plt.title("Resulting image")
plt.axis("off")

plt.show()

In [None]:
fig, ax = plt.subplots(2, 5, figsize=(30, 15))

for ind in range(10):
    plt.subplot(2, 5, ind+1)
    plt.imshow(min_max_scale(synthetic_image[ind].detach().numpy().transpose(1, 2, 0)))
    plt.title(classes[ind])
    plt.axis("off")

plt.show()

Чтобы оптимизационными и генеративными методами получить результаты лучше - много всего, поэтому сейчас остановимся здесь. Больше найдёте в специализированных курсах по CV, NLP, генеративкам, etc.

Полезные ссылки, если хочется посмотреть/ попробовать ещё:

* https://distill.pub/2017/feature-visualization
* https://yosinski.com/deepvis
* https://github.com/utkuozbulak/pytorch-cnn-visualizations
* https://github.com/yosinski/deep-visualization-toolbox
* https://github.com/tensorflow/lucid
