In [None]:
import os
import random
import warnings
import copy
from enum import Enum

import pandas as pd
import numpy as np
from PIL import Image
from tqdm import tqdm, tqdm_notebook
from sklearn.model_selection import StratifiedKFold, KFold, train_test_split
from sklearn.metrics import f1_score
import albumentations
import matplotlib as mpl
import matplotlib.pyplot as plt

%matplotlib inline

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader, random_split, WeightedRandomSampler

from torchvision import transforms

In [None]:
# Set random seed
SEED = 2021
random.seed(SEED)
np.random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)  # type: ignore
torch.backends.cudnn.deterministic = True  # type: ignore
torch.backends.cudnn.benchmark = True  # type: ignore

In [None]:
# 현재 OS 및 라이브러리 버전 체크 체크
#current_os = platform.system()
#print(f"Current OS: {current_os}")
print(f"CUDA: {torch.cuda.is_available()}")
#print(f"Python Version: {platform.python_version()}")
print(f"torch Version: {torch.__version__}")
#print(f"torchvision Version: {torchvision.__version__}")

# 중요하지 않은 에러 무시
warnings.filterwarnings(action='ignore')

# 유니코드 깨짐현상 해결
mpl.rcParams['axes.unicode_minus'] = False

# face crop image

In [None]:
%%time
from facenet_pytorch import MTCNN
import cv2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
mtcnn = MTCNN(keep_all=True, device=device)
new_img_dir = '../input/data/train/new_imgs/'
train_feat_dir = '../input/data/train'
test_feat_dir = '../input/data/eval'
train_img_dir = '../input/data/train/images/'
test_img_dir = '../input/data/eval/images/'

cnt = 0

for paths in os.listdir(train_img_dir):
    if paths[0] == '.': continue
    os.mkdir(new_img_dir + paths)
    sub_dir = os.path.join(train_img_dir, paths)
    
    for imgs in os.listdir(sub_dir):
        if imgs[0] == '.': continue
        
        img_dir = os.path.join(sub_dir, imgs)
        img = cv2.imread(img_dir)
        img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
        
        #mtcnn 적용
        boxes,probs = mtcnn.detect(img)
        
        # boxes 확인
        if len(probs) > 1:
            pass 
            #print(boxes)
        if not isinstance(boxes, np.ndarray):
            #print('Nope!')
            # 직접 crop
            img = img[100:400, 50:350, :]
        else: # boexes size 확인
            xmin = int(boxes[0, 0])-30
            ymin = int(boxes[0, 1])-30
            xmax = int(boxes[0, 2])+30
            ymax = int(boxes[0, 3])+30
            
            if xmin < 0: xmin = 0
            if ymin < 0: ymin = 0
            if xmax > 384: xmax = 384
            if ymax > 512: ymax = 512
            
            img = img[ymin:ymax, xmin:xmax, :]
            
        tmp = os.path.join(new_img_dir, paths)
        cnt += 1
        imageio.imwrite(os.path.join(tmp, imgs), img)

print(cnt)

# data preprocessing

In [None]:
%%time
IMG_EXTENSIONS = [
    ".jpg", ".JPG", ".jpeg", ".JPEG", ".png",
    ".PNG", ".ppm", ".PPM", ".bmp", ".BMP",
]

class MaskLabels(int, Enum):
    MASK = 0
    INCORRECT = 1
    NORMAL = 2


class GenderLabels(int, Enum):
    MALE = 0
    FEMALE = 1

    @classmethod
    def from_str(cls, value: str) -> int:
        value = value.lower()
        if value == "male":
            return cls.MALE
        elif value == "female":
            return cls.FEMALE
        else:
            raise ValueError(f"Gender value should be either 'male' or 'female', {value}")

class AgeLabels(int, Enum):
    YOUNG = 0
    MIDDLE = 1
    OLD = 2

    @classmethod
    def from_number(cls, value: str) -> int:
        try:
            value = int(value)
        except Exception:
            raise ValueError(f"Age value should be numeric, {value}")

        if value < 30:
            return cls.YOUNG
        elif value < 60:
            return cls.MIDDLE
        else:
            return cls.OLD
        
