In [1]:
import warnings
warnings.filterwarnings(action='ignore')

import os
import gc
import cv2
import random
import argparse
import numpy as np
import pandas as pd
from glob import glob

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from tqdm.autonotebook import tqdm
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split

import wandb

from utils_.set_path import *
from utils_.set_seed import seed_everything
from utils_.loss import FocalLoss, LabelSmoothingLoss, F1Loss
from utils_.get_class_weight import calc_class_weight
from runner.pytorch_timm import TimmModel
from runner.train_runner import CustomTrainer
from data.dataset import CustomTrainDataset

In [31]:
def main(args, model, model_name, device):
    train_df = pd.read_csv(TRAIN_CSV_PATH)
    train_, val_ = train_test_split(train_df, test_size=0.2, random_state=args.seed)
    
    val_dataset = CustomTrainDataset(model_name, data_dir, val_['path'].values, args.resize, transforms=False)
    valid_dataloader = DataLoader(val_dataset, batch_size = args.batch_size, shuffle=False, num_workers=8)
    
    model.to(device)
    model = nn.DataParallel(model)
    model.eval()
    
    pred_list, prob_list, label_list = [], [], []
    with torch.no_grad():
        for imgs, labels in tqdm(iter(valid_dataloader)):
            imgs = imgs.float().to(device)
            labels = labels.to(device)

            logit = model(imgs)
            
            pred_list += logit.argmax(1).detach().cpu().numpy().tolist()
            # prob_list += logit.softmax(1).detach().cpu().numpy().tolist()
            prob_list += (logit.max(1))[0].detach().cpu().numpy().tolist()
            label_list += labels.detach().cpu().numpy().tolist()
            
        return pred_list, prob_list, label_list

In [3]:
parser = argparse.ArgumentParser()
    
# Data and model checkpoints directories
parser.add_argument('--seed', type=int, default=909)
parser.add_argument("--resize", nargs="+", type=int, default=[256, 192])
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--model', type=str, default='resnet34')

args = parser.parse_args(args=[])

In [4]:
seed_everything(args.seed)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [5]:
data_dir = TRAIN_IMG_FOLDER_PATH
model_dir = '/workspace/models_all'
project_idx = len(glob('/workspace/models_all/*'))
project_idx = 9
print(f'Project {project_idx} Valid Inference Start')

Project 9 Valid Inference Start


In [15]:
# Mask
model_name = "Mask"
args.model = 'convnext_base'
args.resize = [256, 192]

mask_model_weights = torch.load(glob(f'{model_dir}/Project{project_idx}/Mask*')[0])
mask_model = TimmModel(args, num_classes=3, pretrained=True).to(device)
mask_model.load_state_dict(mask_model_weights)

mask_preds, _, mask_labels = main(args, mask_model, model_name, device)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))




In [16]:
# Gender
model_name = "Gender"
args.model = 'convnext_base'
args.resize = [256, 192]

gender_model_weights = torch.load(glob(f'{model_dir}/Project{project_idx}/Gender*')[0])
gender_model = TimmModel(args, num_classes=2, pretrained=True).to(device)
gender_model.load_state_dict(gender_model_weights)

gender_preds, _, gender_labels = main(args, gender_model, model_name, device)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))




In [32]:
# Age
model_name = "Age"
args.model = 'resnet34'
args.resize = [512, 384]

age_model_weights = torch.load(glob(f'{model_dir}/Project{project_idx}/Age0*')[0])
age_model0 = TimmModel(args, num_classes=2, pretrained=True).to(device)
age_model0.load_state_dict(age_model_weights)

age_preds_labels0, age_preds_prob0, age_labels0 = main(args, age_model0, model_name, device)

age_model_weights = torch.load(glob(f'{model_dir}/Project{project_idx}/Age1*')[0])
age_model1 = TimmModel(args, num_classes=2, pretrained=True).to(device)
age_model1.load_state_dict(age_model_weights)

age_preds_labels1, age_preds_prob1, age_labels1 = main(args, age_model1, model_name, device)

age_model_weights = torch.load(glob(f'{model_dir}/Project{project_idx}/Age2*')[0])
age_model2 = TimmModel(args, num_classes=2, pretrained=True).to(device)
age_model2.load_state_dict(age_model_weights)

age_preds_labels2, age_preds_prob2, age_labels2 = main(args, age_model2, model_name, device)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=119.0), HTML(value='')))




In [33]:
"""
Mask               Gender          Age
- mask      : 0    - male   : 0    - 30 미만         : 0 
- incorrect : 1    - female : 1    - 30 이상 60 미만 : 1
- normal    : 2                    - 60 이상         : 2
"""
label_dict = {(0, 0, 0): 0, (0, 0, 1): 1, (0, 0, 2): 2, (0, 1, 0): 3, (0, 1, 1): 4, 
              (0, 1, 2): 5, (1, 0, 0): 6, (1, 0, 1): 7, (1, 0, 2): 8, (1, 1, 0): 9, 
              (1, 1, 1): 10, (1, 1, 2): 11, (2, 0, 0): 12, (2, 0, 1): 13, (2, 0, 2): 14, 
              (2, 1, 0): 15, (2, 1, 1): 16, (2, 1, 2): 17}

