In [1]:
#define environment
ON_KAGGLE = False
TRAIN_PREDICT = 'train'

In [2]:
import sys
if ON_KAGGLE:
    sys.path.append('../input/bengali-util/script/')
    sys.path.append('../input/bengali-util/')
    from script.utils import seed_everything, set_n_get_device
else:
    from utils import seed_everything, set_n_get_device

import os
import pickle
import time

import pandas as pd
import numpy as np
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
from torch.nn import functional as F
print('torch', torch.__version__)

if TRAIN_PREDICT=='train':
    from sklearn.model_selection import train_test_split, KFold


%matplotlib inline

torch 1.3.1+cu100


In [3]:
if ON_KAGGLE:
    #load utility scripts
    pass

else:#offline
    pass
    

In [4]:
#config
debug = False
SEED = 42
IMG_HEIGHT = 137
IMG_WIDTH = 236

if TRAIN_PREDICT=='train':
    BATCH_SIZE = 128
else:
    BATCH_SIZE = 256

if ON_KAGGLE:
    NUM_WORKERS = 2
else:
    NUM_WORKERS = 16

device = set_n_get_device("0,1", data_device_id="cuda:0")#IMPORTANT: data_device_id is set to free gpu for storing the model, e.g."cuda:1"
multi_gpu = [0,1]

if debug:
    LOG_PATH = '../logging/v5-debug.log'
else:
    LOG_PATH = '../logging/v5.log'

checkpoint_path = '../checkpoint/v5'
warm_start, last_checkpoint_path = False, '../checkpoint/v3/best.pth.tar'

NUM_EPOCHS = 60
early_stopping_round = 9999
#LearningRate = 5e-3

seed_everything(SEED)

In [5]:
#CONSTANTS
n_grapheme=168
n_vowel=11
n_consonant=7
#n_combo = 1295

#num_classes = n_grapheme+n_vowel+n_consonant+n_combo
num_classes = n_grapheme+n_vowel+n_consonant

## Pipeline
1. pytorch Dataset, data augmentation, DataLoader, train-test-split/KFold, 
2. network
3. training process

In [6]:
if ON_KAGGLE:
    pass

else:#offline
    train_df_list = [pd.read_feather('../data/processed/train_image_data_%d.feather'%i) for i in range(4)]
    train_images_arr = np.concatenate([df.iloc[:, 1:].values.reshape(-1, IMG_HEIGHT, IMG_WIDTH) 
                                       for df in train_df_list], axis=0)
    train_label_df = pd.read_csv('../data/raw/train.csv')
    train_label_arr = train_label_df[['grapheme_root', 'vowel_diacritic', 'consonant_diacritic']].values

train_images_arr.shape, train_label_arr.shape

((200840, 137, 236), (200840, 3))

In [7]:
# import importlib
# import augmentation
# importlib.reload(augmentation)
# from augmentation import *

In [8]:
# ##experiment a lot of augmentations
# imgs = train_images_arr[np.random.choice(200840, 4), :]
# imgs = np.clip((255-imgs)/255, 0, 1)

# fig,axes = plt.subplots(4,2, figsize=(10,8))
# for i in range(4):
#     image = imgs[i]
#     img_aug = do_random_shift_scale_crop_pad2(image, limit=0.2)
#     axes[i, 0].imshow(image, cmap='binary')
#     axes[i, 1].imshow(img_aug, cmap='binary')

In [9]:
## 1. encode grapheme characters to index 2. use it as sampler
# unique_char = train_label_df['grapheme'].unique()
# char2ind = dict([(char,i) for i,char in enumerate(unique_char)])
# grapheme_ind = [char2ind[char] for char in train_label_df['grapheme']]
# cls_w_dict = pd.value_counts(grapheme_ind)
# cls_w_dict /= 100
# cls_w_dict = cls_w_dict.to_dict()
# cls_w = [cls_w_dict[i] for i in grapheme_ind]


##check onehot correct?
#train_label_df.loc[train_label_arr[:,3]==1, ]

