In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os 
import numpy as np

from apex import amp

import torch
import torch.optim as optim
from torch import nn
from torchvision import transforms
import torchcontrib
from sklearn.utils.class_weight import compute_class_weight

from hw_grapheme.io.load_data import load_processed_data
from hw_grapheme.model_archs.se_resnext import se_resnext50
from hw_grapheme.models.train import train_model
from hw_grapheme.train_utils.create_dataloader import create_dataloaders_train
from hw_grapheme.train_utils.train_test_split import stratified_split_kfold

In [3]:
# load processed data 
pickle_paths = [
    "../data/processed/size_128/train_data_0.pickle",
    "../data/processed/size_128/train_data_1.pickle",
    "../data/processed/size_128/train_data_2.pickle",
    "../data/processed/size_128/train_data_3.pickle",
]

image_data, name_data, label_data = load_processed_data(pickle_paths, image_size=128)

Load data done, shape: (200840, 128, 128), (200840,), (200840, 3)


In [4]:
# split train valid set
n_splits = 5
random_seed = 2020

train_idx_list, test_idx_list = stratified_split_kfold(
    image_data, label_data, n_splits, random_seed
)

StratifiedKFold(n_splits=5, random_state=2020, shuffle=True)


In [5]:
# create data_transforms
data_transforms = {
    'train': transforms.Compose([
        transforms.ToPILImage(),
        # transforms.RandomAffine(degrees=10, scale=(1.0, 1.15)),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        # transforms.Normalize([0.0692], [0.2051]),
        # transforms.ToPILImage(),
    ]),
    'val': transforms.Compose([
        transforms.ToPILImage(),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        # transforms.Normalize([0.0692], [0.2051])
    ]),
}

In [6]:
# default training setting
num_workers = 6
pin_memory = True
fold = list(range(n_splits))

# customize training setting
n_epoch = 120
batch_size = 512
mixed_precision = True

model_arch = se_resnext50
model_parameter = {}

swa = True

optimizer = optim.SGD
optimizer_parameter = {"weight_decay": 1e-4, "momentum": 0.9, "lr": 0.2, "nesterov": True}

# whether to use weighted loss for each class
is_weighted_class_loss = True

# create lr scheduler
epoch_scheduler_func = None
epoch_scheduler_func_para = {}
error_plateau_scheduler_func = optim.lr_scheduler.ReduceLROnPlateau
error_plateau_scheduler_func_para = {"mode": "min", "factor": 0.1, "patience": 10, "verbose": True, "min_lr":1e-3}

# prob. of using ["mixup", "cutmix", "cross_entropy"] loss
train_loss_prob = [0.5, 0.5, 0.0]
mixup_alpha = 0.4  # for mixup/cutmix only

# weighting of [root, vowel, consonant]
head_weights = [0.5, 0.25, 0.25]

wandb_log = True

# save dir, set None to not save, need to manual create folders first
save_dir = "../models/all_we_get_20200306/"
# save_dir = None

In [None]:
if is_weighted_class_loss:
    root_label = label_data[:, 0]
    vowel_label = label_data[:, 1]
    consonant_label = label_data[:, 2]

    class_weight = "balanced"

    root_cls_weight = compute_class_weight(class_weight, np.unique(root_label), root_label)
    vowel_cls_weight = compute_class_weight(class_weight, np.unique(vowel_label), vowel_label)
    consonant_cls_weight = compute_class_weight(class_weight, np.unique(consonant_label), consonant_label)
    
    class_weights = [
        torch.Tensor(root_cls_weight).cuda(),
        torch.Tensor(vowel_cls_weight).cuda(),
        torch.Tensor(consonant_cls_weight).cuda(),
    ]
else:
    class_weights = None
    