def encode_multi_class(mask_label, gender_label, age_label) -> int:
    return mask_label * 6 + gender_label * 3 + age_label

In [None]:
%%time

train_feat_dir = '../input/data/train'
test_feat_dir = '../input/data/eval'
train_img_dir = '../input/data/train/images/'
test_img_dir = '../input/data/eval/images/'

'''
MASK = Wear : 0, Incorrect : 1, Not Wear : 2
GENDER = Male : 0, Female : 1
AGE = ~ 29 : 0, 30 ~ 59 : 1, 60 ~ : 2
'''

# gender labeling outlier handling
female_to_male = ['006359', '006360', '006361', '006362', '006363', '006364'] 
male_to_female = ['001498-1', '004432']

num_classes = 3 * 2 * 3

_file_names = {
    "mask1": MaskLabels.MASK,
    "mask2": MaskLabels.MASK,
    "mask3": MaskLabels.MASK,
    "mask4": MaskLabels.MASK,
    "mask5": MaskLabels.MASK,
    "incorrect_mask": MaskLabels.INCORRECT,
    "normal": MaskLabels.NORMAL
}

img_paths = []
mask_labels = []
gender_labels = []
age_labels = []
multi_labels = []


profiles = os.listdir(train_img_dir)
for profile in profiles:
    if profile.startswith("."):  # "." 로 시작하는 파일은 무시합니다
        continue

    img_folder = os.path.join(train_img_dir, profile)
    for file_name in os.listdir(img_folder):
        _file_name, ext = os.path.splitext(file_name)
        if _file_name not in _file_names:  # "." 로 시작하는 파일 및 invalid 한 파일들은 무시합니다
            continue

        img_path = os.path.join(train_img_dir, profile, file_name)  # (resized_data, 000004_male_Asian_54, mask1.jpg)
        mask_label = _file_names[_file_name]

        id, gender, race, age = profile.split("_")
        ### gender labeling outlier handling
        if id in female_to_male:
            gender = 'male'
        if id in male_to_female:
            gender = 'female'
            
        gender_label = GenderLabels.from_str(gender)
        age_label = AgeLabels.from_number(age)

        img_paths.append(img_path)
        mask_labels.append(mask_label.value)
        gender_labels.append(gender_label.value)
        age_labels.append(age_label.value)
        multi_labels.append(encode_multi_class(mask_label, gender_label, age_label))

# dataset

In [None]:
class MyDataset(Dataset):
    '''
    MASK = Wear : 0, Incorrect : 1, Not Wear : 2
    GENDER = Male : 0, Female : 1
    AGE = ~ 29 : 0, 30 ~ 59 : 1, 60 ~ : 2
    '''
    def __init__(self, img_paths, labels, transform):
        self.img_paths = np.array(img_paths)
        self.labels = np.array(labels)
        self.transform = transform

    def __getitem__(self, idx: torch.Tensor) -> torch.Tensor:
        X, y = Image.open(self.img_paths[idx]), self.labels[idx]

        if self.transform:
            X = self.transform(X)
            
        return torch.tensor(X), torch.tensor(y)

    def __len__(self):
        return len(self.labels)


# Model

In [None]:
from efficientnet_pytorch import EfficientNet
import timm