In [10]:
##data augmentation --cutmix, mixup
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    #cut_rat = np.sqrt(1. - lam)
    cut_rat = lam
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform, ignore edge area
    cx = np.random.randint(W//4, W*3//4)
    cy = np.random.randint(H//4, H*3//4)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    #print(bbx1, bby1, bbx2, bby2)
    return bbx1, bby1, bbx2, bby2

def cutmix(data, target, alpha=1.0):
    targets1, targets2, targets3 = target[:,0], target[:,1], target[:,2]
    indices = torch.randperm(data.size(0))
    shuffled_data = data[indices]
    shuffled_targets1 = targets1[indices]
    shuffled_targets2 = targets2[indices]
    shuffled_targets3 = targets3[indices]

    #lam = np.random.beta(alpha, alpha)
    lam = np.sqrt(np.random.rand()/4)
    bbx1, bby1, bbx2, bby2 = rand_bbox(data.size(), lam)
    data[:, :, bbx1:bbx2, bby1:bby2] = data[indices, :, bbx1:bbx2, bby1:bby2]
    #data[:, :, bbx1:bbx2, bby1:bby2] += data[indices, :, bbx1:bbx2, bby1:bby2]
    #data = torch.clamp(data, 0, 1)
    # adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (data.size()[-1] * data.size()[-2]))

    targets = [targets1, shuffled_targets1, targets2, shuffled_targets2, 
               targets3, shuffled_targets3, lam]
    return data, targets

def mixup(data, target, alpha=0.4):
    targets1, targets2, targets3 = target[:,0], target[:,1], target[:,2]
    indices = torch.randperm(data.size(0))
    shuffled_data = data[indices]
    shuffled_targets1 = targets1[indices]
    shuffled_targets2 = targets2[indices]
    shuffled_targets3 = targets3[indices]

    lam = np.random.beta(alpha, alpha)
    data = data * lam + shuffled_data * (1 - lam)
    targets = [targets1, shuffled_targets1, targets2, shuffled_targets2, 
               targets3, shuffled_targets3, lam]

    return data, targets

In [11]:
# kf = KFold(n_splits=5, shuffle=True, random_state=SEED).split(X=train_images_arr, y=train_label_arr)

# for fold, (train_idx, valid_idx) in enumerate(kf):
    
#     if fold in [0]:#train 1 fold for testing ideas
#         print('========training fold %d========'%fold)
#         print(train_idx)
#         print(valid_idx)
        
#         #1.1 data
#         train_inputs, valid_inputs = train_images_arr[train_idx], train_images_arr[valid_idx]
#         train_outputs, valid_outputs = train_label_arr[train_idx], train_label_arr[valid_idx]
#         #1.2 Dataset, DataLoader
#         train_dl = prepare_dataset(train_inputs, train_outputs, mode='train', debug=debug)
#         val_dl = prepare_dataset(valid_inputs, valid_outputs, mode='valid', debug=debug)

# ####
# for batch_id, (images, labels) in enumerate(train_dl):
#     inputs = images.to(device=device, dtype=torch.float)
#     truth = labels.to(device=device, dtype=torch.float)
#     if np.random.rand()<-10:
#         inputs, truth = cutmix(inputs, truth, alpha=None)
#     else:
#         inputs, truth = mixup(inputs, truth, alpha=0.4)
#     if batch_id==1:
#         break

# #print(inputs.shape, truth.shape)

# show_inputs = inputs.cpu().numpy()[np.random.choice(BATCH_SIZE, 20, replace=False), :, :, :]
# fig,axes = plt.subplots(5,4, figsize=(10,8))
# for i in range(20):
#     axes[i//4, i%4].imshow(show_inputs[i, 0], cmap='binary')

In [12]:
# ##crop to 128x128
# def bbox(img):
#     rows = np.any(img, axis=1)
#     cols = np.any(img, axis=0)
#     rmin, rmax = np.where(rows)[0][[0, -1]]
#     cmin, cmax = np.where(cols)[0][[0, -1]]
#     return rmin, rmax, cmin, cmax

# def crop_resize(img0, size=128, pad=16):
#     #crop a box around pixels large than the threshold 
#     #some images contain line at the sides
#     ymin,ymax,xmin,xmax = bbox(img0[5:-5,5:-5] > 80)
#     #cropping may cut too much, so we need to add it back
#     xmin = xmin - 13 if (xmin > 13) else 0
#     ymin = ymin - 10 if (ymin > 10) else 0
#     xmax = xmax + 13 if (xmax < IMG_WIDTH - 13) else IMG_WIDTH
#     ymax = ymax + 10 if (ymax < IMG_HEIGHT - 10) else IMG_HEIGHT
#     img = img0[ymin:ymax,xmin:xmax]
#     #remove lo intensity pixels as noise
#     img[img < 28] = 0
#     lx, ly = xmax-xmin,ymax-ymin
#     l = max(lx,ly) + pad
#     #make sure that the aspect ratio is kept in rescaling
#     img = np.pad(img, [((l-ly)//2,), ((l-lx)//2,)], mode='constant')
#     return cv2.resize(img,(size,size))

In [13]:
#img0 = train_images_arr[0]
#img0 = 255 - img0
#img0 = (img0*(255.0/img0.max())).astype(np.uint8)
#plt.imshow(crop_resize(rotate(img0, angle=20, reshape=False)))

In [14]:
# #### make up a weighted/balanced data sampler --for mixup/cutmix ####
# #weights for 168 classes--graphene_root
# import torch.utils.data
# _weights = [6.8,6.9,3,3.1,3,5.7,3.2,6.5,6.4,2.3,6.6,6.6,6.8,0.2,1.3,0.9,1.1,1.3,0.6,3.6,3,1.1,0.3,0.2,3,0.9,5.8,3.3,1.3,0.4,2.3,1.3,0.9,7.4,3.6,2.1,1,3.5,0.3,1.6,1.3,3.3,0.5,0.3,0.9,6.9,1.7,2.2,0.7,3.1,1.4,3.1,1.1,0.3,1.7,0.6,0.4,1.6,0.8,0.4,2.3,1.7,1.2,6.7,0.2,0.7,1.3,2.1,1.6,1.3,1,0.3,0.2,7.7,0.7,0.9,0.5,1,3.4,0.3,2.2,0.3,3.4,0.7,2.2,0.7,0.5,6,1.3,0.4,1.6,0.6,0.9,1.6,1,1.4,0.2,2.1,1.6,2.2,2.2,0.9,7.1,0.3,6.2,6.6,1.3,0.2,6.3,1.1,2.9,1.3,1.1,0.2,6.7,0.2,2.3,0.7,0.9,0.7,0.8,2.2,0.4,0.5,0.5,1.2,6.3,1.1,1.1,1,6.9,2.3,1,0.2,1.6,1.6,1,1.8,1.1,0.4,1.1,0.6,0.9,1.6,1.6,3.2,3.3,0.2,0.6,0.4,0.4,0.8,1.6,0.6,1.4,1.1,1.3,3.1,7,0.3,2.1,3.2,2.2,6.1,6.1,0.9,3.3,0.6]
# weights_dict = dict(zip(range(len(_weights)), _weights))

In [15]:
from torch.utils.data import DataLoader, Dataset
from utils import set_logger, save_checkpoint, load_checkpoint
import logging
#import gc
from scipy.ndimage.interpolation import rotate
import cv2
from augmentation import *


def prepare_dataset(img_arr, label_arr, mode='train', debug=False):
    """
    mode: 'train', 'valid', 'test'
    """
    if debug:#for debug, sample 1/10 data
        n = img_arr.shape[0]
        sid = np.random.choice(n, size=n//5, replace=False)
        img_arr = img_arr[sid]
        label_arr = label_arr[sid]

    if mode=='train':
        ds = DatasetV1(img_arr, label_arr, mode='train', augmentation=True)
        dl = DataLoader(ds,
                        batch_size=BATCH_SIZE,
                        shuffle=True,
                        #sampler=sampler,
                        num_workers=NUM_WORKERS,
                        drop_last=True
                       )

#     if mode=='train':
#         ds = DatasetV1(img_arr, label_arr, mode='train', augmentation=True)
#         weights = [weights_dict[c] for c in label_arr[:,0]]
#         sampler = torch.utils.data.WeightedRandomSampler(weights, num_samples=len(ds), replacement=True)
#         dl = DataLoader(ds,
#                         batch_size=BATCH_SIZE,
#                         shuffle=False,
#                         sampler=sampler,
#                         num_workers=NUM_WORKERS,
#                         drop_last=True
#                        )
#         return dl

    elif mode=='valid':
        ds = DatasetV1(img_arr, label_arr, mode='train', augmentation=False)
        dl = DataLoader(ds,
                        batch_size=BATCH_SIZE,
                        shuffle=False,
                        #sampler=sampler,
                        num_workers=NUM_WORKERS,
                        drop_last=True
                       )
    elif mode=='test':
        ds = DatasetV1(img_arr, label_arr, mode='test', augmentation=False)
        dl = DataLoader(ds,
                        batch_size=BATCH_SIZE,
                        shuffle=False,
                        #sampler=sampler,
                        num_workers=NUM_WORKERS,
                        drop_last=False
                       ) 
    return dl

class DatasetV1(Dataset):
    """plain"""
    def __init__(self, inputs, outputs, mode='train', augmentation=False):
        """
        inputs: images, (N, H, W)
        outputs: label, (N, 3)
        """
        self.inputs = inputs
        self.outputs = outputs
        self.mode = mode 
        self.augmentation = augmentation
    
    def __getitem__(self, idx):
        #TODO: augmentation, preprocessing
        inputs, outputs = self.inputs[idx], self.outputs[idx]
        inputs = np.clip((255-inputs)/255.0, 0, 1)

#         inputs = cv2.resize(inputs, (224,224))
#         inputs = np.clip(inputs, 0, 1)

        #crop
#         inputs = 255-inputs
#         inputs = crop_resize(inputs)
#         inputs = np.clip(inputs/255.0, 0, 1)

        if self.augmentation:
            inputs = self.do_augmentation(inputs)
        
        inputs = np.expand_dims(inputs, 0)#(224,224)-->(1,224,224)
        inputs = inputs.astype(np.float32)
        return inputs, outputs

    def __len__(self):
        return self.inputs.shape[0]
    
    def do_augmentation(self, image):
        #rotate
        #if np.random.rand() < 0.5:
        #angle = np.random.randint(0, 40) - 20
        #inputs = rotate(inputs, angle, reshape=False)
        for op in np.random.choice([
            lambda image : do_identity(image),
            lambda image : do_random_projective(image, 0.4),#0.4
            lambda image : do_random_perspective(image, 0.4),#0.4
            lambda image : do_random_scale(image, 0.4),#0.4
            lambda image : do_random_rotate(image, 0.4),#0.4
            lambda image : do_random_shear_x(image, 0.5),#0.5
            lambda image : do_random_shear_y(image, 0.4),#0.4
            lambda image : do_random_stretch_x(image, 0.5),#0.5
            lambda image : do_random_stretch_y(image, 0.5),#0.5
            lambda image : do_random_grid_distortion(image, 0.4),#0.4
            lambda image : do_random_custom_distortion1(image, 0.5),#0.5
        ],1):
            image = op(image)

        for op in np.random.choice([
            lambda image : do_identity(image),
            lambda image : do_random_erode(image, 0.4),#0.4
            lambda image : do_random_dilate(image, 0.4),#0.4
            lambda image : do_random_sprinkle(image, 0.5),#0.5
            #lambda image : do_random_line(image, 0.2),
        ],1):
            image = op(image)

        for op in np.random.choice([
            lambda image : do_identity(image),
            lambda image : do_random_contast(image, 0.5),#0.5
            lambda image : do_random_block_fade(image, 0.5),#0.5
        ],1):
            image = op(image)
        
#         if np.random.rand()<1.1:
#             image = do_random_shift_scale_crop_pad2(image, limit=0.1)
#         else:
#             image = do_shift_scale_rotate2(image, angle=np.random.uniform(0, 10))
        return image

def train_and_valid(net, train_dl, val_dl):
    """train one fold
    
    [settings]...
    
    [Epoch i]
        [Trainset]
            [Batch j]
        [Validset]
            [Batch k]
        [Logging/Checkpoint]
    """
    set_logger(LOG_PATH)
    logging.info('\n\n')
    #1. optim
    train_params = filter(lambda p: p.requires_grad, net.parameters())
#     optimizer = torch.optim.Adam(train_params, lr=LearningRate)
#     scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', 
#                                                           factor=0.5, patience=5, 
#                                                           verbose=False, threshold=0.0001, 
#                                                           threshold_mode='rel', cooldown=0, 
#                                                           min_lr=0, eps=1e-08)

#     optimizer = torch.optim.Adam(train_params, lr=5e-2)
#     scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5,10,15,20], gamma=0.1)

    optimizer = torch.optim.Adam(train_params)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=5e-3, 
                                                    steps_per_epoch=len(train_dl),#1255 
                                                    epochs=NUM_EPOCHS)

    #1.1 warm-start
    if warm_start:
        logging.info('warm_start: '+last_checkpoint_path)
        net, _ = load_checkpoint(last_checkpoint_path, net)
    
    #2. using multi GPU
    if multi_gpu is not None:
        net = nn.DataParallel(net, device_ids=multi_gpu)
    #3. train
    diff = 0
    best_val_metric = -1.0#np.inf
    optimizer.zero_grad()
    
    for i_epoch in range(NUM_EPOCHS):
        t0 = time.time()
        print('lr: ', scheduler.get_lr())
        ## trainset -------------------------------------------------------------
        net.train()
        loss_logger = LossLogger()
        for batch_id, (images, labels) in enumerate(train_dl):
            inputs = images.to(device=device, dtype=torch.float)
            truth = labels.to(device=device, dtype=torch.float)
            
            #do cutmix/mixup
            if np.random.rand()<0.5:
                inputs, truth = cutmix(inputs, truth, alpha=None)
            else:
                inputs, truth = mixup(inputs, truth, alpha=0.4)

            #if use ohem loss
            if i_epoch<-1:#80
                mode = 'normal'
                rate = None
            else:
                mode = 'ohem'
                if i_epoch<10:
                    rate = 1.0 #keep all loss, no ohem
                elif i_epoch<50:
                    rate = 0.7
                else:
                    rate = 0.2

            logit = net(inputs)
            logits = torch.split(logit, [n_grapheme, n_vowel, n_consonant], dim=1)
            train_loss = loss_logger.update(logits, truth, mode=mode, rate=rate)
            #grandient accumulation step=2
            acc_step = 1
            if acc_step>1:
                train_loss = train_loss / acc_step
            train_loss.backward()
            if (batch_id+1)%acc_step==0:
                optimizer.step()
                optimizer.zero_grad()
            ##lr scheduler
            scheduler.step()
        ##aggregate loss
        train_loss_total, _, _, _ = loss_logger.aggregate()
        
#         ##check for memory leakage
#         for obj in gc.get_objects():
#             try:
#                 if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
#                     print(type(obj), obj.size())
#             except:
#                 pass
        ## validset -------------------------------------------------------------
        net.eval()
        loss_logger = LossLogger()
        metric_logger = MetricLogger()
        with torch.no_grad():
            for batch_id, (images, labels) in enumerate(val_dl):
                inputs = images.to(device=device, dtype=torch.float)
                truth = labels.to(device=device, dtype=torch.float)
                logit = net(inputs)
                logits = torch.split(logit, [n_grapheme, n_vowel, n_consonant], dim=1)
                _ = loss_logger.update(logits, truth, mode='normal', rate=None)
                metric_logger.update(logits, truth)
        rec, rec_grapheme, rec_vowel, rec_consonant = metric_logger.aggregate()
        loss_total, loss_grapheme, loss_vowel, loss_consonant = loss_logger.aggregate()
        
        ## callbacks -------------------------------------------------------------
        val_metric = rec#loss_total
        #scheduler.step(val_metric)
        ##lr scheduler
        scheduler.step()

        #sometimes too early stop, force to at least train N epochs
        if i_epoch>=-1:
            if val_metric > best_val_metric:
                best_val_metric = val_metric
                is_best = True
                diff = 0
            else:
                is_best = False
                diff += 1
                if diff > early_stopping_round:
                    logging.info('Early Stopping: val_metric does not increase %d rounds'%early_stopping_round)
                    break
        else:
            is_best = False
        
        #save checkpoint
        checkpoint_dict = \
        {
            'epoch': i_epoch,
            'state_dict': net.module.state_dict() if multi_gpu is not None else net.state_dict(),
            'optim_dict' : optimizer.state_dict(),
            'metrics': {'train_loss': train_loss_total, 'val_loss': loss_total, 
                        'val_metric': val_metric}
        }
        save_checkpoint(checkpoint_dict, is_best=is_best, checkpoint=checkpoint_path)
        
        #logging loss/metric
        logging.info('[EPOCH %05d]train_loss: %0.4f; val_loss: %0.4f; time elapsed: %0.1f min'%(i_epoch, 
                    train_loss_total, loss_total, (time.time()-t0)/60))
        logging.info('[valid loss]grapheme: %0.4f, vowel: %0.4f, consonant: %0.4f'%(loss_grapheme, 
                                                                        loss_vowel, loss_consonant))
        logging.info('[valid recall]total: %0.4f, grapheme: %0.4f, vowel: %0.4f, consonant: %0.4f'%(rec, 
                    rec_grapheme, rec_vowel, rec_consonant))
        logging.info('='*80)

def predict(test_dl):
    pass

In [16]:
# import torch
# import torchvision
# print(torch.__version__)

# NUM_EPOCHS = 60

# model = torchvision.models.resnet18(pretrained=False)
# optimizer = torch.optim.Adam(model.parameters())
# scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=5e-3, 
#                                                 steps_per_epoch=1255, epochs=NUM_EPOCHS)#steps_per_epoch=len(dl)
# # optimizer = torch.optim.Adam(model.parameters(), lr=5e-3)
# # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10,30,50], gamma=0.1)

# l = []
# for epoch in range(NUM_EPOCHS):
#     for i,batch in enumerate(range(1255)):
#         #pass
#         l.append(scheduler.get_lr())
#         #train_batch(...)
#     #l.append(scheduler.get_lr()[0])
#         scheduler.step()

# l[0::1255]
# #l

In [17]:
from sklearn.metrics import recall_score

class LossLogger(object):
    """loss for an epoch
    
    [Epoch i]:
        loss_logger = LossLogger()
        
        [Batch j]:
            loss = loss_logger.update(logits, truth)
            loss.backward()
        
        loss, loss_grapheme, loss_vowel, loss_consonant = loss_logger.aggregate()
    """
    def __init__(self):
        self.loss_grapheme = []
        self.loss_vowel = []
        self.loss_consonant = []

    def update(self, logits, truth, mode='normal', rate=None):
        """
        logits: logit splitted to [logit_grapheme, logit_vowel, logit_consonant]
        truth: shape (N, 3)
        """
        truth1, truth2, truth3, truth4, truth5, truth6, lam = \
                truth[0], truth[1], truth[2], truth[3], truth[4], truth[5], truth[6]
        if mode=='normal':
            criterion = nn.CrossEntropyLoss(reduction='mean')
            #loss
            loss_grapheme = F.cross_entropy(logits[0], truth[:,0].long())
            loss_vowel = F.cross_entropy(logits[1], truth[:,1].long())
            loss_consonant = F.cross_entropy(logits[2], truth[:,2].long())
        elif mode=='ohem':
            criterion = ohem_loss
            loss_grapheme = lam * criterion(logits[0], truth1.long(), rate) + \
                (1 - lam) * criterion(logits[0], truth2.long(), rate)
            loss_vowel = lam * criterion(logits[1], truth3.long(), rate) + \
                (1 - lam) * criterion(logits[1], truth4.long(), rate)
            loss_consonant = lam * criterion(logits[2], truth5.long(), rate) + \
                (1 - lam) * criterion(logits[2], truth6.long(), rate)
        loss = 0.5*loss_grapheme + 0.25*loss_vowel + 0.25*loss_consonant
        #
        self.loss_grapheme.append(loss_grapheme.item())
        self.loss_vowel.append(loss_vowel.item())
        self.loss_consonant.append(loss_consonant.item())
        return loss
    
    def aggregate(self):
        """
        for print logging
        """
        loss_grapheme = np.mean(self.loss_grapheme)
        loss_vowel = np.mean(self.loss_vowel)
        loss_consonant = np.mean(self.loss_consonant)
        loss_total = np.mean(
            0.5*np.array(self.loss_grapheme) + \
            0.25*np.array(self.loss_vowel) + \
            0.25*np.array(self.loss_consonant)
        )
        return loss_total, loss_grapheme, loss_vowel, loss_consonant

def ohem_loss(cls_pred, cls_target, rate=0.7):
    """TODO: rate may change per EPOCH"""
    batch_size = cls_pred.size(0) 
    ohem_cls_loss = F.cross_entropy(cls_pred, cls_target, reduction='none', ignore_index=-1)

#     sorted_ohem_loss, idx = torch.sort(ohem_cls_loss, descending=True)
#     keep_num = min(sorted_ohem_loss.size()[0], int(batch_size*rate) )
#     if keep_num < sorted_ohem_loss.size()[0]:
#         keep_idx_cuda = idx[:keep_num]
#         ohem_cls_loss = ohem_cls_loss[keep_idx_cuda]
#     cls_loss = ohem_cls_loss.sum() / keep_num
    ohem_cls_loss_sorted, idx = torch.topk(ohem_cls_loss, k=int(rate * batch_size), 
                                           largest=True, sorted=True, out=None)
    cls_loss = ohem_cls_loss_sorted.mean()
    return cls_loss

class MetricLogger(object):
    """recall, precision for an epoch
    
    [Epoch i]:
        metric_logger = MetricLogger()
        
        [Batch j]:
            metric_logger.update(logits, truth)
        
        rec, rec_grapheme, rec_vowel, rec_consonant = metric_logger.aggregate()
    """
    def __init__(self):
        self.pred_grapheme = torch.tensor([], dtype=torch.long).cuda(device)
        self.pred_vowel = torch.tensor([], dtype=torch.long).cuda(device=device)
        self.pred_consonant = torch.tensor([], dtype=torch.long).cuda(device=device)
        
        self.truth_grapheme = torch.tensor([], dtype=torch.long).cuda(device=device)
        self.truth_vowel = torch.tensor([], dtype=torch.long).cuda(device=device)
        self.truth_consonant = torch.tensor([], dtype=torch.long).cuda(device=device)

    def update(self, logits, truth):
        pred = torch.argmax(logits[0], dim=1)
        self.pred_grapheme = torch.cat([self.pred_grapheme, pred])
        self.truth_grapheme = torch.cat([self.truth_grapheme, truth[:, 0].long()])
        #
        pred = torch.argmax(logits[1], dim=1)
        self.pred_vowel = torch.cat([self.pred_vowel, pred])
        self.truth_vowel = torch.cat([self.truth_vowel, truth[:, 1].long()])
        #
        pred = torch.argmax(logits[2], dim=1)
        self.pred_consonant = torch.cat([self.pred_consonant, pred])
        self.truth_consonant = torch.cat([self.truth_consonant, truth[:, 2].long()])

    def aggregate(self):
        rec_grapheme = recall_score(self.truth_grapheme.cpu().numpy(), 
                                    self.pred_grapheme.cpu().numpy(), 
                                    average='macro')
        rec_vowel = recall_score(self.truth_vowel.cpu().numpy(), 
                                 self.pred_vowel.cpu().numpy(), 
                                 average='macro')
        rec_consonant = recall_score(self.truth_consonant.cpu().numpy(), 
                                     self.pred_consonant.cpu().numpy(), 
                                     average='macro')
        #rec = (2*rec_grapheme + 1*rec_vowel + 1*rec_consonant) / 4
        rec = np.average([rec_grapheme, rec_vowel, rec_consonant], weights=[2,1,1])
        return rec, rec_grapheme, rec_vowel, rec_consonant

# #debug MetricLogger
# #Epoch 0
# loss_logger = LossLogger()
# metric_logger = MetricLogger()
# for batch_id, (images, labels) in enumerate(val_dl):
#     inputs = images.to(device=device, dtype=torch.float)
#     truth = labels.to(device=device, dtype=torch.float)
#     if batch_id==10:
#         break
#     logit = net(inputs)
#     logits = torch.split(logit, [n_grapheme, n_vowel, n_consonant], dim=1)
#     loss = loss_logger.update(logits, truth)
#     metric_logger.update(logits, truth)
# rec, rec_grapheme, rec_vowel, rec_consonant = metric_logger.aggregate()
# loss_total, loss_grapheme, loss_vowel, loss_consonant = loss_logger.aggregate()
# print(rec, rec_grapheme, rec_vowel, rec_consonant)
# print(loss_total, loss_grapheme, loss_vowel, loss_consonant)

In [18]:
# from senet import se_resnext50_32x4d
# from senet_v2 import se_resnext50_32x4d
from efficientnet import EffiNet

In [19]:
#### training for 5 folds here ####
kf = KFold(n_splits=5, shuffle=True, random_state=SEED).split(X=train_images_arr, y=train_label_arr)

for fold, (train_idx, valid_idx) in enumerate(kf):
    
    if fold in [0]:#train 1 fold for testing ideas
        print('========training fold %d========'%fold)
        print(train_idx)
        print(valid_idx)
        
        #checkpoint_path = '../checkpoint/v3-fold'+str(fold)
        #print('checkpoint_path: ', checkpoint_path)
        
        #1.1 data
        train_inputs, valid_inputs = train_images_arr[train_idx], train_images_arr[valid_idx]
        train_outputs, valid_outputs = train_label_arr[train_idx], train_label_arr[valid_idx]
        #1.2 Dataset, DataLoader
        train_dl = prepare_dataset(train_inputs, train_outputs, mode='train', debug=debug)
        val_dl = prepare_dataset(valid_inputs, valid_outputs, mode='valid', debug=debug)
        
        #2. model
        #net = se_resnext50_32x4d(num_classes=num_classes, pretrained=None).cuda(device=device)
        #net = se_resnext50_32x4d(num_classes=num_classes, pretrained='imagenet').cuda(device=device)
        net = EffiNet(model='b3').cuda(device=device)

        #3. train session
        train_and_valid(net, train_dl, val_dl)


[     0      1      2 ... 200835 200837 200839]
[     4      6     12 ... 200832 200836 200838]
Loaded pretrained weights for efficientnet-b3







lr:  [0.00019999999999999966]


[EPOCH 00000]train_loss: 2.7047; val_loss: 0.8431; time elapsed: 7.7 min
[valid loss]grapheme: 1.0416, vowel: 0.6096, consonant: 0.6796
[valid recall]total: 0.7330, grapheme: 0.6436, vowel: 0.8879, consonant: 0.7570


lr:  [0.00023652259907873396]


[EPOCH 00001]train_loss: 1.8154; val_loss: 0.4404; time elapsed: 7.9 min
[valid loss]grapheme: 0.5120, vowel: 0.3318, consonant: 0.4057
[valid recall]total: 0.8477, grapheme: 0.8433, vowel: 0.9396, consonant: 0.7645


lr:  [0.0003449788127787151]


[EPOCH 00002]train_loss: 1.5375; val_loss: 0.3814; time elapsed: 8.1 min
[valid loss]grapheme: 0.5104, vowel: 0.2314, consonant: 0.2734
[valid recall]total: 0.8739, grapheme: 0.8528, vowel: 0.9448, consonant: 0.8453


lr:  [0.0005220677220911405]


[EPOCH 00003]train_loss: 1.3944; val_loss: 0.3121; time elapsed: 8.1 min
[valid loss]grapheme: 0.3967, vowel: 0.2345, consonant: 0.2205
[valid recall]total: 0.9060, grapheme: 0.8958, vowel: 0.9603, consonant: 0.8721


lr:  [0.0007623995376525878]


[EPOCH 00004]train_loss: 1.3127; val_loss: 0.3388; time elapsed: 8.0 min
[valid loss]grapheme: 0.4679, vowel: 0.2332, consonant: 0.1861
[valid recall]total: 0.9064, grapheme: 0.8718, vowel: 0.9614, consonant: 0.9207


lr:  [0.0010586596406750451]


[EPOCH 00005]train_loss: 1.2585; val_loss: 0.3073; time elapsed: 7.9 min
[valid loss]grapheme: 0.4290, vowel: 0.2157, consonant: 0.1558
[valid recall]total: 0.9222, grapheme: 0.8788, vowel: 0.9724, consonant: 0.9587


lr:  [0.001401831207020416]


[EPOCH 00006]train_loss: 1.2229; val_loss: 0.2954; time elapsed: 8.0 min
[valid loss]grapheme: 0.3792, vowel: 0.2409, consonant: 0.1824
[valid recall]total: 0.9194, grapheme: 0.9056, vowel: 0.9586, consonant: 0.9077


lr:  [0.0017814696387446574]


[EPOCH 00007]train_loss: 1.2047; val_loss: 0.2488; time elapsed: 8.1 min
[valid loss]grapheme: 0.3454, vowel: 0.1649, consonant: 0.1397
[valid recall]total: 0.9315, grapheme: 0.9086, vowel: 0.9704, consonant: 0.9386


lr:  [0.002186020450650482]


[EPOCH 00008]train_loss: 1.1873; val_loss: 0.5396; time elapsed: 8.0 min
[valid loss]grapheme: 0.7570, vowel: 0.3948, consonant: 0.2496
[valid recall]total: 0.8330, grapheme: 0.7918, vowel: 0.8643, consonant: 0.8839


lr:  [0.0026031709368127135]


[EPOCH 00009]train_loss: 1.1756; val_loss: 0.2418; time elapsed: 8.0 min
[valid loss]grapheme: 0.3462, vowel: 0.1379, consonant: 0.1367
[valid recall]total: 0.9326, grapheme: 0.9016, vowel: 0.9667, consonant: 0.9604


lr:  [0.0030202249139300156]


[EPOCH 00010]train_loss: 1.5542; val_loss: 0.4341; time elapsed: 8.1 min
[valid loss]grapheme: 0.5945, vowel: 0.3471, consonant: 0.2003
[valid recall]total: 0.8927, grapheme: 0.8654, vowel: 0.9020, consonant: 0.9379


lr:  [0.0034244891360020163]


[EPOCH 00011]train_loss: 1.4959; val_loss: 0.3425; time elapsed: 8.1 min
[valid loss]grapheme: 0.4466, vowel: 0.2689, consonant: 0.2080
[valid recall]total: 0.9292, grapheme: 0.9066, vowel: 0.9671, consonant: 0.9364


lr:  [0.0038036596196082064]


[EPOCH 00012]train_loss: 1.4486; val_loss: 0.3226; time elapsed: 7.9 min
[valid loss]grapheme: 0.4102, vowel: 0.2596, consonant: 0.2101
[valid recall]total: 0.9321, grapheme: 0.9056, vowel: 0.9591, consonant: 0.9583


lr:  [0.004146196121785886]


[EPOCH 00013]train_loss: 1.4374; val_loss: 0.4298; time elapsed: 8.0 min
[valid loss]grapheme: 0.5876, vowel: 0.3228, consonant: 0.2211
[valid recall]total: 0.9191, grapheme: 0.8848, vowel: 0.9531, consonant: 0.9536


lr:  [0.004441673373085999]


[EPOCH 00014]train_loss: 1.4220; val_loss: 0.5024; time elapsed: 8.1 min
[valid loss]grapheme: 0.7238, vowel: 0.3380, consonant: 0.2239
[valid recall]total: 0.8876, grapheme: 0.8393, vowel: 0.9644, consonant: 0.9076


lr:  [0.0046810983758534454]


[EPOCH 00015]train_loss: 1.3530; val_loss: 0.2793; time elapsed: 8.0 min
[valid loss]grapheme: 0.3808, vowel: 0.1974, consonant: 0.1580
[valid recall]total: 0.9384, grapheme: 0.9085, vowel: 0.9728, consonant: 0.9637


lr:  [0.004857184110600309]


[EPOCH 00016]train_loss: 1.4052; val_loss: 0.3096; time elapsed: 7.9 min
[valid loss]grapheme: 0.4035, vowel: 0.2373, consonant: 0.1940
[valid recall]total: 0.9427, grapheme: 0.9242, vowel: 0.9703, consonant: 0.9519


lr:  [0.0049645713200818856]


[EPOCH 00017]train_loss: 1.3767; val_loss: 0.2488; time elapsed: 8.0 min
[valid loss]grapheme: 0.3227, vowel: 0.1985, consonant: 0.1513
[valid recall]total: 0.9484, grapheme: 0.9291, vowel: 0.9705, consonant: 0.9649


lr:  [0.004999998397016159]


[EPOCH 00018]train_loss: 1.3508; val_loss: 0.2883; time elapsed: 8.2 min
[valid loss]grapheme: 0.3721, vowel: 0.2274, consonant: 0.1817
[valid recall]total: 0.9464, grapheme: 0.9352, vowel: 0.9727, consonant: 0.9426


lr:  [0.004992785049552961]


[EPOCH 00019]train_loss: 1.3278; val_loss: 0.2584; time elapsed: 8.1 min
[valid loss]grapheme: 0.3352, vowel: 0.2082, consonant: 0.1551
[valid recall]total: 0.9497, grapheme: 0.9309, vowel: 0.9762, consonant: 0.9607


lr:  [0.004971608878267791]


[EPOCH 00020]train_loss: 1.3161; val_loss: 0.2173; time elapsed: 8.0 min
[valid loss]grapheme: 0.2859, vowel: 0.1567, consonant: 0.1406
[valid recall]total: 0.9506, grapheme: 0.9341, vowel: 0.9734, consonant: 0.9608


lr:  [0.004936588497613618]


[EPOCH 00021]train_loss: 1.3088; val_loss: 0.4414; time elapsed: 8.1 min
[valid loss]grapheme: 0.6429, vowel: 0.2646, consonant: 0.2151
[valid recall]total: 0.8728, grapheme: 0.8250, vowel: 0.9587, consonant: 0.8823


lr:  [0.0048879200678518114]


[EPOCH 00022]train_loss: 1.2768; val_loss: 0.2507; time elapsed: 8.1 min
[valid loss]grapheme: 0.3231, vowel: 0.2093, consonant: 0.1472
[valid recall]total: 0.9475, grapheme: 0.9295, vowel: 0.9789, consonant: 0.9522


lr:  [0.004825876196296287]


[EPOCH 00023]train_loss: 1.2724; val_loss: 0.2112; time elapsed: 8.1 min
[valid loss]grapheme: 0.2819, vowel: 0.1525, consonant: 0.1283
[valid recall]total: 0.9505, grapheme: 0.9347, vowel: 0.9705, consonant: 0.9619


lr:  [0.0047508044103534664]


[EPOCH 00024]train_loss: 1.2732; val_loss: 0.2078; time elapsed: 8.0 min
[valid loss]grapheme: 0.2885, vowel: 0.1332, consonant: 0.1208
[valid recall]total: 0.9521, grapheme: 0.9318, vowel: 0.9756, consonant: 0.9692


lr:  [0.004663125210911027]


[EPOCH 00025]train_loss: 1.2396; val_loss: 0.4008; time elapsed: 8.1 min
[valid loss]grapheme: 0.6193, vowel: 0.2164, consonant: 0.1483
[valid recall]total: 0.8942, grapheme: 0.8338, vowel: 0.9688, consonant: 0.9403


lr:  [0.004563329716979042]


[EPOCH 00026]train_loss: 1.2409; val_loss: 0.3875; time elapsed: 8.2 min
[valid loss]grapheme: 0.5164, vowel: 0.3149, consonant: 0.2023
[valid recall]total: 0.9321, grapheme: 0.9078, vowel: 0.9673, consonant: 0.9456


lr:  [0.004451976914776603]


[EPOCH 00027]train_loss: 1.2290; val_loss: 0.2129; time elapsed: 8.0 min
[valid loss]grapheme: 0.2917, vowel: 0.1436, consonant: 0.1244
[valid recall]total: 0.9512, grapheme: 0.9279, vowel: 0.9818, consonant: 0.9675


lr:  [0.004329690526672687]


[EPOCH 00028]train_loss: 1.2023; val_loss: 0.2949; time elapsed: 8.0 min
[valid loss]grapheme: 0.3703, vowel: 0.2752, consonant: 0.1637
[valid recall]total: 0.9502, grapheme: 0.9356, vowel: 0.9708, consonant: 0.9589


lr:  [0.00419715551751931]


[EPOCH 00029]train_loss: 1.2285; val_loss: 0.2246; time elapsed: 8.1 min
[valid loss]grapheme: 0.2885, vowel: 0.1907, consonant: 0.1309
[valid recall]total: 0.9602, grapheme: 0.9431, vowel: 0.9791, consonant: 0.9756


lr:  [0.004055114257946095]


[EPOCH 00030]train_loss: 1.1897; val_loss: 0.2322; time elapsed: 8.2 min
[valid loss]grapheme: 0.3143, vowel: 0.1644, consonant: 0.1360
[valid recall]total: 0.9472, grapheme: 0.9229, vowel: 0.9788, consonant: 0.9642


lr:  [0.0039043623661068868]


[EPOCH 00031]train_loss: 1.1956; val_loss: 0.2130; time elapsed: 8.0 min
[valid loss]grapheme: 0.2719, vowel: 0.1703, consonant: 0.1379
[valid recall]total: 0.9640, grapheme: 0.9500, vowel: 0.9785, consonant: 0.9774


lr:  [0.003745744251170077]


[EPOCH 00032]train_loss: 1.1385; val_loss: 0.1484; time elapsed: 8.1 min
[valid loss]grapheme: 0.1894, vowel: 0.1141, consonant: 0.1006
[valid recall]total: 0.9703, grapheme: 0.9582, vowel: 0.9848, consonant: 0.9799


lr:  [0.003580148383514985]


[EPOCH 00033]train_loss: 1.1370; val_loss: 0.2445; time elapsed: 8.2 min
[valid loss]grapheme: 0.3113, vowel: 0.2161, consonant: 0.1393
[valid recall]total: 0.9640, grapheme: 0.9477, vowel: 0.9840, consonant: 0.9766


lr:  [0.0034085023181274206]


[EPOCH 00034]train_loss: 1.1090; val_loss: 0.2192; time elapsed: 8.1 min
[valid loss]grapheme: 0.2811, vowel: 0.1821, consonant: 0.1323
[valid recall]total: 0.9685, grapheme: 0.9539, vowel: 0.9854, consonant: 0.9808


lr:  [0.003231767499069964]


[EPOCH 00035]train_loss: 1.1342; val_loss: 0.1463; time elapsed: 8.0 min
[valid loss]grapheme: 0.1894, vowel: 0.1117, consonant: 0.0949
[valid recall]total: 0.9693, grapheme: 0.9549, vowel: 0.9854, consonant: 0.9820


lr:  [0.003050933874128778]


[EPOCH 00036]train_loss: 1.1365; val_loss: 0.1883; time elapsed: 8.0 min
[valid loss]grapheme: 0.2458, vowel: 0.1451, consonant: 0.1166
[valid recall]total: 0.9598, grapheme: 0.9442, vowel: 0.9771, consonant: 0.9737


lr:  [0.002867014349802023]


[EPOCH 00037]train_loss: 1.0960; val_loss: 0.1455; time elapsed: 8.1 min
[valid loss]grapheme: 0.1888, vowel: 0.1094, consonant: 0.0951
[valid recall]total: 0.9677, grapheme: 0.9531, vowel: 0.9862, consonant: 0.9785


lr:  [0.00268103911768924]


[EPOCH 00038]train_loss: 1.0967; val_loss: 0.1793; time elapsed: 8.0 min
[valid loss]grapheme: 0.2263, vowel: 0.1541, consonant: 0.1105
[valid recall]total: 0.9718, grapheme: 0.9586, vowel: 0.9860, consonant: 0.9841


lr:  [0.0024940498840614004]


[EPOCH 00039]train_loss: 1.0847; val_loss: 0.1820; time elapsed: 8.1 min
[valid loss]grapheme: 0.2318, vowel: 0.1452, consonant: 0.1193
[valid recall]total: 0.9676, grapheme: 0.9557, vowel: 0.9840, consonant: 0.9752


lr:  [0.002307094034933627]


[EPOCH 00040]train_loss: 1.0884; val_loss: 0.1991; time elapsed: 8.2 min
[valid loss]grapheme: 0.2399, vowel: 0.1595, consonant: 0.1573
[valid recall]total: 0.9598, grapheme: 0.9417, vowel: 0.9801, consonant: 0.9756


lr:  [0.0021212187693238705]


[EPOCH 00041]train_loss: 1.0380; val_loss: 0.1602; time elapsed: 8.2 min
[valid loss]grapheme: 0.2011, vowel: 0.1373, consonant: 0.1010
[valid recall]total: 0.9739, grapheme: 0.9621, vowel: 0.9862, consonant: 0.9852


lr:  [0.001937465233559005]


[EPOCH 00042]train_loss: 1.0406; val_loss: 0.1327; time elapsed: 8.0 min
[valid loss]grapheme: 0.1683, vowel: 0.1066, consonant: 0.0874
[valid recall]total: 0.9742, grapheme: 0.9619, vowel: 0.9867, consonant: 0.9865


lr:  [0.0017568626894839828]


[EPOCH 00043]train_loss: 1.0379; val_loss: 0.1918; time elapsed: 8.0 min
[valid loss]grapheme: 0.2518, vowel: 0.1526, consonant: 0.1109
[valid recall]total: 0.9593, grapheme: 0.9386, vowel: 0.9818, consonant: 0.9783


lr:  [0.0015804227492397258]


[EPOCH 00044]train_loss: 1.0217; val_loss: 0.1324; time elapsed: 8.1 min
[valid loss]grapheme: 0.1637, vowel: 0.1103, consonant: 0.0918
[valid recall]total: 0.9753, grapheme: 0.9648, vowel: 0.9866, consonant: 0.9850


lr:  [0.001409133708902623]


[EPOCH 00045]train_loss: 1.0258; val_loss: 0.2102; time elapsed: 8.1 min
[valid loss]grapheme: 0.2604, vowel: 0.1910, consonant: 0.1287
[valid recall]total: 0.9703, grapheme: 0.9576, vowel: 0.9857, consonant: 0.9803


lr:  [0.0012439550127246893]


[EPOCH 00046]train_loss: 1.0446; val_loss: 0.1031; time elapsed: 8.0 min
[valid loss]grapheme: 0.1243, vowel: 0.0874, consonant: 0.0763
[valid recall]total: 0.9784, grapheme: 0.9694, vowel: 0.9880, consonant: 0.9866


lr:  [0.001085811878981954]


[EPOCH 00047]train_loss: 1.0429; val_loss: 0.1773; time elapsed: 8.1 min
[valid loss]grapheme: 0.2191, vowel: 0.1509, consonant: 0.1200
[valid recall]total: 0.9698, grapheme: 0.9561, vowel: 0.9843, consonant: 0.9826


lr:  [0.0009355901175333936]


[EPOCH 00048]train_loss: 1.0299; val_loss: 0.1578; time elapsed: 8.2 min
[valid loss]grapheme: 0.1991, vowel: 0.1319, consonant: 0.1012
[valid recall]total: 0.9730, grapheme: 0.9600, vowel: 0.9872, consonant: 0.9849


lr:  [0.0007941311681189094]


[EPOCH 00049]train_loss: 0.9988; val_loss: 0.1505; time elapsed: 8.0 min
[valid loss]grapheme: 0.1812, vowel: 0.1387, consonant: 0.1010
[valid recall]total: 0.9766, grapheme: 0.9679, vowel: 0.9872, consonant: 0.9833


lr:  [0.0006622273871884255]


[EPOCH 00050]train_loss: 1.6882; val_loss: 0.3013; time elapsed: 7.9 min
[valid loss]grapheme: 0.3789, vowel: 0.2656, consonant: 0.1817
[valid recall]total: 0.9776, grapheme: 0.9670, vowel: 0.9887, consonant: 0.9879


lr:  [0.0005406176096620473]


[EPOCH 00051]train_loss: 1.6759; val_loss: 0.2369; time elapsed: 8.0 min
[valid loss]grapheme: 0.2797, vowel: 0.2012, consonant: 0.1869
[valid recall]total: 0.9793, grapheme: 0.9706, vowel: 0.9890, consonant: 0.9872


lr:  [0.0004299830104813085]


[EPOCH 00052]train_loss: 1.6094; val_loss: 0.2129; time elapsed: 8.0 min
[valid loss]grapheme: 0.2543, vowel: 0.1687, consonant: 0.1743
[valid recall]total: 0.9780, grapheme: 0.9682, vowel: 0.9886, consonant: 0.9870


lr:  [0.00033094328913222813]


[EPOCH 00053]train_loss: 1.6551; val_loss: 0.4936; time elapsed: 7.9 min
[valid loss]grapheme: 0.6502, vowel: 0.3868, consonant: 0.2872
[valid recall]total: 0.9792, grapheme: 0.9701, vowel: 0.9893, consonant: 0.9873


lr:  [0.0002440531985119155]


[EPOCH 00054]train_loss: 1.5661; val_loss: 0.2845; time elapsed: 7.9 min
[valid loss]grapheme: 0.3386, vowel: 0.2428, consonant: 0.2180
[valid recall]total: 0.9802, grapheme: 0.9718, vowel: 0.9900, consonant: 0.9873


lr:  [0.0001697994375816017]


[EPOCH 00055]train_loss: 1.6198; val_loss: 0.5078; time elapsed: 8.0 min
[valid loss]grapheme: 0.6655, vowel: 0.4021, consonant: 0.2981
[valid recall]total: 0.9800, grapheme: 0.9721, vowel: 0.9901, consonant: 0.9856


lr:  [0.0001085979252113749]


[EPOCH 00056]train_loss: 1.5648; val_loss: 0.3110; time elapsed: 8.1 min
[valid loss]grapheme: 0.3657, vowel: 0.2719, consonant: 0.2405
[valid recall]total: 0.9805, grapheme: 0.9719, vowel: 0.9896, consonant: 0.9888


lr:  [6.079147048664102e-05]


[EPOCH 00057]train_loss: 1.6319; val_loss: 0.1894; time elapsed: 7.9 min
[valid loss]grapheme: 0.2113, vowel: 0.1666, consonant: 0.1682
[valid recall]total: 0.9810, grapheme: 0.9725, vowel: 0.9899, consonant: 0.9889


lr:  [2.664785252569453e-05]


[EPOCH 00058]train_loss: 1.6272; val_loss: 0.3597; time elapsed: 7.9 min
[valid loss]grapheme: 0.4372, vowel: 0.3170, consonant: 0.2472
[valid recall]total: 0.9806, grapheme: 0.9722, vowel: 0.9900, consonant: 0.9880


lr:  [6.358320563928247e-06]


ValueError: Tried to step 75302 times. The specified number of total steps is 75300

In [None]:
#fold0 - val_loss: 0.0621; CV: 0.9874, LB=0.9801
#fold1 - val_loss: 0.0627; CV: 0.9861, LB=0.9796
#fold2 - val_loss: 0.0674; CV: 0.9875, LB=0.9789
#fold3 - val_loss: 0.0664; CV: 0.9853, LB=0.9787
#fold4 - val_loss: 0.0670; CV: 0.9854, LB=

In [None]:
# fold 0
# [     0      1      2 ... 200835 200837 200839]
# [     4      6     12 ... 200832 200836 200838]

### check dataset/network/etc.

In [None]:
# kf = KFold(n_splits=5, shuffle=True, random_state=SEED).split(X=train_images_arr, y=train_label_arr)

# for fold, (train_idx, valid_idx) in enumerate(kf):
    
#     if fold in [0]:#train 1 fold for testing ideas
#         print('========training fold %d========'%fold)
#         print(train_idx)
#         print(valid_idx)
        
#         #1.1 data
#         train_inputs, valid_inputs = train_images_arr[train_idx], train_images_arr[valid_idx]
#         train_outputs, valid_outputs = train_label_arr[train_idx], train_label_arr[valid_idx]
#         #1.2 Dataset, DataLoader
#         train_dl = prepare_dataset(train_inputs, train_outputs, mode='train', debug=debug)
#         val_dl = prepare_dataset(valid_inputs, valid_outputs, mode='valid', debug=debug)

# ####
# for batch_id, (images, labels) in enumerate(train_dl):
#     inputs = images.to(device=device, dtype=torch.float)
#     truth = labels.to(device=device, dtype=torch.float)
#     if np.random.rand()<1.0:
#         print('cutmix')
#         inputs, truth = cutmix(inputs, truth, alpha=None)
#     else:
#         print('mixup')
#         inputs, truth = mixup(inputs, truth, alpha=0.4)
#     if batch_id==1:
#         break

# #print(inputs.shape, truth.shape)

In [None]:
# show_inputs = inputs.cpu().numpy()[np.random.choice(BATCH_SIZE, 20), :, :, :]
# fig,axes = plt.subplots(5,4, figsize=(10,8))
# for i in range(20):
#     axes[i//4, i%4].imshow(show_inputs[i, 0], cmap='binary')

In [None]:
# import importlib
# import efficientnet
# importlib.reload(efficientnet)

# import torch

# #from senet_v2 import se_resnext50_32x4d
# from efficientnet import EffiNet

# #net = se_resnext50_32x4d(num_classes=186, pretrained='imagenet', debug=True).cuda(device='cuda:2')
# net = EffiNet(model='b3', debug=True).cuda(device='cuda:3')

# inputs = torch.rand((128, 1, 137, 236), dtype=torch.float).cuda(device='cuda:3')
# print(inputs.size())

# logit = net(inputs)
# print(logit.shape)