In [1]:
import numpy as np
import torch
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from PIL import Image
from torchvision.models.feature_extraction import create_feature_extractor
from torchinfo import summary
from tqdm.notebook import tqdm
import torch.nn.functional as F
from collections import OrderedDict

# from ssdg_mobilenetv2_dli14 import DG_model
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

# from ssdg_t2t_vit import DG_model
from fas_simple_distill.model.mobilenetv2_ssdg.ssdg_mobilenetv2_dli14_wo_normalize_more_lyr_ft import DG_model
# import fusion_maxpool_classifier_more_lyr 

  warn(f"Failed to load image Python extension: {e}")


torchvision version 0.11.2+cu102


In [2]:
from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM
from pytorch_grad_cam.utils.image import show_cam_on_image

In [3]:
class SimpleGlobDataset(Dataset):
    def __init__(self, root, glob_patt, transform=None):
        self.root = root
        self.glob_patt = glob_patt
        self.images = sorted(list(Path(root).glob(glob_patt)))
        self.transform = transform
        
    def __getitem__(self, index):
        img_path = self.images[index]
        img = Image.open(img_path)
        if self.transform is not None:
            img = self.transform(img)
            
        return img, img_path
    
    def __len__(self):
        return len(self.images)

In [4]:
# return_nodes = [
#     "embedder.backbone.10.conv.2",
#     "embedder.backbone.14.conv.2",
#     "classifier.classifier_layer",  # original output
# ]

# input_shapes = [
#     [384, 21, 21],
#     [576, 21, 21],
# ]

ckpt_path = "./model/fas19v2-no-border-public-1-1658220595__0__e090a7a6_e090a7a6_step_700.pth"
# classif_ckpt_path = "/home/rayhan/Desktop/git/fas-test-scripts/agi/visualization/weight/best-config/new/ckpt-1648696287-step2600-ep25.pth"

# ------------------------------------------

dg_model = DG_model(
    embedding_size=128, 
    norm_flag=False, 
    use_ssdg_norm=False, 
    last_channel=512, 
    width_mult=1.0, 
    embedder_act='Identity',
    drop_rate=0.0,
    preemb_drop=0.0,
)
dg_model_ckpt = torch.load(ckpt_path, map_location="cpu")
state_dict = dg_model_ckpt['model_ema']

new_state_dict = OrderedDict()
# print(state_dict)
for k, v in state_dict.items():
    name = k[7:] # remove 'module.' of dataparallel
    new_state_dict[name]=v
print(new_state_dict)
dg_model.load_state_dict(new_state_dict)
dg_model.eval()
# print(dg_model)

# -------------------------------------------

# classifier = fusion_maxpool_classifier_more_lyr.ConvFusionClassifierLight(
#     input_shapes=input_shapes,
#     num_classes=2,
# )
# classifier_ckpt = torch.load(classif_ckpt_path, map_location="cpu")
# classifier.load_state_dict(classifier_ckpt["classifier"])
# classifier.eval()

# --------------------------------------------

class ModelWrapper(torch.nn.Module):
    def __init__(self, dg_model_):
        super().__init__()
        self.model = dg_model_
        
    def forward(self, x):
        cls_out,_ = self.model(x)
        
        # main_cls_out = dg_model_out[-1]
        cls_out = F.sigmoid(cls_out)
        
        return cls_out

model = ModelWrapper(dg_model)
model.eval()
print(model)

OrderedDict([('embedder.backbone.0.0.weight', tensor([[[[-1.1107e-02,  1.3914e-01, -1.0090e-01],
          [ 5.0005e-02, -3.8918e-01,  3.7341e-01],
          [-3.5453e-02, -6.3883e-01,  6.6534e-01]],

         [[ 5.5103e-02,  1.3214e-01, -2.3363e-01],
          [ 7.5018e-02, -7.2240e-01,  6.7126e-01],
          [-1.9870e-02, -1.1681e+00,  1.1667e+00]],

         [[-4.0259e-02,  1.4244e-01, -1.4499e-01],
          [ 1.1783e-02, -1.5938e-01,  1.4156e-01],
          [-2.2147e-02, -2.7810e-01,  3.3310e-01]]],


        [[[ 7.5433e-02,  1.4236e-02, -5.3402e-02],
          [ 8.3521e-02,  2.7036e-01,  8.4251e-02],
          [ 9.8369e-02,  6.1380e-01,  9.2975e-02]],

         [[-5.3194e-04, -8.3629e-02, -1.5371e-02],
          [-5.5048e-02,  1.3712e-01, -9.3215e-02],
          [ 1.3328e-02,  5.2336e-01, -2.5347e-01]],

         [[-2.3996e-02, -7.2249e-02,  3.6781e-02],
          [-1.4687e-03,  1.4902e-01, -3.3184e-02],
          [-5.3797e-02,  3.4834e-01, -3.2597e-01]]],


        [[[-7.2016e-

In [5]:
summary(model, input_size=(1, 3, 256, 256), device="cpu")



Layer (type:depth-idx)                             Output Shape              Param #
ModelWrapper                                       [1, 1]                    --
├─DG_model: 1-1                                    [1, 1]                    --
│    └─FeatureEmbedderMobileNetV2Dli14: 2-3        [1, 128]                  (recursive)
│    │    └─AdaptiveAvgPool2d: 3-1                 [1, 512, 1, 1]            --
│    │    └─Identity: 3-6                          [1, 512]                  --
│    └─Classifier: 2-2                             [1, 1]                    --
│    │    └─Linear: 3-3                            [1, 64]                   8,256
│    │    └─Linear: 3-4                            [1, 32]                   2,080
│    │    └─Linear: 3-5                            [1, 1]                    33
│    └─FeatureEmbedderMobileNetV2Dli14: 2          --                        --
│    │    └─Identity: 3-6                          [1, 512]                  --
│    └─FeatureEmbedd

In [6]:
model(torch.randn(1, 3, 256, 256))

tensor([[0.3797]], grad_fn=<SigmoidBackward0>)

In [7]:
torch.onnx.export(
    model, 
    torch.randn(1, 3, 256, 256),
    "./weights/mnv2_public.onnx",
    output_names=["main_out"],
    opset_version=11,
)

In [8]:
cls_out = model(torch.randn(1,3,256,256))

In [9]:
cls_out

tensor([[0.3815]], grad_fn=<SigmoidBackward0>)