In [1]:
import argparse
import glob
import json
import multiprocessing
import os
import random
import re
from importlib import import_module
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from dataset import MaskBaseDataset, ProfileClassEqualSplitTrainMaskDataset
from loss import create_criterion

In [2]:
def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)   

In [3]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']


In [4]:
def grid_image(np_images, gts, preds, n=16, shuffle=False):
    batch_size = np_images.shape[0]
    assert n <= batch_size

    choices = random.choices(range(batch_size), k=n) if shuffle else list(range(n))
    figure = plt.figure(figsize=(12, 18 + 2))  # cautions: hardcoded, 이미지 크기에 따라 figsize 를 조정해야 할 수 있습니다. T.T
    plt.subplots_adjust(top=0.8)               # cautions: hardcoded, 이미지 크기에 따라 top 를 조정해야 할 수 있습니다. T.T
    n_grid = np.ceil(n ** 0.5)
    tasks = ["mask", "gender", "age"]
    for idx, choice in enumerate(choices):
        gt = gts[choice].item()
        pred = preds[choice].item()
        image = np_images[choice]
        # title = f"gt: {gt}, pred: {pred}"
        gt_decoded_labels = ProfileClassEqualSplitTrainMaskDataset.decode_multi_class(gt)
        pred_decoded_labels = ProfileClassEqualSplitTrainMaskDataset.decode_multi_class(pred)
        title = "\n".join([
            f"{task} - gt: {gt_label}, pred: {pred_label}"
            for gt_label, pred_label, task
            in zip(gt_decoded_labels, pred_decoded_labels, tasks)
        ])

        plt.subplot(n_grid, n_grid, idx + 1, title=title)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(image, cmap=plt.cm.binary)

    return figure

In [5]:
def increment_path(path, exist_ok=False):
    """ Automatically increment path, i.e. runs/exp --> runs/exp0, runs/exp1 etc.

    Args:
        path (str or pathlib.Path): f"{model_dir}/{args.name}".
        exist_ok (bool): whether increment path (increment if False).
    """
    path = Path(path)
    if (path.exists() and exist_ok) or (not path.exists()):
        return str(path)
    else:
        dirs = glob.glob(f"{path}*")
        matches = [re.search(rf"%s(\d+)" % path.stem, d) for d in dirs]
        i = [int(m.groups()[0]) for m in matches if m]
        n = max(i) + 1 if i else 2
        return f"{path}{n}"

In [6]:
def mask_label(label):
    if label<6:
        return 0
    elif label<12:
        return 1
    else:
        return 2
    
def gender_label(label):
    if label%6<3:
        return 0
    else:
        return 1

def age_label(label):
    if label%3==0:
        return 0
    elif label%3==1:
        return 1
    else: 
        return 2
    
def under_age_label(label):
    if label%3==0:
        return 0
    else: 
        return 1 

def over_age_label(label):
    if label%3==0:
        return 0
    elif label%3==1:
        return 0
    else: 
        return 1