for i, (train_idx, valid_idx) in enumerate(zip(train_idx_list, test_idx_list)):
    # skip unwanted fold
    if i not in [0]:
        continue
        
    print(f"Training fold {i}") 
        
    # create model 
    model = model_arch(**model_parameter)
    
    # create optimizer
    optimizer_ft = optimizer(model.parameters(), **optimizer_parameter)
    
    if swa:
        optimizer_ft = torchcontrib.optim.SWA(optimizer_ft)
        
    if mixed_precision:
        model.to("cuda")
        model, optimizer_ft = amp.initialize(model, optimizer_ft, opt_level="O1")
        model = nn.parallel.DataParallel(model)
    else:
        model.to("cuda")
        model = nn.DataParallel(model)

    # create data loader
    data_loaders = create_dataloaders_train(
        image_data, name_data, label_data, train_idx, valid_idx, 
        data_transforms, batch_size, num_workers, pin_memory
    )
    
    # create epoch_scheduler
    if epoch_scheduler_func:
        epoch_scheduler = epoch_scheduler_func(optimizer_ft, **epoch_scheduler_func_para)
    else:
        epoch_scheduler = None
        
    # create error_plateaus_scheduler
    if error_plateau_scheduler_func:
        error_plateau_scheduler = error_plateau_scheduler_func(optimizer_ft, **error_plateau_scheduler_func_para)
    else:
        error_plateau_scheduler = None
        
    # callbacks = {}
    if save_dir:
        full_save_dir = os.path.join(save_dir, f"fold_{i}")
    else:
        full_save_dir = None
        
    train_input_args = {
        "model": model, 
        "optimizer": optimizer_ft,
        "dataloaders": data_loaders,
        "mixed_precision": mixed_precision, 
        "train_loss_prob": train_loss_prob,
        "class_weights": class_weights,
        "head_weights": head_weights,
        "mixup_alpha": mixup_alpha, 
        "num_epochs": n_epoch,
        "epoch_scheduler": epoch_scheduler, 
        "error_plateau_scheduler": error_plateau_scheduler,
        "save_dir": full_save_dir,
        "wandb_log": wandb_log,
        "swa": swa,
    }
        
    callbacks = train_model(**train_input_args)

Training fold 0
Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Creating train dataloader...
Creating valid dataloader...


wandb: Wandb version 0.8.29 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade


Epoch 0/119
----------


HBox(children=(FloatProgress(value=0.0, max=314.0), HTML(value='')))


Finish training
Loss: 3.7524, root_acc: 0.0013, vowel_acc: 0.0919, consonant_acc: 0.2751, combined_acc: 0.0924, root_recall: 0.0056, vowel_recall: 0.1058, consonant_recall: 0.1545, combined_recall: 0.0679



HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Finish validating
Loss: nan, root_acc: 0.0022, vowel_acc: 0.2089, consonant_acc: 0.3845, combined_acc: 0.1495, root_recall: 0.0094, vowel_recall: 0.1427, consonant_recall: 0.1984, combined_recall: 0.0900

Highest val_combined_recall increases from 0.0 to 0.0899857748963445.

Epoch 1/119
----------


HBox(children=(FloatProgress(value=0.0, max=314.0), HTML(value='')))


Finish training
Loss: 3.4832, root_acc: 0.0031, vowel_acc: 0.1948, consonant_acc: 0.2080, combined_acc: 0.1022, root_recall: 0.0143, vowel_recall: 0.2350, consonant_recall: 0.2650, combined_recall: 0.1321



HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Finish validating
Loss: 2.8203, root_acc: 0.0057, vowel_acc: 0.3878, consonant_acc: 0.2811, combined_acc: 0.1701, root_recall: 0.0324, vowel_recall: 0.4262, consonant_recall: 0.4323, combined_recall: 0.2308

Lowest val_loss decreases from 999 to 2.8203238013540455.
Highest val_combined_recall increases from 0.0899857748963445 to 0.2308107320069687.

Epoch 2/119
----------


HBox(children=(FloatProgress(value=0.0, max=314.0), HTML(value='')))


