In [1]:
import os
import sys
from pathlib import Path

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
import torch.nn as nn
from torch.nn import Sequential
from torchinfo import summary

from baard.classifiers import CIFAR10_ResNet18

In [4]:
PATH_ROOT = Path(os.getcwd()).parent
PATH_DATA = os.path.join(PATH_ROOT, 'data')
PATH_CHECKPOINT = os.path.join(PATH_ROOT, 'pretrained_clf', 'cifar10_resnet18.ckpt')
print(PATH_CHECKPOINT)

/home/lukec/workspace/baard_v4/pretrained_clf/cifar10_resnet18.ckpt


In [5]:
model = CIFAR10_ResNet18.load_from_checkpoint(PATH_CHECKPOINT)
batch_size = model.train_dataloader().batch_size
input_size = (batch_size, 3, 32, 32)
device = 'cpu'

results = summary(model, input_size=input_size, device=device)
print(results)

Files already downloaded and verified
Layer (type:depth-idx)                        Output Shape              Param #
CIFAR10_ResNet18                              [256, 10]                 --
├─ResNet: 1-1                                 [256, 10]                 --
│    └─Conv2d: 2-1                            [256, 64, 32, 32]         1,728
│    └─BatchNorm2d: 2-2                       [256, 64, 32, 32]         128
│    └─ReLU: 2-3                              [256, 64, 32, 32]         --
│    └─Identity: 2-4                          [256, 64, 32, 32]         --
│    └─Sequential: 2-5                        [256, 64, 32, 32]         --
│    │    └─BasicBlock: 3-1                   [256, 64, 32, 32]         73,984
│    │    └─BasicBlock: 3-2                   [256, 64, 32, 32]         73,984
│    └─Sequential: 2-6                        [256, 128, 16, 16]        --
│    │    └─BasicBlock: 3-3                   [256, 128, 16, 16]        230,144
│    │    └─BasicBlock: 3-4             

In [6]:
resnet18_list = list(model.model.children())
labelled_layers = [(f'Layer: {i}', net) for i, net in enumerate(resnet18_list)]
print(*labelled_layers, sep='\n\n')

('Layer: 0', Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))

('Layer: 1', BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))

('Layer: 2', ReLU(inplace=True))

('Layer: 3', Identity())

('Layer: 4', Sequential(
  (0): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (1): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1),

In [7]:
results = summary(Sequential(*resnet18_list[:-1]), input_size=input_size, device=device)
print(results)

Layer (type:depth-idx)                   Output Shape              Param #
Sequential                               [256, 512, 1, 1]          --
├─Conv2d: 1-1                            [256, 64, 32, 32]         1,728
├─BatchNorm2d: 1-2                       [256, 64, 32, 32]         128
├─ReLU: 1-3                              [256, 64, 32, 32]         --
├─Identity: 1-4                          [256, 64, 32, 32]         --
├─Sequential: 1-5                        [256, 64, 32, 32]         --
│    └─BasicBlock: 2-1                   [256, 64, 32, 32]         --
│    │    └─Conv2d: 3-1                  [256, 64, 32, 32]         36,864
│    │    └─BatchNorm2d: 3-2             [256, 64, 32, 32]         128
│    │    └─ReLU: 3-3                    [256, 64, 32, 32]         --
│    │    └─Conv2d: 3-4                  [256, 64, 32, 32]         36,864
│    │    └─BatchNorm2d: 3-5             [256, 64, 32, 32]         128
│    │    └─ReLU: 3-6                    [256, 64, 32, 32]         --
│

In [8]:
resnet18_list = list(model.model.children())
latent_net = nn.Sequential(
    *list(resnet18_list)[:-1],
    nn.Flatten(start_dim=1)
)
weight = list(model.model.children())[-1].weight

print(weight.size())

results = summary(latent_net, input_size=input_size, device=device)
print(results)

torch.Size([10, 512])
Layer (type:depth-idx)                   Output Shape              Param #
Sequential                               [256, 512]                --
├─Conv2d: 1-1                            [256, 64, 32, 32]         1,728
├─BatchNorm2d: 1-2                       [256, 64, 32, 32]         128
├─ReLU: 1-3                              [256, 64, 32, 32]         --
├─Identity: 1-4                          [256, 64, 32, 32]         --
├─Sequential: 1-5                        [256, 64, 32, 32]         --
│    └─BasicBlock: 2-1                   [256, 64, 32, 32]         --
│    │    └─Conv2d: 3-1                  [256, 64, 32, 32]         36,864
│    │    └─BatchNorm2d: 3-2             [256, 64, 32, 32]         128
│    │    └─ReLU: 3-3                    [256, 64, 32, 32]         --
│    │    └─Conv2d: 3-4                  [256, 64, 32, 32]         36,864
│    │    └─BatchNorm2d: 3-5             [256, 64, 32, 32]         128
│    │    └─ReLU: 3-6                    [256, 64

In [9]:
net = Sequential(*resnet18_list[:7])
X = torch.rand((10, 3, 32, 32))
outputs = net(X)
print(outputs.size())

torch.Size([10, 256, 8, 8])


In [10]:
net = Sequential(*resnet18_list[:7], nn.AdaptiveMaxPool3d((512, 1, 1)))
X = torch.rand((10, 3, 32, 32))
outputs = net(X)
print(outputs.size())

torch.Size([10, 512, 1, 1])