# preds, labels = [], []
# for (mask_pred, gender_pred, age_pred), (mask_label, gender_label, age_label) in zip(zip(mask_preds, gender_preds, age_preds), zip(mask_labels, gender_labels, age_labels)):
#     pred = (mask_pred, gender_pred, age_pred)
#     preds.append(label_dict[pred])
#     label = (mask_label, gender_label, age_label)
#     labels.append(label_dict[label])

# print(f1_score(labels, preds, average='macro'))

error_index = []
age_preds, age_labels = [], []
for idx, ((age_pred_label0, age_pred_prob0, age_label), (age_pred_label1, age_pred_prob1, _), (age_pred_label2, age_pred_prob2, _)) in enumerate(zip(zip(age_preds_labels0, age_preds_prob0, age_labels0), zip(age_preds_labels1, age_preds_prob1, age_labels1), zip(age_preds_labels2, age_preds_prob2, age_labels2))):
    if age_pred_label0 == 0 and age_pred_label1 == 1 and age_pred_label2 == 1:
        age_pred = 0
    elif age_pred_label0 == 1 and age_pred_label1 == 0 and age_pred_label2 == 1:
        age_pred = 1
    elif age_pred_label0 == 1 and age_pred_label1 == 1 and age_pred_label2 == 0:
        age_pred = 2
    elif age_pred_label0 == 0 and age_pred_label1 == 0 and age_pred_label2 == 1: # 001
        if age_pred_prob0 > age_pred_prob1: age_pred = 0
        else: age_pred = 1
    elif age_pred_label0 == 0 and age_pred_label1 == 1 and age_pred_label2 == 0: # 010
        if age_pred_prob0 > age_pred_prob2: age_pred = 0
        else: age_pred = 2
    elif age_pred_label0 == 1 and age_pred_label1 == 0 and age_pred_label2 == 0: # 100
        if age_pred_prob1 > age_pred_prob2: age_pred = 1
        else: age_pred = 2
    elif age_pred_label0 == 0 and age_pred_label1 == 0 and age_pred_label2 == 0: # 000
        age_pred = np.argmax(np.array([age_pred_prob0, age_pred_prob1, age_pred_prob2]))
    elif age_pred_label0 == 1 and age_pred_label1 == 1 and age_pred_label2 == 1: # 111
        # print([age_pred_prob0, age_pred_prob1, age_pred_prob2])
        age_pred = np.argmin(np.array([age_pred_prob0, age_pred_prob1, age_pred_prob2]))
        
    if age_pred != age_label:
        error_index.append(idx)
        
    age_preds.append(age_pred)
    age_labels.append(age_label)
    
print(len(error_index))

189


In [34]:
print(set(age_preds))
print(len(mask_preds), len(gender_preds), len(age_preds))
print(len(mask_labels), len(gender_labels), len(age_labels))

{0, 1, 2}
3780 3780 3780
3780 3780 3780


In [36]:
"""
Mask               Gender          Age
- mask      : 0    - male   : 0    - 30 미만         : 0 
- incorrect : 1    - female : 1    - 30 이상 60 미만 : 1
- normal    : 2                    - 60 이상         : 2
"""
label_dict = {(0, 0, 0): 0, (0, 0, 1): 1, (0, 0, 2): 2, (0, 1, 0): 3, (0, 1, 1): 4, 
              (0, 1, 2): 5, (1, 0, 0): 6, (1, 0, 1): 7, (1, 0, 2): 8, (1, 1, 0): 9, 
              (1, 1, 1): 10, (1, 1, 2): 11, (2, 0, 0): 12, (2, 0, 1): 13, (2, 0, 2): 14, 
              (2, 1, 0): 15, (2, 1, 1): 16, (2, 1, 2): 17}

preds, labels = [], []
for (mask_pred, gender_pred, age_pred), (mask_label, gender_label, age_label) in zip(zip(mask_preds, gender_preds, age_preds), zip(mask_labels, gender_labels, age_labels)):
    pred = (mask_pred, gender_pred, age_pred)
    preds.append(label_dict[pred])
    label = (mask_label, gender_label, age_label)
    labels.append(label_dict[label])

print(f1_score(labels, preds, average='macro'))

0.9020256178815712


In [44]:
test_dir = TEST_IMG_PATH
model_dir = '/workspace/models_all'
project_idx = len(glob('/workspace/models_all/*'))
project_idx = '12'
print(f'Project {project_idx} Inference Start !')

Project 12 Inference Start !


In [45]:
test_df = pd.read_csv(TEST_CSV_PATH)
image_paths = [os.path.join(test_dir, img_id) for img_id in test_df.ImageID]

In [47]:
"""
Mask               Gender          Age
- mask      : 0    - male   : 0    - 30 미만         : 0 
- incorrect : 1    - female : 1    - 30 이상 60 미만 : 1
- normal    : 2                    - 60 이상         : 2
"""
preds = []
for mask_pred, gender_pred, age_pred in zip(mask_preds, gender_preds, age_preds):
    temp = (mask_pred, gender_pred, age_pred)
    preds.append(label_dict[temp])

test_df['ans'] = preds
test_df.to_csv(f'/workspace/submits/{args.model}.csv', index=False)
print(f'Project {project_idx} inference is done!')

3780


ValueError: Length of values (3780) does not match length of index (12600)