In [7]:
def train(data_dir, model_dir, args):
    seed_everything(args.seed)

    save_dir = increment_path(os.path.join(model_dir, args.name))

    # -- settings
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # -- dataset
    dataset_module = getattr(import_module("dataset"), args.dataset)  # default: BaseAugmentation
    dataset = dataset_module(
        data_dir=data_dir,
    )
    mask_num_classes = dataset.mask_num_classes #3
    gender_num_classes=dataset.gender_num_classes #2
    age_num_classes=dataset.age_num_classes #3
    

    # -- augmentation
    transform_module = getattr(import_module("dataset"), args.augmentation)  # default: BaseAugmentation
    transform = transform_module(
        resize=args.resize,
        mean=dataset.mean,
        std=dataset.std,
    )
    dataset.set_transform(transform)

    # -- data_loader
    train_set, val_set = dataset.split_dataset()

    train_loader = DataLoader(
        train_set,
        batch_size=args.batch_size,
        num_workers=multiprocessing.cpu_count()//2,
        shuffle=True,
        pin_memory=use_cuda,
        drop_last=True,
    )

    val_loader = DataLoader(
        val_set,
        batch_size=args.valid_batch_size,
        num_workers=multiprocessing.cpu_count()//2,
        shuffle=False,
        pin_memory=use_cuda,
        drop_last=True,
    )

    # _model
    model_module = getattr(import_module("model"), args.model)  # default: BaseModel
    model = model_module(
        num_classes=mask_num_classes
    ).to(device)
    model = torch.nn.DataParallel(model)

        
        
    # -- loss & metric
    binary_criterion=create_criterion(args.binary_criterion)
    cross_criterion = create_criterion(args.cross_criterion)  # default: cross_entropy
    opt_module = getattr(import_module("torch.optim"), args.optimizer)  # default: SGD
    optimizer = opt_module(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=args.lr,
        weight_decay=5e-4
    )


    scheduler = StepLR(optimizer, args.lr_decay_step, gamma=0.5)  
    
    
    
    # -- logging
    logger = SummaryWriter(log_dir=save_dir)
    with open(os.path.join(save_dir, 'config.json'), 'w', encoding='utf-8') as f:
        json.dump(vars(args), f, ensure_ascii=False, indent=4)

    best_val_acc = 0
    best_val_loss = np.inf
    for epoch in range(args.epochs):
        # train loop
        model.train()
        model.train()
        model.train()
        loss_value = 0
        matches = 0
        for idx, train_batch in enumerate(train_loader):
            inputs, labels = train_batch
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            mask_labels=torch.tensor(list(map(mask_label,labels)))
            gender_labels=torch.tensor(list(map(gender_label,labels)))
            under_labels=torch.tensor(list(map(under_age_label,labels)))
            over_labels=torch.tensor(list(map(over_age_label,labels)))

            mask_labels = mask_labels.to(device)
            gender_labels = gender_labels.to(device)
            under_labels= under_labels.to(device)
            over_labels= over_labels.to(device)
            
            optimizer.zero_grad()

            mask_outs, gender_outs, under_outs, over_outs = model(inputs)
            
            mask_preds = torch.argmax(mask_outs, dim=-1)
            mask_loss = cross_criterion(mask_outs, mask_labels)

            gender_preds = torch.argmax(gender_outs, dim=-1)
            gender_loss = cross_criterion(gender_outs, gender_labels)

            under_preds = torch.argmax(under_outs, dim=-1)
            under_loss = cross_criterion(under_outs, under_labels)           
            
            over_preds = torch.argmax(over_outs, dim=-1)
            over_loss = cross_criterion(over_outs, under_labels)  
            
            mask_loss.backward()
            
            gender_loss.backward()

            under_loss.backward()
            
            over_loss.backward()

            optimizer.step()
            
            loss_value =loss_value + mask_loss.item() + gender_loss.item() + under_loss.item() + over_loss.item()
            
            age_preds=under_preds+over_preds
            
            preds=dataset.encode_multi_class(mask_preds,gender_preds,age_preds)
            
            matches +=(preds==labels).sum().item()
            if (idx + 1) % args.log_interval == 0:
                train_loss = loss_value / args.log_interval
                train_acc = matches / args.batch_size / args.log_interval
                current_lr = get_lr(optimizer)
                print(
                    f"Epoch[{epoch}/{args.epochs}]({idx + 1}/{len(train_loader)}) || "
                    f"training loss {train_loss:4.4} || training accuracy {train_acc:4.2%} || lr {current_lr}"
                )
                logger.add_scalar("Train/loss", train_loss, epoch * len(train_loader) + idx)
                logger.add_scalar("Train/accuracy", train_acc, epoch * len(train_loader) + idx)

                loss_value = 0
                matches = 0

        scheduler.step()

        # val loop
        with torch.no_grad():
            print("Calculating validation results...")
            model.eval()
            
            val_loss_items = []
            val_acc_items = []
            figure = None
            
            for val_batch in val_loader:
                inputs,labels= val_batch
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                mask_labels=torch.tensor(list(map(mask_label,labels)))
                gender_labels=torch.tensor(list(map(gender_label,labels)))
                under_labels=torch.tensor(list(map(under_age_label,labels)))
                over_labels=torch.tensor(list(map(over_age_label,labels)))
                
                mask_labels = mask_labels.to(device)
                gender_labels = gender_labels.to(device)
                under_labels= under_labels.to(device)
                over_labels= over_labels.to(device)

                mask_outs, gender_outs, under_outs, over_outs = model(inputs)
 
                mask_preds = torch.argmax(mask_outs, dim=-1)
                
                gender_preds = torch.argmax(gender_outs, dim=-1)
                
                under_preds = torch.argmax(under_outs, dim=-1)
                
                over_preds = torch.argmax(over_outs, dim=-1)
                
                age_preds = under_preds+over_preds
                
                
                loss_item = cross_criterion(mask_outs, mask_labels).item()+cross_criterion(gender_outs,gender_labels).item()+cross_criterion(under_outs, under_labels).item() + cross_criterion(over_outs, over_labels).item()
            
                preds=dataset.encode_multi_class(mask_preds,gender_preds,age_preds)
                
                acc_item = (labels==preds).sum().item()
                val_loss_items.append(loss_item)
                val_acc_items.append(acc_item)
                
                if figure is None:
                    inputs_np = torch.clone(inputs).detach().cpu().permute(0, 2, 3, 1).numpy()
                    inputs_np = dataset_module.denormalize_image(inputs_np, dataset.mean, dataset.std)
                    figure = grid_image(
                        inputs_np, labels, preds, n=16, shuffle=args.dataset != "MaskSplitByProfileDataset"
                    )
                
                
    
            val_loss = np.sum(val_loss_items) / len(val_loader)
            val_acc = np.sum(val_acc_items) / len(val_set)
            best_val_loss = min(best_val_loss, val_loss)
            if val_acc > best_val_acc:
                print(f"New best model for val accuracy : {val_acc:4.2%}! saving the best model..")
                torch.save(model.module.state_dict(), f"{save_dir}/best.pth")
                best_val_acc = val_acc
            torch.save(model.module.state_dict(), f"{save_dir}/last.pth")
            print(
                f"[Val] acc : {val_acc:4.2%}, loss: {val_loss:4.2} || "
                f"best acc : {best_val_acc:4.2%}, best loss: {best_val_loss:4.2}"
            )
            logger.add_scalar("Val/loss", val_loss, epoch)
            logger.add_scalar("Val/accuracy", val_acc, epoch)
            logger.add_figure("results", figure, epoch)
            print()

