In [1]:
from torchinfo import summary
import torch

from going_modular.model.MTLFaceRecognition import MTLFaceRecognition

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # Xác định thiết bị

CONFIGURATION = {
    'backbone': 'miresnet18',
    'image_size': 256,
    'num_classes': 123
}

In [2]:
model = MTLFaceRecognition(backbone=CONFIGURATION['backbone'], num_classes=CONFIGURATION['num_classes']).to(device)

# Summarize the model
summary(
    model=model,
    input_size=(16, 3, 256, 256),
    col_names=["input_size", "output_size", "num_params", "trainable"],
    col_width=20,
    row_settings=["var_names"]
)

Layer (type (var_name))                                 Input Shape          Output Shape         Param #              Trainable
MTLFaceRecognition (MTLFaceRecognition)                 [16, 3, 256, 256]    [16, 2]              --                   True
├─MIResNet (backbone)                                   [16, 3, 256, 256]    [16, 512, 8, 8]      3,954,764            True
│    └─Conv2d (conv1)                                   [16, 3, 256, 256]    [16, 64, 128, 128]   9,408                True
│    └─BatchNorm2d (bn1)                                [16, 64, 128, 128]   [16, 64, 128, 128]   128                  True
│    └─PReLU (prelu)                                    [16, 64, 128, 128]   [16, 64, 128, 128]   1                    True
│    └─Sequential (layer1)                              [16, 64, 128, 128]   [16, 64, 64, 64]     --                   True
│    │    └─BasicBlock (0)                              [16, 64, 128, 128]   [16, 64, 64, 64]     73,985               True
│  

In [3]:
model.eval()
dumpy_input = torch.randn(1,3,256,256).to(device)
features = model(dumpy_input)
x_id, x_gender, x_pose, x_emotion, x_facial_hair, x_occlusion, x_spectacles = features

x_gender.shape

torch.Size([1, 512, 8, 8])

In [4]:
from going_modular.model.head.gender import GenderDetectModule

age_estimate = GenderDetectModule(256).to(device)

In [5]:
x = age_estimate(x_gender)

x.shape

torch.Size([1, 101])