In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import torch.optim as optim

from efficientnet_pytorch import EfficientNet
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.optim import lr_scheduler

from hw_grapheme.train import generate_stratified_k_fold_index, train_model
from hw_grapheme.utils import load_model_weight
from hw_grapheme.data_pipeline import create_dataloaders, load_data
from hw_grapheme.model import EfficientNet_0
from hw_grapheme.loss_func import Loss_combine

from torchtools.optim import RangerLars, RAdam
# from one_cycle import OneCycleLR
from torch.optim import Optimizer

# from warmup_scheduler import GradualWarmupScheduler



In [3]:
def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
#     mean = np.array([0.485, 0.456, 0.406])
#     std = np.array([0.229, 0.224, 0.225])
#     inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated

In [4]:
# load data 
pickle_paths = [
    "../data/processed_data/size_224/train_data_0.pickle",
    "../data/processed_data/size_224/train_data_1.pickle",
    "../data/processed_data/size_224/train_data_2.pickle",
    "../data/processed_data/size_224/train_data_3.pickle",
]

image_data, name_data, label_data = load_data(pickle_paths)

Load data done, shape: (200840, 224, 224), (200840,), (200840, 3)


In [5]:
batch_size = 64
num_workers = 6
pin_memory = True
n_epoch = 120

n_splits = 5
random_seed = 2020

mixed_precision = False

train_idx_list, test_idx_list = generate_stratified_k_fold_index(
    image_data, label_data, n_splits, random_seed
)

# create loss function
# criterion = nn.CrossEntropyLoss()
criterion = Loss_combine()

# for discriminative lr
# my_list = ['module._fc.weight', 'module._fc.bias']
# params = list(filter(lambda kv: kv[0] in my_list, eff_b0.named_parameters()))
# base_params = list(filter(lambda kv: kv[0] not in my_list, eff_b0.named_parameters()))
# params = [kv[1] for kv in params]
# base_params = [kv[1] for kv in base_params]

# create data_transforms
data_transforms = {
    'train': transforms.Compose([
        transforms.ToPILImage(),
        # transforms.RandomAffine(degrees=10, scale=(1.0, 1.15)),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        # transforms.Normalize([0.0692], [0.2051]),
        # transforms.ToPILImage(),
    ]),
    'val': transforms.Compose([
        transforms.ToPILImage(),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        # transforms.Normalize([0.0692], [0.2051])
    ]),
}

StratifiedKFold(n_splits=5, random_state=2020, shuffle=True)


In [7]:
pretrain = False

for i, (train_idx, valid_idx) in enumerate(zip(train_idx_list, test_idx_list)):
    if i != 1:
        continue
        
    print(f"Training fold {i}")
    
    # create model 
    eff_b0 = EfficientNet_0(pretrain)
    if mixed_precision:
        eff_b0 = apex.parallel.DistributedDataParallel(eff_b0)
        eff_b0.to("cuda")
        eff_b0 = torch.nn.parallel.DistributedDataParallel(
            eff_b0, device_ids=[0,1], output_device=0
        )
        eff_b0, optimizer_ft = amp.initialize(eff_b0, optimizer_ft, opt_level="O1")
    else:
        eff_b0.to("cuda")
        eff_b0 = nn.DataParallel(eff_b0)
        
    # create optimizer
    # optimizer_ft = RangerLars(eff_b0.parameters())
    optimizer_ft = optim.Adam(eff_b0.parameters())

    # create data loader
    data_loaders = create_dataloaders(
        image_data, name_data, label_data, train_idx, valid_idx, 
        data_transforms, batch_size, num_workers, pin_memory
    )
    
    # create lr scheduler
    # exp_lr_scheduler = lr_scheduler.ReduceLROnPlateau(
    #     optimizer_ft, factor=0.5, patience=5,
    # )
#     cos_lr_scheduler = lr_scheduler.CosineAnnealingLR(
#         optimizer_ft, T_max=n_epoch,
#     )
#     # one_cycle_lr_scheduler = OneCycleLR(
#     #     optimizer_ft, max_lr=0.01, steps_per_epoch=len(data_loaders["train"]), epochs=n_epoch
#     # )   
    
#     scheduler_warmup = GradualWarmupScheduler(
#         optimizer_ft, multiplier=1, total_epoch=10, after_scheduler=cos_lr_scheduler
#     )

    
    callbacks = {}

    callbacks = train_model(
        eff_b0, criterion, optimizer_ft, data_loaders,
        mixed_precision, callbacks, num_epochs=n_epoch,
        epoch_scheduler=None, save_dir=f"../model_weights/eff_0_with_mixup_cutmix/fold_{i}"
    )