class MaskModel(nn.Module):
    def __init__(self, num_classes: int = 3):
        super().__init__()
        self.net = EfficientNet.from_pretrained('efficientnet-b7', in_channels=3, num_classes=num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

class GenderModel(nn.Module):
    def __init__(self, num_classes: int = 2):
        super().__init__()
        self.net = EfficientNet.from_pretrained('efficientnet-b7', in_channels=3, num_classes=num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

    
class AgeModel(nn.Module):
    def __init__(self, num_classes: int = 3):
        super().__init__()
        self.net = EfficientNet.from_pretrained('efficientnet-b7', in_channels=3, num_classes=num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


# Loss

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, weight=None,
                 gamma=2., reduction='mean'):
        nn.Module.__init__(self)
        self.weight = weight
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, input_tensor, target_tensor):
        log_prob = F.log_softmax(input_tensor, dim=-1)
        prob = torch.exp(log_prob)
        return F.nll_loss(
            ((1 - prob) ** self.gamma) * log_prob,
            target_tensor,
            weight=self.weight,
            reduction=self.reduction
        )

# Early Stopping

In [None]:
#https://quokkas.tistory.com/entry/pytorch%EC%97%90%EC%84%9C-EarlyStop-%EC%9D%B4%EC%9A%A9%ED%95%98%EA%B8%B0

class EarlyStopping:
    """주어진 patience 이후로 validation loss가 개선되지 않으면 학습을 조기 중지"""
    def __init__(self, patience=7, verbose=False, delta=0):
        """
        Args:
            patience (int): epoch loss가 개선된 후 기다리는 기간
                            Default: 7
            verbose (bool): True일 경우 각 epoch loss의 개선 사항 메세지 출력
                            Default: False
            delta (float): 개선되었다고 인정되는 monitered quantity의 최소 변화
                            Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.epoch_loss_min = np.Inf
        self.delta = delta

#    def __call__(self, model, optimizer, epoch_loss, epoch_acc, epoch):
    def __call__(self, epoch_loss, epoch_acc):

        score = -epoch_loss

        if self.best_score is None:
            self.best_score = score
            #self.save_checkpoint(model, optimizer, epoch_loss, epoch_acc, epoch)
            self.epoch_loss_min = epoch_loss ##########
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            #self.save_checkpoint(model, optimizer, epoch_loss, epoch_acc, epoch)
            self.epoch_loss_min = epoch_loss #########
            self.counter = 0

#     def save_checkpoint(self, model, optimizer, epoch_loss, epoch_acc, epoch):
#         '''validation loss가 감소하면 모델을 저장한다.'''
#         if self.verbose:
#             print(f'Epoch loss decreased ({self.epoch_loss_min:.3f} --> {epoch_loss:.3f}).  Saving model ...')
#         torch.save({
#             'epoch': epoch,
#             'model_state_dict': model.state_dict(),
#             'optimizer_state_dict': optimizer.state_dict(),
#             'loss': epoch_loss
#         }, f'./saved/chkpoint_model_{epoch}_{epoch_loss:.3f}_{epoch_acc:.3f}.pt')
#         self.epoch_loss_min = epoch_loss


In [None]:
torch.cuda.empty_cache()

del mask_model, gender_model, age_model

# train

In [None]:
#RGB_MEAN = [0.558, 0.512, 0.478]
#RGB_STD = [0.218, 0.238, 0.252]

train_transform = transforms.Compose([
    transforms.Resize((224, 224), Image.BILINEAR),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation([-19, +19]),
    #transforms.GaussianBlur(kernel_size=501),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
    transforms.Resize((224, 224), Image.BILINEAR),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
])

####### hyper-parameter
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
LEARNING_RATE = 1e-4
NUM_EPOCH = 15
BATCH_SIZE = 16
PATIENCE = 7
NUM_WORKERS = 8
TEST_SIZE = 0.2


####### train test data split
train_mask_X, val_mask_X, train_mask_y, val_mask_y = train_test_split(img_paths, mask_labels,
                                        test_size=TEST_SIZE, random_state=SEED)
train_gender_X, val_gender_X, train_gender_y, val_gender_y = train_test_split(img_paths, gender_labels,
                                        test_size=TEST_SIZE, random_state=SEED)
train_age_X, val_age_X, train_age_y, val_age_y = train_test_split(img_paths, age_labels,
                                        test_size=TEST_SIZE, random_state=SEED)
_, val_X, _, val_y = train_test_split(img_paths, multi_labels,
                                        test_size=TEST_SIZE, random_state=SEED)


#val_indicies = set(random.choices(range(len(val_y))))
####### dataset & dataloader
train_mask_dataset = MyDataset(train_mask_X, train_mask_y, train_transform)
train_mask_dataloader = DataLoader(train_mask_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
train_gender_dataset = MyDataset(train_gender_X, train_gender_y, train_transform)
train_gender_dataloader = DataLoader(train_gender_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
train_age_dataset = MyDataset(train_age_X, train_age_y, train_transform)
train_age_dataloader = DataLoader(train_age_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_mask_dataset = MyDataset(val_mask_X, val_mask_y, val_transform)
val_mask_dataloader = DataLoader(val_mask_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
val_gender_dataset = MyDataset(val_gender_X, val_gender_y, val_transform)
val_gender_dataloader = DataLoader(val_gender_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
val_age_dataset = MyDataset(val_age_X, val_age_y, val_transform)
val_age_dataloader = DataLoader(val_age_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
val_dataset = MyDataset(val_X, val_y, val_transform)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)


#m_losses, g_losses, a_losses = [], [], []

g_losses, a_losses = [], []

####### train       
    
mask_model = MaskModel().to(device)

# for name, param in mask_model.named_parameters():
#     if '_fc' not in name:
#         param.requires_grad = False

m_parameters = [param for param in mask_model.parameters() if param.requires_grad]
m_optimizer = optim.Adam(m_parameters, lr=LEARNING_RATE)
m_criterion = nn.CrossEntropyLoss(weight=torch.tensor([1., 1., 5.]).to(device))
mask_model.train()
m_state = {'epoch_loss' : 999.}

for epoch in range(NUM_EPOCH):
    m_epoch_loss = 0.0
    for idx, (m_imgs, m_lbls) in enumerate(tqdm(train_mask_dataloader)):
        m_imgs = m_imgs.to(device)
        m_lbls = m_lbls.to(device)

        m_optimizer.zero_grad()
        
        m_logits = mask_model(m_imgs)
        #_, preds = torch.max(logits, 1)
        m_loss = m_criterion(m_logits, m_lbls)

        m_loss.backward()
        m_optimizer.step()
        
        m_epoch_loss += m_loss.item() * m_imgs.size(0)
    
    m_epoch_loss = float(m_epoch_loss / len(train_mask_dataloader.dataset))
    m_losses.append(m_epoch_loss)
    
    if m_epoch_loss < m_state['epoch_loss']:
        m_state['model_state_dict'] = copy.deepcopy(mask_model.state_dict())
        m_state['optimizer_state_dict'] = copy.deepcopy(m_optimizer.state_dict())
        m_state['epoch_loss'] = m_epoch_loss

torch.save(m_state, f"./saved/mask_model_b7.pt")

#################
torch.cuda.empty_cache()

gender_model = GenderModel().to(device)

# for name, param in gender_model.named_parameters():
#     if '_fc' not in name:
#         param.requires_grad = False

g_parameters = [param for param in gender_model.parameters() if param.requires_grad]
g_optimizer = optim.Adam(g_parameters, lr=LEARNING_RATE)
g_criterion = nn.CrossEntropyLoss()
gender_model.train()
g_state = {'epoch_loss' : 999.}

for epoch in range(NUM_EPOCH):
    g_epoch_loss = 0.0
    for idx, (g_imgs, g_lbls) in enumerate(tqdm(train_gender_dataloader)):
        g_imgs = g_imgs.to(device)
        g_lbls = g_lbls.to(device)

        g_optimizer.zero_grad()
        
        g_logits = gender_model(g_imgs)
        g_loss = g_criterion(g_logits, g_lbls)

        g_loss.backward()
        g_optimizer.step()
        
        g_epoch_loss += g_loss.item() * g_imgs.size(0)
        
    g_epoch_loss = float(g_epoch_loss / len(train_gender_dataloader.dataset))
    g_losses.append(g_epoch_loss)
    
    if g_epoch_loss < g_state['epoch_loss']:
        g_state['model_state_dict'] = copy.deepcopy(gender_model.state_dict())
        g_state['optimizer_state_dict'] = copy.deepcopy(g_optimizer.state_dict())
        g_state['epoch_loss'] = g_epoch_loss
    
    
torch.save(g_state, f"./saved/gender_model_b7.pt")
    
    
#################
torch.cuda.empty_cache()

age_model = AgeModel().to(device)

# for name, param in age_model.named_parameters():
#     if '_fc' not in name:
#         param.requires_grad = False

a_parameters = [param for param in age_model.parameters() if param.requires_grad]
a_optimizer = optim.Adam(a_parameters, lr=LEARNING_RATE)
a_criterion = nn.CrossEntropyLoss(weight=torch.tensor([1., 1., 7.]).to(device))
age_model.train()
a_state = {'epoch_loss' : 999.}

for epoch in range(NUM_EPOCH):
    a_epoch_loss = 0.0
    for idx, (a_imgs, a_lbls) in enumerate(tqdm(train_age_dataloader)):
        a_imgs = a_imgs.to(device)
        a_lbls = a_lbls.to(device)

        a_optimizer.zero_grad()
        
        a_logits = age_model(a_imgs)
        #_, preds = torch.max(logits, 1)
        a_loss = a_criterion(a_logits, a_lbls)

        a_loss.backward()
        a_optimizer.step()

        a_epoch_loss += a_loss.item() * a_imgs.size(0)
        
    a_epoch_loss = float(a_epoch_loss / len(train_age_dataloader.dataset))
    a_losses.append(a_epoch_loss)
    
    if a_epoch_loss < a_state['epoch_loss']:
        a_state['model_state_dict'] = copy.deepcopy(age_model.state_dict())
        a_state['optimizer_state_dict'] = copy.deepcopy(a_optimizer.state_dict())
        a_state['epoch_loss'] = a_epoch_loss
    

    
torch.save(a_state, f"./saved/age_model_b7.pt")

        

# evaluation

In [None]:
####### validation
mask_preds, gender_preds, age_preds = None, None, None
step_acc, step_f1, n_step = 0., 0., 0.

mask_iter = iter(val_mask_dataloader)
gender_iter = iter(val_gender_dataloader)
age_iter = iter(val_age_dataloader)

mask_model = MaskModel().to(device)
m_chk_pts = torch.load('./saved/mask_model_b7.pt')
mask_model.load_state_dict(m_chk_pts['model_state_dict'])
gender_model = GenderModel().to(device)
g_chk_pts = torch.load('./saved/gender_model_b7.pt')
gender_model.load_state_dict(g_chk_pts['model_state_dict'])
age_model = AgeModel().to(device)
a_chk_pts = torch.load('./saved/age_model_b7.pt')
age_model.load_state_dict(a_chk_pts['model_state_dict'])

mask_model.eval()
gender_model.eval()
age_model.eval()

for idx, (val_ims, val_lbs) in enumerate(tqdm(val_dataloader)):
    mask_ims, mask_lbs = next(mask_iter)
    gen_ims, gen_lbs = next(gender_iter)
    age_ims, age_lbs = next(age_iter)

    mask_logits = mask_model(mask_ims.to(device))
    mask_logits = mask_logits.detach().cpu().numpy()
    
    gender_logits = gender_model(gen_ims.to(device))
    gender_logits = gender_logits.detach().cpu().numpy()

    age_logits = age_model(age_ims.to(device))
    age_logits = age_logits.detach().cpu().numpy()
    
    sum_preds = []
    for idx in range(len(mask_logits)):
        temp = []
        for m in mask_logits[idx]:
            for g in gender_logits[idx]:
                for a in age_logits[idx]:
                    temp.append(m+g+a)
        sum_preds.append(temp)
    
    sum_preds = np.argmax(sum_preds, 1)

    step_acc += np.sum(np.array(val_lbs) == np.array(sum_preds))
    step_f1 += f1_score(val_lbs, sum_preds, average='macro')
    n_step += 1
    
    del mask_logits, gender_logits, age_logits

In [None]:
step_f1 / n_step

# Inference

In [None]:
test_dir = '/opt/ml/input/data/eval'
m_chk_pts_dir = './saved/mask_model_b7.pt'
g_chk_pts_dir = './saved/gender_model_b7.pt'
a_chk_pts_dir = './saved/age_model_b7.pt'

class TestDataset(Dataset):
    def __init__(self, img_paths, transform):
        self.img_paths = img_paths
        self.transform = transform

    def __getitem__(self, index):
        image = Image.open(self.img_paths[index])

        if self.transform:
            image = self.transform(image)
        return image

    def __len__(self):
        return len(self.img_paths)

# meta 데이터와 이미지 경로를 불러옵니다.
submission = pd.read_csv(os.path.join(test_dir, 'info.csv'))
image_dir = os.path.join(test_dir, 'images')

# Test Dataset 클래스 객체를 생성하고 DataLoader를 만듭니다.
image_paths = [os.path.join(image_dir, img_id) for img_id in submission.ImageID]

dataset = TestDataset(image_paths, val_transform)

loader = DataLoader(
    dataset,
    shuffle=False,
    num_workers=NUM_WORKERS
)

m_chk_pts = torch.load(m_chk_pts_dir)
mask_model = MaskModel().to(device)
mask_model.load_state_dict(m_chk_pts['model_state_dict'])
g_chk_pts = torch.load(g_chk_pts_dir)
gender_model = GenderModel().to(device)
gender_model.load_state_dict(g_chk_pts['model_state_dict'])
a_chk_pts = torch.load(a_chk_pts_dir)
age_model = AgeModel().to(device)
age_model.load_state_dict(a_chk_pts['model_state_dict'])


mask_model.eval()
gender_model.eval()
age_model.eval()

# 모델이 테스트 데이터셋을 예측하고 결과를 저장합니다.
all_predictions = []
for images in loader:
    with torch.no_grad():
        images = images.to(device)
        
        m_logits = mask_model(images)
        m_logits = m_logits.detach().cpu().numpy()
        
        g_logits = gender_model(images)
        g_logits = g_logits.detach().cpu().numpy()
        
        a_logits = age_model(images)
        a_logits = a_logits.detach().cpu().numpy()
        
        add_preds = []
        for idx in range(len(m_logits)):
            _temp = []
            for m in m_logits[idx]:
                for g in g_logits[idx]:
                    for a in a_logits[idx]:
                        _temp.append(m+g+a)
            add_preds.append(tempo)
            
        add_preds = np.argmax(add_preds, 1)
        all_predictions.extend(add_preds)
        
submission['ans'] = all_predictions

# 제출할 파일을 저장합니다.
submission.to_csv(os.path.join(test_dir, 'submission.csv'), index=False)
print('test inference is done!')

In [None]:
for idx, (val_ims, val_lbs) in enumerate(tqdm(val_dataloader)):
    mask_ims, mask_lbs = next(mask_iter)
    gen_ims, gen_lbs = next(gender_iter)
    age_ims, age_lbs = next(age_iter)
    
    mask_model.eval()
    gender_model.eval()
    age_model.eval()
    
    mask_logits = mask_model(mask_ims.to(device))
    gender_logits = gender_model(gen_ims.to(device))
    age_logits = age_model(age_ims.to(device))
    
    mask_logits = mask_logits.detach().cpu().numpy()
    gender_logits = gender_logits.detach().cpu().numpy()
    age_logits = age_logits.detach().cpu().numpy()
    
    sum_preds = []
    for idx in range(len(mask_logits)):
        temp = []
        for m in mask_logits[idx]:
            for g in gender_logits[idx]:
                for a in age_logits[idx]:
                    temp.append(m+g+a)
        sum_preds.append(temp)
    
    sum_preds = np.argmax(sum_preds, 1)