### Import packages

In [1]:
import torch
import torch.nn as nn
from torchvision.models.feature_extraction import get_graph_node_names

from models.pim_module.pim_module import PluginMoodel

### costom model

In [2]:
class Model(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Linear(128, 10)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.pool(x)
        x = x.flatten(1)
        x = self.classifier(x)
        return x

In [3]:
model = Model()

### get model name

In [4]:
print(model) ### structure

Model(
  (conv1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
  )
  (conv2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
  )
  (pool): AdaptiveAvgPool2d(output_size=(1, 1))
  (classifier): Linear(in_features=128, out_features=10, bias=True)
)


In [5]:
print(get_graph_node_names(model))

(['x', 'conv1.0', 'conv1.1', 'conv1.2', 'conv1.3', 'conv1.4', 'conv1.5', 'conv2.0', 'conv2.1', 'conv2.2', 'conv2.3', 'conv2.4', 'conv2.5', 'pool', 'flatten', 'classifier'], ['x', 'conv1.0', 'conv1.1', 'conv1.2', 'conv1.3', 'conv1.4', 'conv1.5', 'conv2.0', 'conv2.1', 'conv2.2', 'conv2.3', 'conv2.4', 'conv2.5', 'pool', 'flatten', 'classifier'])


### prepare material to build PluginMoodel

In [6]:
# if we want conv1 output and conv2 output
return_nodes = {
    "conv1.5":"layer1",
    "conv2.5":"layer2",
}

In [7]:
# notice that 'layer1' and 'layer2' must match return_nodes's value
num_selects = {
    "layer1":64, 
    "layer2":64
}

In [8]:
IMG_SIZE = 224
USE_FPN = True
FPN_SIZE = 128 # fpn projection size, if do not use fpn, you can set fpn_size to None

In [9]:
# proj_type : you can choose 'Conv' or 'Linear', 'Conv' is design for 4d image input (resnet, efficientnet, vgg...),
# 'Linear' is for 3d image input (Vit, Swin-T...)
PROJ_TYPE = "Conv"

In [10]:
# upsample_type : ["Bilinear", "Conv", "Fc"]
# for convolution neural network (e.g. ResNet, EfficientNet), recommand 'Bilinear'. 
# for Vit, "Fc". and Swin-T, "Conv"
UPSAMPLE_TYPE = "Bilinear"

In [11]:
pim_model = \
PluginMoodel(backbone = model,
             return_nodes = return_nodes,
             img_size = IMG_SIZE,
             use_fpn = USE_FPN,
             fpn_size = FPN_SIZE,
             proj_type = PROJ_TYPE,
             upsample_type = UPSAMPLE_TYPE,
             use_selection = True,
             num_classes = 10,
             num_selects = num_selects, 
             use_combiner = True,
             comb_proj_size = None)

In [12]:
rand_inp = torch.randn(1, 3, 224, 224)
outs = pim_model(rand_inp)



In [13]:
print([name for name in outs])

['layer1', 'layer2', 'preds_1', 'preds_0', 'comb_outs']


In [14]:
# 'layer1' : logits of 'layer1' , size [B, num_classes]
# 'layer2' : logits of 'layer2' , size [B, num_classes]
# 'preds_1'(dict) : logits of selected region, size [B, num_classes]
# 'preds_0'(dict) : logits of NOT selected region, size [B, num_classes]
# 'comb_outs' : logits of Combiner , size [B, num_classes]

### some error raise while get_graph_node_names() or create_feature_extractor()

In [15]:
### change model

In [16]:
class Model(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Linear(128, 10)
        
    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x1)
        x = self.pool(x2)
        x = x.flatten(1)
        x = self.classifier(x)
        return {"layer1":x1, "layer2":x2}

In [17]:
model = Model()

In [18]:
## set return_nodes to None

In [19]:
pim_model = \
PluginMoodel(backbone = model,
             return_nodes = None,
             img_size = IMG_SIZE,
             use_fpn = USE_FPN,
             fpn_size = FPN_SIZE,
             proj_type = PROJ_TYPE,
             upsample_type = UPSAMPLE_TYPE,
             use_selection = True,
             num_classes = 10,
             num_selects = num_selects, 
             use_combiner = True,
             comb_proj_size = None)

In [20]:
rand_inp = torch.randn(1, 3, 224, 224)
outs = pim_model(rand_inp)

In [21]:
print([name for name in outs])

['layer1', 'layer2', 'preds_1', 'preds_0', 'comb_outs']