In [8]:
import easydict
args=easydict.EasyDict({
    "seed": 42,
    'epochs':10,
    'dataset':'ProfileClassEqualSplitTrainMaskDataset',
    'augmentation':'BaseAugmentation',
    'resize':[256,192],
    'batch_size':64,
    'valid_batch_size':1000,
    'model':'MyModel',
    'optimizer':'Adam',
    'lr':1e-3,
    'val_ratio':0.2,
    'cross_criterion':'cross_entropy',
    'binary_criterion':'binary_cross_entropy',
    'lr_decay_step':20,
    'log_interval':20,
    'name':'exp',
    'data_dir':os.environ.get('SM_CHANNEL_TRAIN', '/opt/ml/input/data/train/images'),
    'model_dir':os.environ.get('SM_MODEL_DIR', './model')
})

In [9]:
train('/opt/ml/input/data','./model',args)

Epoch[0/10](20/236) || training loss 4.026 || training accuracy 8.59% || lr 0.001
Epoch[0/10](40/236) || training loss 2.761 || training accuracy 16.72% || lr 0.001
Epoch[0/10](60/236) || training loss 2.142 || training accuracy 23.83% || lr 0.001
Epoch[0/10](80/236) || training loss 1.753 || training accuracy 30.31% || lr 0.001
Epoch[0/10](100/236) || training loss 1.666 || training accuracy 31.56% || lr 0.001
Epoch[0/10](120/236) || training loss 1.432 || training accuracy 36.80% || lr 0.001
Epoch[0/10](140/236) || training loss 1.388 || training accuracy 36.56% || lr 0.001
Epoch[0/10](160/236) || training loss 1.294 || training accuracy 38.52% || lr 0.001
Epoch[0/10](180/236) || training loss 1.262 || training accuracy 35.47% || lr 0.001
Epoch[0/10](200/236) || training loss 1.235 || training accuracy 37.81% || lr 0.001
Epoch[0/10](220/236) || training loss 1.079 || training accuracy 40.62% || lr 0.001
Calculating validation results...
New best model for val accuracy : 36.31%! savin