Finish training
Loss: 2.9522, root_acc: 0.0260, vowel_acc: 0.4911, consonant_acc: 0.4706, combined_acc: 0.2535, root_recall: 0.0764, vowel_recall: 0.5273, consonant_recall: 0.5566, combined_recall: 0.3092



HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Finish validating
Loss: 2.0311, root_acc: 0.0873, vowel_acc: 0.6966, consonant_acc: 0.5008, combined_acc: 0.3430, root_recall: 0.1842, vowel_recall: 0.7608, consonant_recall: 0.7300, combined_recall: 0.4648

Lowest val_loss decreases from 2.8203238013540455 to 2.0310798702285573.
Highest val_combined_recall increases from 0.2308107320069687 to 0.464790903016563.

Epoch 3/119
----------


HBox(children=(FloatProgress(value=0.0, max=314.0), HTML(value='')))


Finish training
Loss: 2.2341, root_acc: 0.2060, vowel_acc: 0.6235, consonant_acc: 0.6319, combined_acc: 0.4169, root_recall: 0.3117, vowel_recall: 0.6484, consonant_recall: 0.6823, combined_recall: 0.4885



HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Finish validating
Loss: 1.2693, root_acc: 0.5081, vowel_acc: 0.8612, consonant_acc: 0.7627, combined_acc: 0.6600, root_recall: 0.6205, vowel_recall: 0.8916, consonant_recall: 0.8877, combined_recall: 0.7551

Lowest val_loss decreases from 2.0310798702285573 to 1.269290113563059.
Highest val_combined_recall increases from 0.464790903016563 to 0.7550507290866558.

Epoch 4/119
----------


HBox(children=(FloatProgress(value=0.0, max=314.0), HTML(value='')))

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0

Finish training
Loss: 1.7724, root_acc: 0.4326, vowel_acc: 0.6893, consonant_acc: 0.6989, combined_acc: 0.5633, root_recall: 0.5073, vowel_recall: 0.7127, consonant_recall: 0.7377, combined_recall: 0.6162



HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Finish validating
Loss: 0.7649, root_acc: 0.7245, vowel_acc: 0.8989, consonant_acc: 0.8717, combined_acc: 0.8049, root_recall: 0.7900, vowel_recall: 0.9220, consonant_recall: 0.8920, combined_recall: 0.8485

Lowest val_loss decreases from 1.269290113563059 to 0.7648522963777431.
Highest val_combined_recall increases from 0.7550507290866558 to 0.8484700899442742.

Epoch 5/119
----------


HBox(children=(FloatProgress(value=0.0, max=314.0), HTML(value='')))


Finish training
Loss: 1.5402, root_acc: 0.5310, vowel_acc: 0.7092, consonant_acc: 0.7280, combined_acc: 0.6248, root_recall: 0.5813, vowel_recall: 0.7280, consonant_recall: 0.7417, combined_recall: 0.6581



HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Finish validating
Loss: nan, root_acc: 0.8089, vowel_acc: 0.9305, consonant_acc: 0.9201, combined_acc: 0.8671, root_recall: 0.8487, vowel_recall: 0.9338, consonant_recall: 0.9351, combined_recall: 0.8915

Highest val_combined_recall increases from 0.8484700899442742 to 0.8915290168376079.

Epoch 6/119
----------


HBox(children=(FloatProgress(value=0.0, max=314.0), HTML(value='')))


Finish training
Loss: 1.5230, root_acc: 0.5615, vowel_acc: 0.7243, consonant_acc: 0.7389, combined_acc: 0.6466, root_recall: 0.6106, vowel_recall: 0.7442, consonant_recall: 0.7631, combined_recall: 0.6821



HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Finish validating
Loss: 1.2709, root_acc: 0.7339, vowel_acc: 0.9030, consonant_acc: 0.8410, combined_acc: 0.8029, root_recall: 0.7927, vowel_recall: 0.9156, consonant_recall: 0.8812, combined_recall: 0.8456


