In [None]:
%load_ext autoreload
%autoreload 2

In [99]:
import cv2
from collections import OrderedDict
import torch
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from matplotlib import pyplot as plt
from anomalib.pre_processing.pre_process import get_transforms, PreProcessor
import torchvision.models.detection as detection
from anomalib.data import InferenceDataset
from torchvision.ops import RoIAlign
from torch.utils.data import DataLoader
from torchvision.transforms import ToPILImage
import torchvision.models as models
import torch.nn.functional as F
from torch import nn, Tensor

from anomalib.models.rbad.region_extractor import RegionExtractor as RegionExtractor2
from anomalib.models.rbad.region import RegionExtractor as RegionExtractor1
from anomalib.models.rbad.feature import FeatureExtractor as FeatureExtractor1
from anomalib.models.rbad.feature_extractor import FeatureExtractor as FeatureExtractor2

In [100]:
filename = "150.tif"
image = cv2.imread(filename)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Transformations.
transforms = get_transforms(config=A.Compose([A.Normalize(mean=0.0, std=1.0), ToTensorV2()]))
pre_process = PreProcessor(config=transforms)

# Get the data via dataloader
dataset = InferenceDataset(path=filename, pre_process=pre_process)
dataloader = DataLoader(dataset)
i, data = next(enumerate(dataloader))

# Create the region extractor.
stage="rpn"
use_original = True
region_extractor1 = RegionExtractor1(stage=stage, use_original=use_original).eval().cuda()
region_extractor2 = RegionExtractor2(stage=stage, use_original=use_original).eval().cuda()
region_extractor3 = RegionExtractor2(stage=stage, use_original=use_original).eval()

# Forward-Pass the input
boxes1 = region_extractor1([image])
boxes2 = region_extractor2(data["image"].cuda())
boxes3 = region_extractor3(data["image"])

# Feature Extractor
feature_extractor1 = FeatureExtractor1().eval().cuda()
feature_extractor2 = FeatureExtractor2().eval().cuda()
feature_extractor3 = FeatureExtractor2().eval()
features1 = feature_extractor1(image, boxes1[0])
features2 = feature_extractor2(data["image"].cuda(), boxes2)
features3 = feature_extractor3(data["image"], boxes3)

In [None]:
base_model_checkpoint = torch.load("combined_head_1_2-a9f83242.pth")
base_model_checkpoint["module_state"].keys()
rcnn_checkpoint = torch.load("rcnn_1_2-31296d99.pth")

feature_extractor2.backbone.load_state_dict(base_model_checkpoint["module_state"])
feature_extractor2.rcnn_module.load_state_dict(rcnn_checkpoint["module_state"], strict=False)

_IncompatibleKeys(missing_keys=[], unexpected_keys=['RCNN_cls_score.weight', 'RCNN_cls_score.bias', 'RCNN_adjective_score.weight', 'RCNN_adjective_score.bias', 'RCNN_verb_score.weight', 'RCNN_verb_score.bias'])

In [56]:
checkpoint = torch.load(f="rcnn_feature_extractor.pth")
checkpoint["backbone"]["module_state"]