In [10]:
if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    from dotenv import load_dotenv
    import os
    load_dotenv(verbose=True)

    # Data and model checkpoints directories
    parser.add_argument('--seed', type=int, default=42, help='random seed (default: 42)')
    parser.add_argument('--epochs', type=int, default=1, help='number of epochs to train (default: 1)')
    parser.add_argument('--dataset', type=str, default='MaskBaseDataset', help='dataset augmentation type (default: MaskBaseDataset)')
    parser.add_argument('--augmentation', type=str, default='BaseAugmentation', help='data augmentation type (default: BaseAugmentation)')
    parser.add_argument("--resize", nargs="+", type=list, default=[128, 96], help='resize size for image when training')
    parser.add_argument('--batch_size', type=int, default=64, help='input batch size for training (default: 64)')
    parser.add_argument('--valid_batch_size', type=int, default=1000, help='input batch size for validing (default: 1000)')
    parser.add_argument('--model', type=str, default='BaseModel', help='model type (default: BaseModel)')
    parser.add_argument('--optimizer', type=str, default='SGD', help='optimizer type (default: SGD)')
    parser.add_argument('--lr', type=float, default=1e-3, help='learning rate (default: 1e-3)')
    parser.add_argument('--val_ratio', type=float, default=0.2, help='ratio for validaton (default: 0.2)')
    parser.add_argument('--criterion', type=str, default='cross_entropy', help='criterion type (default: cross_entropy)')
    parser.add_argument('--lr_decay_step', type=int, default=20, help='learning rate scheduler deacy step (default: 20)')
    parser.add_argument('--log_interval', type=int, default=20, help='how many batches to wait before logging training status')
    parser.add_argument('--name', default='exp', help='model save at {SM_MODEL_DIR}/{name}')

    # Container environment
    parser.add_argument('--data_dir', type=str, default=os.environ.get('SM_CHANNEL_TRAIN', '/opt/ml/input/data/train/images'))
    parser.add_argument('--model_dir', type=str, default=os.environ.get('SM_MODEL_DIR', './model'))

    args = parser.parse_args()
    print(args)

    data_dir = args.data_dir
    model_dir = args.model_dir

    train(data_dir, model_dir, args)

usage: ipykernel_launcher.py [-h] [--seed SEED] [--epochs EPOCHS]
                             [--dataset DATASET] [--augmentation AUGMENTATION]
                             [--resize RESIZE [RESIZE ...]]
                             [--batch_size BATCH_SIZE]
                             [--valid_batch_size VALID_BATCH_SIZE]
                             [--model MODEL] [--optimizer OPTIMIZER] [--lr LR]
                             [--val_ratio VAL_RATIO] [--criterion CRITERION]
                             [--lr_decay_step LR_DECAY_STEP]
                             [--log_interval LOG_INTERVAL] [--name NAME]
                             [--data_dir DATA_DIR] [--model_dir MODEL_DIR]
ipykernel_launcher.py: error: unrecognized arguments: -f /opt/ml/.local/share/jupyter/runtime/kernel-5474faf1-fedb-4981-9f19-4adfbad91c74.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [None]:
a=torch.tensor([False, False, False, False, False,  True,  True,  True,  True,  True,
        False,  True, False, False, False,  True,  True, False, False, False,
        False, False,  True,  True, False,  True, False, False,  True,  True,
         True,  True, False,  True,  True,  True,  True, False, False, False,
        False, False, False,  True, False,  True,  True, False, False, False,
        False, False, False, False,  True, False, False, False, False, False,
         True, False, False, False], device='cuda:0')
b=torch.tensor([True, False, False, False, False,  True,  True,  True,  True,  True,
        False,  True, False, False, False,  True,  True, False, False, False,
        False, False,  True,  True, False,  True, False, False,  True,  True,
         True,  True, False,  True,  True,  True,  True, False, False, False,
        False, False, False,  True, False,  True,  True, False, False, False,
        False, False, False, False,  True, False, False, False, False, False,
         True, False, False, False], device='cuda:0')

In [None]:
a*b