Training fold 0
Epoch 0/31
----------


HBox(children=(FloatProgress(value=0.0, max=2511.0), HTML(value='')))


Train Loss: 1.1969, root_acc: 0.6243, vowel_acc: 0.6913, consonant_acc: 0.8236, combined_acc: 0.6908


HBox(children=(FloatProgress(value=0.0, max=628.0), HTML(value='')))


Val Loss: 0.5670, root_acc: 0.8097, vowel_acc: 0.8814, consonant_acc: 0.9106, combined_acc: 0.8528
In epoch 0, highest val accuracy increases from 0.0 to 0.8528492830113523.
In epoch 0, lowest val loss decreases from 999 to 0.5670398714959776.

Epoch 1/31
----------


HBox(children=(FloatProgress(value=0.0, max=2511.0), HTML(value='')))


Train Loss: 0.4002, root_acc: 0.8605, vowel_acc: 0.9204, consonant_acc: 0.9425, combined_acc: 0.8960


HBox(children=(FloatProgress(value=0.0, max=628.0), HTML(value='')))


Val Loss: 0.3670, root_acc: 0.8724, vowel_acc: 0.9444, consonant_acc: 0.9547, combined_acc: 0.9109
In epoch 1, highest val accuracy increases from 0.8528492830113523 to 0.9109490141406094.
In epoch 1, lowest val loss decreases from 0.5670398714959776 to 0.3670184142061069.

Epoch 2/31
----------


HBox(children=(FloatProgress(value=0.0, max=2511.0), HTML(value='')))


Train Loss: 0.2960, root_acc: 0.8961, vowel_acc: 0.9449, consonant_acc: 0.9570, combined_acc: 0.9235


HBox(children=(FloatProgress(value=0.0, max=628.0), HTML(value='')))


Val Loss: 0.2930, root_acc: 0.8987, vowel_acc: 0.9583, consonant_acc: 0.9643, combined_acc: 0.9300
In epoch 2, highest val accuracy increases from 0.9109490141406094 to 0.9300002489543915.
In epoch 2, lowest val loss decreases from 0.3670184142061069 to 0.29301492901015536.

Epoch 3/31
----------


HBox(children=(FloatProgress(value=0.0, max=2511.0), HTML(value='')))


Train Loss: 0.2404, root_acc: 0.9154, vowel_acc: 0.9556, consonant_acc: 0.9641, combined_acc: 0.9376


HBox(children=(FloatProgress(value=0.0, max=628.0), HTML(value='')))


Val Loss: 0.2805, root_acc: 0.9026, vowel_acc: 0.9604, consonant_acc: 0.9672, combined_acc: 0.9332
In epoch 3, highest val accuracy increases from 0.9300002489543915 to 0.9331930890260904.
In epoch 3, lowest val loss decreases from 0.29301492901015536 to 0.28050896451188806.

Epoch 4/31
----------


HBox(children=(FloatProgress(value=0.0, max=2511.0), HTML(value='')))


Train Loss: 0.2043, root_acc: 0.9276, vowel_acc: 0.9618, consonant_acc: 0.9685, combined_acc: 0.9464


HBox(children=(FloatProgress(value=0.0, max=628.0), HTML(value='')))


Val Loss: 0.2336, root_acc: 0.9202, vowel_acc: 0.9694, consonant_acc: 0.9722, combined_acc: 0.9455
In epoch 4, highest val accuracy increases from 0.9331930890260904 to 0.9455038836885084.
In epoch 4, lowest val loss decreases from 0.28050896451188806 to 0.2335794798776083.

Epoch 5/31
----------


HBox(children=(FloatProgress(value=0.0, max=2511.0), HTML(value='')))


Train Loss: 0.1776, root_acc: 0.9367, vowel_acc: 0.9659, consonant_acc: 0.9718, combined_acc: 0.9528


HBox(children=(FloatProgress(value=0.0, max=628.0), HTML(value='')))


Val Loss: 0.2281, root_acc: 0.9227, vowel_acc: 0.9713, consonant_acc: 0.9734, combined_acc: 0.9475
In epoch 5, highest val accuracy increases from 0.9455038836885084 to 0.9475453096992631.
In epoch 5, lowest val loss decreases from 0.2335794798776083 to 0.22805180701689995.

