In [2]:
import torch
from torchvision.models.feature_extraction import get_graph_node_names, create_feature_extractor
import numpy as np

import os
os.chdir("../")

import torchvision
from dataloaders.tasks_provider import prepare_classes_list, TaskList


In [3]:
os.listdir()

['.git', '.gitignore', '.idea', 'files', 'myCode', 'outputs', 'wandb']

In [4]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)

Using cache found in C:\Users\QbaSo/.cache\torch\hub\pytorch_vision_v0.10.0


In [5]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [6]:
nodes = get_graph_node_names(model)[1] # 1 mean model.eval layers

nodes_to_extract = []

for node in nodes:
    if "relu" in node:
        nodes_to_extract.append(node)

nodes_to_extract = nodes_to_extract[-5:] # using only last 5 layers
print(nodes_to_extract)

['layer3.1.relu_1', 'layer4.0.relu', 'layer4.0.relu_1', 'layer4.1.relu', 'layer4.1.relu_1']


In [55]:
model_feature_extractor = create_feature_extractor(model, return_nodes=nodes_to_extract)

In [8]:
num_classes, classes_per_task = 10, 2
classes_list = prepare_classes_list(num_classes, classes_per_task)
tasks = TaskList(classes_list, 128, torchvision.datasets.CIFAR10)
tasks_test = TaskList(classes_list, 128, torchvision.datasets.CIFAR10, train=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to files/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting files/cifar-10-python.tar.gz to files/
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [30]:
for x, y in tasks.tasks[0].dataloader:
    break

In [75]:
extracted_features = model_feature_extractor(x)

layers = list(extracted_features.keys())

activations = extracted_features[layers[0]].flatten(1).cpu().detach().numpy()
for layer in layers[1:]:
    activations = np.concatenate((activations, extracted_features[layer].flatten(1).cpu().detach().numpy()), axis=1)

In [83]:
activations.sum(0)

array([ 36.78164 ,  16.19849 ,  23.470175, ..., 108.06283 , 111.987175,
       110.30343 ], dtype=float32)