OrderedDict([('0.weight',
              tensor([[[[ 2.9598e-02,  2.4799e-02,  1.9259e-02,  ...,  3.7651e-03,
                          5.9639e-03,  8.4870e-03],
                        [ 2.6165e-02,  2.2026e-02,  1.7773e-02,  ...,  5.1145e-03,
                          6.0359e-03,  7.2886e-03],
                        [ 2.2031e-02,  1.9807e-02,  1.7018e-02,  ...,  1.4474e-03,
                          2.6974e-03,  4.3055e-03],
                        ...,
                        [-5.4552e-03, -2.1305e-03, -5.2299e-03,  ..., -8.2159e-03,
                         -8.6463e-03, -9.2161e-03],
                        [-5.8815e-03, -2.3393e-03, -3.6423e-03,  ..., -1.1838e-02,
                         -1.1842e-02, -1.2679e-02],
                        [-5.3374e-03, -9.4002e-04, -4.7980e-04,  ..., -9.2019e-03,
                         -1.1577e-02, -1.2648e-02]],
              
                       [[-2.0982e-02, -2.5610e-02, -3.2641e-02,  ..., -2.5352e-02,
                         -1.4694e-02

In [85]:
new_weights = OrderedDict()

for key, value in checkpoint["rcnn"]["module_state"].items():
    if key.startswith("RCNN_tail"):
        new_weights[key.replace("RCNN_tail.", "")] = value

new_weights

OrderedDict([('1.weight',
              tensor([[ 3.9885e-04,  1.2097e-03,  1.1138e-03,  ...,  4.2721e-05,
                        6.3370e-04,  1.1545e-03],
                      [ 2.1056e-03,  3.0797e-03,  1.8302e-03,  ..., -1.3055e-03,
                       -4.9129e-04, -6.8891e-04],
                      [-6.2976e-04,  7.8751e-04,  5.3508e-04,  ..., -2.4154e-04,
                       -7.4725e-04, -2.3559e-05],
                      ...,
                      [-1.9887e-03, -1.1254e-03,  5.4735e-04,  ..., -2.4652e-04,
                        5.3090e-04,  5.6090e-04],
                      [-1.5325e-03, -1.8999e-03, -1.7625e-03,  ..., -2.3537e-05,
                       -1.3358e-03, -1.5884e-03],
                      [-1.1570e-03, -9.9317e-04,  1.2536e-04,  ...,  8.8544e-05,
                        9.2516e-06, -1.3223e-03]], device='cuda:0')),
             ('1.bias',
              tensor([ 0.0488,  0.0949, -0.0757,  ..., -0.0306,  0.0422,  0.0110],
                     device='cuda:

In [86]:
feature_extractor_checkpoint = {
    "backbone": base_model_checkpoint["module_state"],
    "classifier": new_weights,
}
torch.save(feature_extractor_checkpoint, "rcnn_feature_extractor.pth")

# Feature Extractor Implementation

In [101]:
filename = "150.tif"
image = cv2.imread(filename)

boxes1 = region_extractor1([image])
features1 = feature_extractor1(image, boxes1[0])

image, scale = feature_extractor1.transform(image)
boxes_pt = torch.tensor(boxes1[0])
rois = torch.cat((torch.zeros(boxes_pt.size(0), 1), boxes_pt), 1).unsqueeze(0).to(feature_extractor1.device)
rois *= scale
base_feats1 = feature_extractor1.head_module(image)
rcnn_feats1 = feature_extractor1.rcnn_module(base_feats1, rois)
rcnn_feats1.shape

torch.Size([24, 4096])

In [102]:
class RegionBasedFeatureExtractor(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        
        self.__model = models.alexnet(pretrained=False)

        # TODO: Load this via torch url.
        state_dict = torch.load("rcnn_feature_extractor.pth", map_location="cpu")
        
        self.backbone = self.__model.features[:-1]
        self.backbone.load_state_dict(state_dict=state_dict["backbone"])

        # Create RoI Align Network.
        self.roi_align = RoIAlign(output_size=(6, 6), spatial_scale=1/16, sampling_ratio=0)
        
        # Classifier network to extract the features.
        self.classifer = self.__model.classifier[:-1]
        self.classifer.load_state_dict(state_dict=state_dict["classifier"])
    
    @torch.no_grad()
    def forward(self, input: Tensor, rois: Tensor):
        features = self.backbone(input)
        # n_rois x 256 x 6 x 6 (AlexNet)
        features = self.roi_align(features, rois.view(-1, 5))
        # n_rois x 4096
        features = self.classifer(features.view(features.size(0), -1))
        
        return features

In [104]:
region_based_feature_extractor = RegionBasedFeatureExtractor().eval().cuda()
rcnn_feats2 = region_based_feature_extractor(image, rois)

In [105]:
rcnn_feats1.shape, rcnn_feats2.shape

(torch.Size([24, 4096]), torch.Size([24, 4096]))

In [106]:
torch.allclose(rcnn_feats1, rcnn_feats2)

True

In [107]:
rcnn_feats1

tensor([[0.2982, 0.0000, 0.0923,  ..., 0.0938, 0.0503, 0.4384],
        [0.2867, 0.0000, 0.0000,  ..., 0.1181, 0.0000, 0.5956],
        [0.3487, 0.0000, 0.0414,  ..., 0.1001, 0.0000, 0.5548],
        ...,
        [0.6233, 0.0000, 0.0329,  ..., 0.1152, 0.0000, 0.8032],
        [0.4534, 0.0000, 0.0037,  ..., 0.1920, 0.0000, 0.8502],
        [0.6094, 0.0000, 0.0053,  ..., 0.1168, 0.0000, 0.8830]],
       device='cuda:0', grad_fn=<ReluBackward0>)

In [108]:
rcnn_feats2

tensor([[0.2982, 0.0000, 0.0923,  ..., 0.0938, 0.0503, 0.4384],
        [0.2867, 0.0000, 0.0000,  ..., 0.1181, 0.0000, 0.5956],
        [0.3487, 0.0000, 0.0414,  ..., 0.1001, 0.0000, 0.5548],
        ...,
        [0.6233, 0.0000, 0.0329,  ..., 0.1152, 0.0000, 0.8032],
        [0.4534, 0.0000, 0.0037,  ..., 0.1920, 0.0000, 0.8502],
        [0.6094, 0.0000, 0.0053,  ..., 0.1168, 0.0000, 0.8830]],
       device='cuda:0')