In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os 

import torch.optim as optim

from torch import nn
from torchvision import transforms

from hw_grapheme.io.load_data import load_processed_data
from hw_grapheme.train_utils.train_test_split import stratified_split_kfold
from hw_grapheme.train_utils.create_dataloader import create_dataloaders_train
from hw_grapheme.model_archs.efficient_net_0 import EfficientNet_0
# from hw_grapheme.models.train import train_model
from hw_grapheme.models.train_with_new_callback import train_model

In [None]:
# load processed data 
pickle_paths = [
    "../data/processed/size_128/train_data_0.pickle",
#     "../data/processed/size_128/train_data_1.pickle",
#     "../data/processed/size_128/train_data_2.pickle",
#     "../data/processed/size_128/train_data_3.pickle",
]

image_data, name_data, label_data = load_processed_data(pickle_paths, image_size=128)

In [None]:
# split train valid set
n_splits = 5
random_seed = 2020

train_idx_list, test_idx_list = stratified_split_kfold(
    image_data, label_data, n_splits, random_seed
)

In [None]:
# create data_transforms
data_transforms = {
    'train': transforms.Compose([
        transforms.ToPILImage(),
        # transforms.RandomAffine(degrees=10, scale=(1.0, 1.15)),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        # transforms.Normalize([0.0692], [0.2051]),
        # transforms.ToPILImage(),
    ]),
    'val': transforms.Compose([
        transforms.ToPILImage(),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        # transforms.Normalize([0.0692], [0.2051])
    ]),
}

In [None]:
# default training setting
num_workers = 6
pin_memory = True
fold = list(range(n_splits))

# customize training setting
n_epoch = 2
batch_size = 64
mixed_precision = False
pretrain = False
optimizer = optim.Adam
optimizer_parameter = {"weight_decay": 1e-5}
mixup_alpha = 0.4  # for mixup/cutmix only
wandb_log = True

# prob. of using ["mixup", "cutmix", "cross_entropy"] loss
train_loss_prob = [0.0, 0.0, 1.0]

# weighting of [root, vowel, consonant]
head_loss_weight = [0.5, 0.25, 0.25]

# save dir, set None to not save, need to manual create folders first
save_dir = "../models/testing"

In [None]:
for i, (train_idx, valid_idx) in enumerate(zip(train_idx_list, test_idx_list)):
    # skip unwanted fold
    if i not in [0]:
        continue
        
    print(f"Training fold {i}") 
    
    # create model 
    model = EfficientNet_0(pretrain=pretrain)
    if mixed_precision:
        model = apex.parallel.DistributedDataParallel(model)
        model.to("cuda")
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[0,1], output_device=0
        )
        model, optimizer_ft = amp.initialize(model, optimizer_ft, opt_level="O1")
    else:
        model.to("cuda")
        model = nn.DataParallel(model)
        
    # create optimizer
    optimizer_ft = optimizer(model.parameters(), **optimizer_parameter)

    # create data loader
    data_loaders = create_dataloaders_train(
        image_data, name_data, label_data, train_idx, valid_idx, 
        data_transforms, batch_size, num_workers, pin_memory
    )
    
    # create lr scheduler
    # exp_lr_scheduler = lr_scheduler.ReduceLROnPlateau(
    #     optimizer_ft, factor=0.5, patience=5,
    # )
    
    # callbacks = {}
    if save_dir:
        full_save_dir = os.path.join(save_dir, f"fold_{i}")
    else:
        full_save_dir = None
    callbacks = train_model(
        model, optimizer_ft, data_loaders,
        mixed_precision, train_loss_prob, head_loss_weight,
        mixup_alpha=mixup_alpha, num_epochs=n_epoch,
        epoch_scheduler=None, save_dir=full_save_dir,
        wandb_log=wandb_log
    )

In [None]:
# configs = {
#     "model": "efficient 0",
#     "pretrain": pretrain,
#     "head_info": "1 fc",
#     "input_size": "224X224",
#     "optimizer": "adam",
#     "n_fold": n_splits,
#     "split_seed": random_seed,
#     "batch_size": batch_size,
#     "epoch": n_epoch,
#     "mixed_precision": mixed_precision
# }