Epoch 6/31
----------


HBox(children=(FloatProgress(value=0.0, max=2511.0), HTML(value='')))


Train Loss: 0.1551, root_acc: 0.9449, vowel_acc: 0.9686, consonant_acc: 0.9742, combined_acc: 0.9581


HBox(children=(FloatProgress(value=0.0, max=628.0), HTML(value='')))


Val Loss: 0.2226, root_acc: 0.9267, vowel_acc: 0.9709, consonant_acc: 0.9745, combined_acc: 0.9497
In epoch 6, highest val accuracy increases from 0.9475453096992631 to 0.9496738697470624.
In epoch 6, lowest val loss decreases from 0.22805180701689995 to 0.22257971257309686.

Epoch 7/31
----------


HBox(children=(FloatProgress(value=0.0, max=2511.0), HTML(value='')))


Train Loss: 0.1380, root_acc: 0.9506, vowel_acc: 0.9710, consonant_acc: 0.9766, combined_acc: 0.9622


HBox(children=(FloatProgress(value=0.0, max=628.0), HTML(value='')))


Val Loss: 0.2338, root_acc: 0.9219, vowel_acc: 0.9730, consonant_acc: 0.9754, combined_acc: 0.9480

Epoch 8/31
----------


HBox(children=(FloatProgress(value=0.0, max=2511.0), HTML(value='')))


Train Loss: 0.1237, root_acc: 0.9555, vowel_acc: 0.9729, consonant_acc: 0.9783, combined_acc: 0.9655


HBox(children=(FloatProgress(value=0.0, max=628.0), HTML(value='')))


Val Loss: 0.2509, root_acc: 0.9201, vowel_acc: 0.9708, consonant_acc: 0.9746, combined_acc: 0.9464

Epoch 9/31
----------


HBox(children=(FloatProgress(value=0.0, max=2511.0), HTML(value='')))


Train Loss: 0.1118, root_acc: 0.9591, vowel_acc: 0.9745, consonant_acc: 0.9797, combined_acc: 0.9681


HBox(children=(FloatProgress(value=0.0, max=628.0), HTML(value='')))


Val Loss: 0.2274, root_acc: 0.9305, vowel_acc: 0.9746, consonant_acc: 0.9773, combined_acc: 0.9533
In epoch 9, highest val accuracy increases from 0.9496738697470624 to 0.9532525891256722.

Epoch 10/31
----------


HBox(children=(FloatProgress(value=0.0, max=2511.0), HTML(value='')))


Train Loss: 0.1019, root_acc: 0.9628, vowel_acc: 0.9755, consonant_acc: 0.9811, combined_acc: 0.9705


HBox(children=(FloatProgress(value=0.0, max=628.0), HTML(value='')))


Val Loss: 0.2376, root_acc: 0.9241, vowel_acc: 0.9773, consonant_acc: 0.9776, combined_acc: 0.9508

Epoch 11/31
----------


HBox(children=(FloatProgress(value=0.0, max=2511.0), HTML(value='')))


Train Loss: 0.0930, root_acc: 0.9660, vowel_acc: 0.9769, consonant_acc: 0.9825, combined_acc: 0.9728


HBox(children=(FloatProgress(value=0.0, max=628.0), HTML(value='')))


Val Loss: 0.2294, root_acc: 0.9310, vowel_acc: 0.9751, consonant_acc: 0.9777, combined_acc: 0.9537
In epoch 11, highest val accuracy increases from 0.9532525891256722 to 0.9537131547500498.

Epoch 12/31
----------


HBox(children=(FloatProgress(value=0.0, max=2511.0), HTML(value='')))

KeyboardInterrupt: 

In [None]:
configs = {
    "model": "efficient 0",
    "pretrain": pretrain,
    "head_info": "1 fc",
    "input_size": "224X224",
    "optimizer": "adam",
    "n_fold": n_splits,
    "split_seed": random_seed,
    "batch_size": batch_size,
    "epoch": n_epoch,
    "mixed_precision": mixed_precision
}

In [None]:
# Get a batch of training data for demo

# visual_loader = DataLoader(
#     train_dataset, batch_size=4,
#     num_workers=num_workers, pin_memory=True,
# )

# inputs, a,b,c = next(iter(visual_loader))

# # Make a grid from batch
# out = torchvision.utils.make_grid(inputs)

# imshow(out)