In [2]:
from anomalib.models.cfa.cnn.efficientnet import EfficientNet as effnet
from anomalib.models.cfa.cnn.resnet import resnet18 as res18
from anomalib.models.cfa.cnn.resnet import wide_resnet50_2 as wrn50_2
from anomalib.models.cfa.cnn.vgg import vgg19_bn as vgg19

import torch
import timm
import torchvision
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names

In [3]:
x = torch.rand((4, 3, 256, 256))

## ResNet - 18

In [4]:
cfa_feature_extractor = res18(pretrained=True, progress=True)
cfa_output = cfa_feature_extractor(x)
for f in cfa_output:
    print(f.min(), f.max(), f.shape)

tensor(-0.0384, grad_fn=<MinBackward1>) tensor(5.5771, grad_fn=<MaxBackward1>) torch.Size([4, 64, 64, 64])
tensor(-0.0322, grad_fn=<MinBackward1>) tensor(5.3179, grad_fn=<MaxBackward1>) torch.Size([4, 128, 32, 32])
tensor(-0.0344, grad_fn=<MinBackward1>) tensor(5.6700, grad_fn=<MaxBackward1>) torch.Size([4, 256, 16, 16])


In [5]:
model = getattr(torchvision.models, "resnet18")(pretrained=True)
torch_feature_extractor = create_feature_extractor(model=model, return_nodes=["layer1", "layer2", "layer3"])
torch_output = torch_feature_extractor(x)

for f in torch_output.values():
    print(f.min(), f.max(), f.shape)

tensor(0., grad_fn=<MinBackward1>) tensor(5.5771, grad_fn=<MaxBackward1>) torch.Size([4, 64, 64, 64])
tensor(0., grad_fn=<MinBackward1>) tensor(5.3179, grad_fn=<MaxBackward1>) torch.Size([4, 128, 32, 32])
tensor(0., grad_fn=<MinBackward1>) tensor(5.6700, grad_fn=<MaxBackward1>) torch.Size([4, 256, 16, 16])


## Wide ResNet - 50 

In [6]:
m1 = wrn50_2(pretrained=True, progress=True)
f1 = m1(x)

for f in f1:
    print(f.min(), f.max(), f.shape)

tensor(-0.0121, grad_fn=<MinBackward1>) tensor(2.9411, grad_fn=<MaxBackward1>) torch.Size([4, 256, 64, 64])
tensor(-0.0116, grad_fn=<MinBackward1>) tensor(2.2952, grad_fn=<MaxBackward1>) torch.Size([4, 512, 32, 32])
tensor(-0.0139, grad_fn=<MinBackward1>) tensor(1.5275, grad_fn=<MaxBackward1>) torch.Size([4, 1024, 16, 16])


In [7]:
m4 = getattr(torchvision.models, "wide_resnet50_2")(pretrained=True)
fe = create_feature_extractor(model=m4, return_nodes=["layer1", "layer2", "layer3"])
f4 = fe(x)

for f in f4.values():
    print(f.min(), f.max(), f.shape)

tensor(0., grad_fn=<MinBackward1>) tensor(2.9411, grad_fn=<MaxBackward1>) torch.Size([4, 256, 64, 64])
tensor(0., grad_fn=<MinBackward1>) tensor(2.2952, grad_fn=<MaxBackward1>) torch.Size([4, 512, 32, 32])
tensor(0., grad_fn=<MinBackward1>) tensor(1.5275, grad_fn=<MaxBackward1>) torch.Size([4, 1024, 16, 16])


## VGG

In [8]:
cfa_feature_extractor = vgg19(pretrained=True, progress=True)
cfa_output = cfa_feature_extractor(x)
for f in cfa_output:
    print(f.min(), f.max(), f.shape)

tensor(0., grad_fn=<MinBackward1>) tensor(1.5991, grad_fn=<MaxBackward1>) torch.Size([4, 256, 64, 64])
tensor(0., grad_fn=<MinBackward1>) tensor(2.6335, grad_fn=<MaxBackward1>) torch.Size([4, 512, 32, 32])
tensor(0., grad_fn=<MinBackward1>) tensor(4.3431, grad_fn=<MaxBackward1>) torch.Size([4, 512, 8, 8])