Epoch 7/119
----------


HBox(children=(FloatProgress(value=0.0, max=314.0), HTML(value='')))


Finish training
Loss: 1.3757, root_acc: 0.6044, vowel_acc: 0.7428, consonant_acc: 0.7577, combined_acc: 0.6773, root_recall: 0.6483, vowel_recall: 0.7634, consonant_recall: 0.7804, combined_recall: 0.7101



HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Finish validating
Loss: 0.6678, root_acc: 0.8264, vowel_acc: 0.9275, consonant_acc: 0.8898, combined_acc: 0.8675, root_recall: 0.8699, vowel_recall: 0.9454, consonant_recall: 0.9216, combined_recall: 0.9017

Lowest val_loss decreases from 0.7648522963777431 to 0.6677719850721551.
Highest val_combined_recall increases from 0.8915290168376079 to 0.9016905833427961.

Epoch 8/119
----------


HBox(children=(FloatProgress(value=0.0, max=314.0), HTML(value='')))


Finish training
Loss: 1.4703, root_acc: 0.5654, vowel_acc: 0.7121, consonant_acc: 0.7279, combined_acc: 0.6427, root_recall: 0.6089, vowel_recall: 0.7327, consonant_recall: 0.7513, combined_recall: 0.6755



HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Finish validating
Loss: 0.5509, root_acc: 0.8544, vowel_acc: 0.9506, consonant_acc: 0.8808, combined_acc: 0.8850, root_recall: 0.8870, vowel_recall: 0.9598, consonant_recall: 0.9470, combined_recall: 0.9202

Lowest val_loss decreases from 0.6677719850721551 to 0.5509159371829607.
Highest val_combined_recall increases from 0.9016905833427961 to 0.9201958938794779.

Epoch 9/119
----------


HBox(children=(FloatProgress(value=0.0, max=314.0), HTML(value='')))


Finish training
Loss: 1.2912, root_acc: 0.6398, vowel_acc: 0.7544, consonant_acc: 0.7781, combined_acc: 0.7030, root_recall: 0.6802, vowel_recall: 0.7745, consonant_recall: 0.7987, combined_recall: 0.7334



HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Finish validating
Loss: 0.5529, root_acc: 0.8720, vowel_acc: 0.9473, consonant_acc: 0.8979, combined_acc: 0.8973, root_recall: 0.8953, vowel_recall: 0.9582, consonant_recall: 0.9329, combined_recall: 0.9204

Highest val_combined_recall increases from 0.9201958938794779 to 0.9204195132822145.

Epoch 10/119
----------


HBox(children=(FloatProgress(value=0.0, max=314.0), HTML(value='')))


Finish training
Loss: 1.2507, root_acc: 0.6585, vowel_acc: 0.7624, consonant_acc: 0.7803, combined_acc: 0.7149, root_recall: 0.7002, vowel_recall: 0.7869, consonant_recall: 0.8118, combined_recall: 0.7498



HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Finish validating
Loss: 0.4596, root_acc: 0.8739, vowel_acc: 0.9682, consonant_acc: 0.9600, combined_acc: 0.9190, root_recall: 0.9006, vowel_recall: 0.9647, consonant_recall: 0.9525, combined_recall: 0.9296

Lowest val_loss decreases from 0.5509159371829607 to 0.4596363641759405.
Highest val_combined_recall increases from 0.9204195132822145 to 0.9296155274101184.

Epoch 11/119
----------


HBox(children=(FloatProgress(value=0.0, max=314.0), HTML(value='')))

In [None]:
# configs = {
#     "model": "efficient 0",
#     "pretrain": pretrain,
#     "head_info": "1 fc",
#     "input_size": "224X224",
#     "optimizer": "adam",
#     "n_fold": n_splits,
#     "split_seed": random_seed,
#     "batch_size": batch_size,
#     "epoch": n_epoch,
#     "mixed_precision": mixed_precision
# }