In [8]:
# System Libs
import multiprocessing as mp
import sys
import os
import argparse
import json
from importlib import import_module
from pathlib import Path
from glob import glob
from time import time

# Other Libs
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
from sklearn.metrics import f1_score
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import f1_score

# Torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import models
from torchvision import transforms
from torchvision.transforms import CenterCrop, Resize, ToTensor, Normalize
device = ("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Local Libs
from dataset import TrainInfo, MaskBaseDataset, TestDataset
from model import BaseModel, ResNet18Pretrained
from loss import get_criterion
import settings
import logger

cuda


In [9]:
import argparse
import os
from importlib import import_module

import pandas as pd
import torch
from torch.utils.data import DataLoader

from dataset import MaskBaseDataset, TestDataset

In [10]:
def load_model(model_dir, device):
    # model_path = os.path.join(model_dir, args.model_name)
    model = torch.load(model_dir, map_location=device)

    return model 

In [11]:
    parser = argparse.ArgumentParser()

    # Container environment
    parser.add_argument('--data_dir', type=str, default=os.environ.get('SM_CHANNEL_EVAL', '/opt/ml/input/data/eval'))
    parser.add_argument('--new_dataset', type=bool, default=False)
    parser.add_argument('--model_dir', type=str, default=os.environ.get('SM_CHANNEL_MODEL', './model'))
    parser.add_argument('--name', type=str, default='exp')
    parser.add_argument('--output_dir', type=str, default=os.environ.get('SM_OUTPUT_DATA_DIR', './output'))
    parser.add_argument('--model_name', type=str, default='best.pt')

    parser.add_argument('--batch_size', type=int, default=1000, help='input batch size for validing (default: 1000)')
    parser.add_argument('--resize', type=tuple, default=(512, 384), help='resize size for image when you trained (default: (512, 384))')
    parser.add_argument('--mode', type=str, default='all', help='choose all or ensemble')
    args = parser.parse_args([])

    os.makedirs(args.output_dir, exist_ok=True)

In [23]:
# Settings
data_dir = args.data_dir
output_dir = args.output_dir

is_cuda = torch.cuda.is_available()
device = torch.device("cuda" if is_cuda else "cpu")

# Models
model_dir = '/opt/ml/image-classification-level1-05/model'

######################### ADD ##########################
model_mask_jw = Path(model_dir).joinpath('mask_jw.pt')
model_gender_jw = Path(model_dir).joinpath('gender_jw.pt')
model_age_jw = Path(model_dir).joinpath('age_jw.pt')

model_mask_jw = load_model(model_mask_jw, device).to(device)
model_gender_jw = load_model(model_gender_jw, device).to(device)
model_age_jw = load_model(model_age_jw, device).to(device)

model_mask_jw.eval()
model_gender_jw.eval()
model_age_jw.eval()
########################################################

model_mask = Path(model_dir).joinpath('model_mask/model_mask_f1.pt')
model_mask_age = Path(model_dir).joinpath('model_mask_age/model_mask_age_f1.pt')
model_mask_gender = Path(model_dir).joinpath('model_mask_gender/model_mask_gender_f1.pt')
model_nomask_age = Path(model_dir).joinpath('model_nomask_age/model_nomask_age_f1.pt')
model_nomask_gender = Path(model_dir).joinpath('model_nomask_gender/model_nomask_gender_f1.pt')

model_mask = load_model(model_mask, device).to(device)
model_mask_age = load_model(model_mask_age, device).to(device)
model_mask_gender = load_model(model_mask_gender, device).to(device)
model_nomask_age = load_model(model_nomask_age, device).to(device)
model_nomask_gender = load_model(model_nomask_gender, device).to(device)

model_mask.eval()
model_mask_age.eval()
model_mask_gender.eval()
model_nomask_age.eval()
model_nomask_gender.eval()

# Image Files & DataLoader
img_root = os.path.join(data_dir, 'images')
info_path = os.path.join(data_dir, 'info.csv')
info = pd.read_csv(info_path)

img_paths = [os.path.join(img_root, img_id) for img_id in info.ImageID]
dataset = TestDataset(img_paths, args.resize)
loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    num_workers=2,
    shuffle=False,
    pin_memory=is_cuda,
    drop_last=False,
)

In [29]:
print("Calculating inference results..")
preds = []
with torch.no_grad():
    for images in tqdm(loader):

        images = images.to(device)

        # 마스크 예측 - 0: Wear, 1: Incorrect, 2: Not Wear
        pred = model_mask_jw(images)
        pred_mask = pred.argmax(dim=-1)
        
        pred = model_gender_jw(images)
        pred_gender = pred.argmax(dim=-1)

        pred = model_age_jw(images)
        pred_age = pred.argmax(dim=-1)

        # 마스크 착용 케이스 - 광현
        if int(pred_mask) != 2:
            if int(pred_age) == 1:
                pred = model_mask_age(images)
                cand = pred.argmax(dim=-1)
                pred_age = cand if pred_age == 2 else pred_age

            # pred = model_nomask_gender(images)
            # pred_gender = pred.argmax(dim=-1)

        # 마스크 착용 케이스
        # else:   
        #     pred = model_mask_age(images)
        #     pred_age_cond = pred.argmax(dim=-1)

            # pred = model_mask_gender(images)
            # pred_gender = pred.argmax(dim=-1)    
        

        result = pred_mask * 6 + pred_gender * 3 + pred_age
        preds.extend(result.cpu().numpy())



    info['ans'] = preds
    info.to_csv(os.path.join(output_dir, f'{args.name}_output.csv'), index=False)
    print(f'Inference Done!')

Calculating inference results..


100%|██████████| 12600/12600 [01:46<00:00, 118.28it/s]

Inference Done!





<br><br><br>

## 분포 확인
---

In [27]:
ans_list = info.ans.unique()
ans_list.sort()

print(f'length : {len(ans_list)} | {ans_list}')

length : 18 | [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17]