In [9]:
model = getattr(torchvision.models, "vgg19_bn")(pretrained=True)
torch_feature_extractor = create_feature_extractor(
    model=model, return_nodes=["features.25", "features.38", "features.52"]
)
# get_graph_node_names(model)

torch_output = torch_feature_extractor(x)
for f in torch_output.values():
    print(f.min(), f.max(), f.shape)

tensor(0., grad_fn=<MinBackward1>) tensor(1.5991, grad_fn=<MaxBackward1>) torch.Size([4, 256, 64, 64])
tensor(0., grad_fn=<MinBackward1>) tensor(2.6335, grad_fn=<MaxBackward1>) torch.Size([4, 512, 32, 32])
tensor(0., grad_fn=<MinBackward1>) tensor(4.3431, grad_fn=<MaxBackward1>) torch.Size([4, 512, 8, 8])


## EfficientNet

In [54]:
cfa_feature_extractor = effnet.from_pretrained("efficientnet-b5")
cfa_output = cfa_feature_extractor(x)
for f in cfa_output:
    print(f.min(), f.max(), f.shape)

Loaded pretrained weights for efficientnet-b5
tensor(-110.1491, grad_fn=<MinBackward1>) tensor(113.3296, grad_fn=<MaxBackward1>) torch.Size([4, 200, 64, 64])
tensor(-62.1684, grad_fn=<MinBackward1>) tensor(45.1688, grad_fn=<MaxBackward1>) torch.Size([4, 192, 32, 32])
tensor(-25.7444, grad_fn=<MinBackward1>) tensor(31.8913, grad_fn=<MaxBackward1>) torch.Size([4, 176, 16, 16])


In [59]:
model = getattr(torchvision.models, "efficientnet_b5")(pretrained=True)
torch_feature_extractor = nodes = ["features.2.2.block.3.1"]
torch_output = torch_feature_extractor(x)

for f in torch_output.values():
    print(f.min(), f.max(), f.shape)

tensor(-19.1508, grad_fn=<MinBackward1>) tensor(19.9623, grad_fn=<MaxBackward1>) torch.Size([4, 40, 64, 64])


In [52]:
model

EfficientNet(
  (features): Sequential(
    (0): ConvNormActivation(
      (0): Conv2d(3, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): MBConv(
        (block): Sequential(
          (0): ConvNormActivation(
            (0): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=48, bias=False)
            (1): BatchNorm2d(48, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
          (1): SqueezeExcitation(
            (avgpool): AdaptiveAvgPool2d(output_size=1)
            (fc1): Conv2d(48, 12, kernel_size=(1, 1), stride=(1, 1))
            (fc2): Conv2d(12, 48, kernel_size=(1, 1), stride=(1, 1))
            (activation): SiLU(inplace=True)
            (scale_activation): Sigmoid()
          )
          (2): ConvNormActivatio

In [51]:
names = get_graph_node_names(model)
for name in names:
    print(name)

['x', 'features.0.0', 'features.0.1', 'features.0.2', 'features.1.0.block.0.0', 'features.1.0.block.0.1', 'features.1.0.block.0.2', 'features.1.0.block.1', 'features.1.0.block.2.0', 'features.1.0.block.2.1', 'features.1.1.block.0.0', 'features.1.1.block.0.1', 'features.1.1.block.0.2', 'features.1.1.block.1', 'features.1.1.block.2.0', 'features.1.1.block.2.1', 'features.1.1.stochastic_depth', 'features.1.1.add', 'features.1.2.block.0.0', 'features.1.2.block.0.1', 'features.1.2.block.0.2', 'features.1.2.block.1', 'features.1.2.block.2.0', 'features.1.2.block.2.1', 'features.1.2.stochastic_depth', 'features.1.2.add', 'features.2.0.block.0.0', 'features.2.0.block.0.1', 'features.2.0.block.0.2', 'features.2.0.block.1.0', 'features.2.0.block.1.1', 'features.2.0.block.1.2', 'features.2.0.block.2', 'features.2.0.block.3.0', 'features.2.0.block.3.1', 'features.2.1.block.0.0', 'features.2.1.block.0.1', 'features.2.1.block.0.2', 'features.2.1.block.1.0', 'features.2.1.block.1.1', 'features.2.1.bl