In [1]:
#define environment
ON_KAGGLE = False
TRAIN_PREDICT = 'train'

In [2]:
import sys
# if not ON_KAGGLE:
#     sys.path.append('../')

import os
import pickle
import time

import pandas as pd
import numpy as np
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
from torch.nn import functional as F
print('torch', torch.__version__)

if TRAIN_PREDICT=='train':
    from sklearn.model_selection import train_test_split, KFold

from utils import seed_everything, set_n_get_device


%matplotlib inline

torch 1.1.0


In [3]:
if ON_KAGGLE:
    #load utility scripts
    pass

else:#offline
    pass
    

In [4]:
#config
debug = False
SEED = 42
IMG_HEIGHT = 137
IMG_WIDTH = 236

if TRAIN_PREDICT=='train':
    BATCH_SIZE = 128
else:
    BATCH_SIZE = 256

if ON_KAGGLE:
    NUM_WORKERS = 2
else:
    NUM_WORKERS = 16

device = set_n_get_device("2,3", data_device_id="cuda:0")#IMPORTANT: data_device_id is set to free gpu for storing the model, e.g."cuda:1"
multi_gpu = [0,1]

if debug:
    LOG_PATH = '../logging/v2-debug.log'
else:
    LOG_PATH = '../logging/v2.log'

checkpoint_path = '../checkpoint/v2'
warm_start, last_checkpoint_path = False, '../checkpoint/v2/last.pth.tar'

NUM_EPOCHS = 100
early_stopping_round = 15
LearningRate = 5e-3

seed_everything(SEED)

In [5]:
#CONSTANTS
n_grapheme=168
n_vowel=11
n_consonant=7
#n_combo = 1295

#num_classes = n_grapheme+n_vowel+n_consonant+n_combo
num_classes = n_grapheme+n_vowel+n_consonant

## Pipeline
1. pytorch Dataset, data augmentation, DataLoader, train-test-split/KFold, 
2. network
3. training process

In [6]:
if ON_KAGGLE:
    pass

else:#offline
    train_df_list = [pd.read_feather('../data/processed/train_image_data_%d.feather'%i) for i in range(4)]
    train_images_arr = np.concatenate([df.iloc[:, 1:].values.reshape(-1, IMG_HEIGHT, IMG_WIDTH) 
                                       for df in train_df_list], axis=0)
    train_label_df = pd.read_csv('../data/raw/train.csv')
    train_label_arr = train_label_df[['grapheme_root', 'vowel_diacritic', 'consonant_diacritic']].values

train_images_arr.shape, train_label_arr.shape

((200840, 137, 236), (200840, 3))

In [7]:
## 1. encode grapheme characters to index 2. use it as sampler
unique_char = train_label_df['grapheme'].unique()
char2ind = dict([(char,i) for i,char in enumerate(unique_char)])
grapheme_ind = [char2ind[char] for char in train_label_df['grapheme']]
cls_w_dict = pd.value_counts(grapheme_ind)
cls_w_dict /= 100
cls_w_dict = cls_w_dict.to_dict()
cls_w = np.array([cls_w_dict[i] for i in grapheme_ind])

##check onehot correct?
#train_label_df.loc[train_label_arr[:,3]==1, ]

In [8]:
##data augmentation --cutmix, mixup
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    #cut_rat = np.sqrt(1. - lam)
    cut_rat = lam
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform, ignore edge area
    cx = np.random.randint(W//4, W*3//4)
    cy = np.random.randint(H//4, H*3//4)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    #print(bbx1, bby1, bbx2, bby2)
    return bbx1, bby1, bbx2, bby2

def cutmix(data, target, alpha=1.0):
    targets1, targets2, targets3 = target[:,0], target[:,1], target[:,2]
    indices = torch.randperm(data.size(0))
    shuffled_data = data[indices]
    shuffled_targets1 = targets1[indices]
    shuffled_targets2 = targets2[indices]
    shuffled_targets3 = targets3[indices]

    #lam = np.random.beta(alpha, alpha)
    lam = np.sqrt(np.random.rand()/4)
    bbx1, bby1, bbx2, bby2 = rand_bbox(data.size(), lam)
    data[:, :, bbx1:bbx2, bby1:bby2] = data[indices, :, bbx1:bbx2, bby1:bby2]
    #data[:, :, bbx1:bbx2, bby1:bby2] += data[indices, :, bbx1:bbx2, bby1:bby2]
    #data = torch.clamp(data, 0, 1)
    # adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (data.size()[-1] * data.size()[-2]))

    targets = [targets1, shuffled_targets1, targets2, shuffled_targets2, 
               targets3, shuffled_targets3, lam]
    return data, targets

def mixup(data, target, alpha=0.4):
    targets1, targets2, targets3 = target[:,0], target[:,1], target[:,2]
    indices = torch.randperm(data.size(0))
    shuffled_data = data[indices]
    shuffled_targets1 = targets1[indices]
    shuffled_targets2 = targets2[indices]
    shuffled_targets3 = targets3[indices]

    lam = np.random.beta(alpha, alpha)
    data = data * lam + shuffled_data * (1 - lam)
    targets = [targets1, shuffled_targets1, targets2, shuffled_targets2, 
               targets3, shuffled_targets3, lam]

    return data, targets

In [9]:
# kf = KFold(n_splits=5, shuffle=True, random_state=SEED).split(X=train_images_arr, y=train_label_arr)

# for fold, (train_idx, valid_idx) in enumerate(kf):
    
#     if fold in [0]:#train 1 fold for testing ideas
#         print('========training fold %d========'%fold)
#         print(train_idx)
#         print(valid_idx)
        
#         #1.1 data
#         train_inputs, valid_inputs = train_images_arr[train_idx], train_images_arr[valid_idx]
#         train_outputs, valid_outputs = train_label_arr[train_idx], train_label_arr[valid_idx]
#         #1.2 Dataset, DataLoader
#         train_dl = prepare_dataset(train_inputs, train_outputs, mode='train', debug=debug)
#         val_dl = prepare_dataset(valid_inputs, valid_outputs, mode='valid', debug=debug)

# ####
# for batch_id, (images, labels) in enumerate(train_dl):
#     inputs = images.to(device=device, dtype=torch.float)
#     truth = labels.to(device=device, dtype=torch.float)
#     if np.random.rand()<1.0:
#         inputs, truth = cutmix(inputs, truth, alpha=None)
#     else:
#         inputs, truth = mixup(inputs, truth, alpha=0.4)
#     if batch_id==1:
#         break

# #print(inputs.shape, truth.shape)

# show_inputs = inputs.cpu().numpy()[np.random.choice(BATCH_SIZE, 20, replace=False), :, :, :]
# fig,axes = plt.subplots(5,4, figsize=(10,8))
# for i in range(20):
#     axes[i//4, i%4].imshow(show_inputs[i, 0], cmap='binary')

In [10]:
##crop to 128x128
def bbox(img):
    rows = np.any(img, axis=1)
    cols = np.any(img, axis=0)
    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]
    return rmin, rmax, cmin, cmax

def crop_resize(img0, size=128, pad=16):
    #crop a box around pixels large than the threshold 
    #some images contain line at the sides
    ymin,ymax,xmin,xmax = bbox(img0[5:-5,5:-5] > 80)
    #cropping may cut too much, so we need to add it back
    xmin = xmin - 13 if (xmin > 13) else 0
    ymin = ymin - 10 if (ymin > 10) else 0
    xmax = xmax + 13 if (xmax < IMG_WIDTH - 13) else IMG_WIDTH
    ymax = ymax + 10 if (ymax < IMG_HEIGHT - 10) else IMG_HEIGHT
    img = img0[ymin:ymax,xmin:xmax]
    #remove lo intensity pixels as noise
    img[img < 28] = 0
    lx, ly = xmax-xmin,ymax-ymin
    l = max(lx,ly) + pad
    #make sure that the aspect ratio is kept in rescaling
    img = np.pad(img, [((l-ly)//2,), ((l-lx)//2,)], mode='constant')
    return cv2.resize(img,(size,size))

In [11]:
#img0 = train_images_arr[0]
#img0 = 255 - img0
#img0 = (img0*(255.0/img0.max())).astype(np.uint8)
#plt.imshow(crop_resize(rotate(img0, angle=20, reshape=False)))

In [12]:
# #### make up a weighted/balanced data sampler --for mixup/cutmix ####
# #weights for 168 classes--graphene_root
# import torch.utils.data
# _weights = [6.8,6.9,3,3.1,3,5.7,3.2,6.5,6.4,2.3,6.6,6.6,6.8,0.2,1.3,0.9,1.1,1.3,0.6,3.6,3,1.1,0.3,0.2,3,0.9,5.8,3.3,1.3,0.4,2.3,1.3,0.9,7.4,3.6,2.1,1,3.5,0.3,1.6,1.3,3.3,0.5,0.3,0.9,6.9,1.7,2.2,0.7,3.1,1.4,3.1,1.1,0.3,1.7,0.6,0.4,1.6,0.8,0.4,2.3,1.7,1.2,6.7,0.2,0.7,1.3,2.1,1.6,1.3,1,0.3,0.2,7.7,0.7,0.9,0.5,1,3.4,0.3,2.2,0.3,3.4,0.7,2.2,0.7,0.5,6,1.3,0.4,1.6,0.6,0.9,1.6,1,1.4,0.2,2.1,1.6,2.2,2.2,0.9,7.1,0.3,6.2,6.6,1.3,0.2,6.3,1.1,2.9,1.3,1.1,0.2,6.7,0.2,2.3,0.7,0.9,0.7,0.8,2.2,0.4,0.5,0.5,1.2,6.3,1.1,1.1,1,6.9,2.3,1,0.2,1.6,1.6,1,1.8,1.1,0.4,1.1,0.6,0.9,1.6,1.6,3.2,3.3,0.2,0.6,0.4,0.4,0.8,1.6,0.6,1.4,1.1,1.3,3.1,7,0.3,2.1,3.2,2.2,6.1,6.1,0.9,3.3,0.6]
# weights_dict = dict(zip(range(len(_weights)), _weights))

In [13]:
from torch.utils.data import DataLoader, Dataset
from utils import set_logger, save_checkpoint, load_checkpoint
import logging
#import gc
from scipy.ndimage.interpolation import rotate
import cv2


def prepare_dataset(img_arr, label_arr, cls_w=None, mode='train', debug=False):
    """
    mode: 'train', 'valid', 'test'
    """
    if debug:#for debug, sample 1/10 data
        n = img_arr.shape[0]
        sid = np.random.choice(n, size=n//5, replace=False)
        img_arr = img_arr[sid]
        label_arr = label_arr[sid]

    if mode=='train':
        ds = DatasetV1(img_arr, label_arr, mode='train', augmentation=True)
        sampler = torch.utils.data.WeightedRandomSampler(cls_w, num_samples=len(ds), replacement=True)
        dl = DataLoader(ds,
                        batch_size=BATCH_SIZE,
                        shuffle=False,
                        sampler=sampler,
                        num_workers=NUM_WORKERS,
                        drop_last=True
                       )
        return dl

    elif mode=='valid':
        ds = DatasetV1(img_arr, label_arr, mode='train', augmentation=False)
        dl = DataLoader(ds,
                        batch_size=BATCH_SIZE,
                        shuffle=False,
                        #sampler=sampler,
                        num_workers=NUM_WORKERS,
                        drop_last=True
                       )
    elif mode=='test':
        ds = DatasetV1(img_arr, label_arr, mode='test', augmentation=False)
        dl = DataLoader(ds,
                        batch_size=BATCH_SIZE,
                        shuffle=False,
                        #sampler=sampler,
                        num_workers=NUM_WORKERS,
                        drop_last=False
                       ) 
    return dl

class DatasetV1(Dataset):
    """plain"""
    def __init__(self, inputs, outputs, mode='train', augmentation=False):
        """
        inputs: images, (N, H, W)
        outputs: label, (N, 3)
        """
        self.inputs = inputs
        self.outputs = outputs
        self.mode = mode 
        self.augmentation = augmentation
    
    def __getitem__(self, idx):
        #TODO: augmentation, preprocessing
        inputs, outputs = self.inputs[idx], self.outputs[idx]
        inputs = 255 - inputs
        inputs = (inputs*(255.0/inputs.max())).astype(np.uint8)
        #inputs = cv2.resize(inputs, (224, 224))#128
        #inputs = np.clip(inputs, 0, 1)
        if self.augmentation:
            inputs = self.do_augmentation(inputs)
        #crop
        inputs = crop_resize(inputs)
        
        inputs = np.clip(inputs/255.0, 0, 1)
        
        inputs = np.expand_dims(inputs, 0)#(224,224)-->(1,224,224)
        return inputs, outputs

    def __len__(self):
        return self.inputs.shape[0]
    
    def do_augmentation(self, inputs):
        #rotate
        #if np.random.rand() < 0.5:
        angle = np.random.randint(0, 40) - 20
        inputs = rotate(inputs, angle, reshape=False)
        
        return inputs

def train_and_valid(net, train_dl, val_dl):
    """train one fold
    
    [settings]...
    
    [Epoch i]
        [Trainset]
            [Batch j]
        [Validset]
            [Batch k]
        [Logging/Checkpoint]
    """
    set_logger(LOG_PATH)
    logging.info('\n\n')
    #1. optim
    train_params = filter(lambda p: p.requires_grad, net.parameters())
    optimizer = torch.optim.Adam(train_params, lr=LearningRate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', 
                                                          factor=0.5, patience=5, 
                                                          verbose=False, threshold=0.0001, 
                                                          threshold_mode='rel', cooldown=0, 
                                                          min_lr=0, eps=1e-08)
    #1.1 warm-start
    if warm_start:
        logging.info('warm_start: '+last_checkpoint_path)
        net, _ = load_checkpoint(last_checkpoint_path, net)
    
    #2. using multi GPU
    if multi_gpu is not None:
        net = nn.DataParallel(net, device_ids=multi_gpu)
    #3. train
    diff = 0
    best_val_metric = np.inf
    optimizer.zero_grad()
    
    for i_epoch in range(NUM_EPOCHS):
        t0 = time.time()
        ## trainset -------------------------------------------------------------
        net.train()
        loss_logger = LossLogger()
        for batch_id, (images, labels) in enumerate(train_dl):
            inputs = images.to(device=device, dtype=torch.float)
            truth = labels.to(device=device, dtype=torch.float)
            
            #do cutmix/mixup
            if i_epoch<-1:#5
                mode = 0
                do_nothing = True
            else:
                mode = 1
                if np.random.rand()<0.5:
                    inputs, truth = cutmix(inputs, truth, alpha=None)
                else:
                    inputs, truth = mixup(inputs, truth, alpha=0.4)
            
            logit = net(inputs)
            logits = torch.split(logit, [n_grapheme, n_vowel, n_consonant], dim=1)
            train_loss = loss_logger.update(logits, truth, mode=mode)
            #grandient accumulation step=2
            acc_step = 1
            if acc_step>1:
                train_loss = train_loss / acc_step
            train_loss.backward()
            if (batch_id+1)%acc_step==0:
                optimizer.step()
                optimizer.zero_grad()
        train_loss_total, _, _, _ = loss_logger.aggregate()
        
#         ##check for memory leakage
#         for obj in gc.get_objects():
#             try:
#                 if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
#                     print(type(obj), obj.size())
#             except:
#                 pass
        ## validset -------------------------------------------------------------
        net.eval()
        loss_logger = LossLogger()
        metric_logger = MetricLogger()
        with torch.no_grad():
            for batch_id, (images, labels) in enumerate(val_dl):
                inputs = images.to(device=device, dtype=torch.float)
                truth = labels.to(device=device, dtype=torch.float)
                logit = net(inputs)
                logits = torch.split(logit, [n_grapheme, n_vowel, n_consonant], dim=1)
                _ = loss_logger.update(logits, truth, mode=0)
                metric_logger.update(logits, truth)
        rec, rec_grapheme, rec_vowel, rec_consonant = metric_logger.aggregate()
        loss_total, loss_grapheme, loss_vowel, loss_consonant = loss_logger.aggregate()
        
        ## callbacks -------------------------------------------------------------
        val_metric = loss_total#rec
        scheduler.step(val_metric)
        
        #sometimes too early stop, force to at least train N epochs
        if i_epoch>=40:#-1
            if val_metric < best_val_metric:
                best_val_metric = val_metric
                is_best = True
                diff = 0
            else:
                is_best = False
                diff += 1
                if diff > early_stopping_round:
                    logging.info('Early Stopping: val_metric does not increase %d rounds'%early_stopping_round)
                    break
        else:
            is_best = False
        
        #save checkpoint
        checkpoint_dict = \
        {
            'epoch': i_epoch,
            'state_dict': net.module.state_dict() if multi_gpu is not None else net.state_dict(),
            'optim_dict' : optimizer.state_dict(),
            'metrics': {'train_loss': train_loss_total, 'val_loss': loss_total, 
                        'val_metric': val_metric}
        }
        save_checkpoint(checkpoint_dict, is_best=is_best, checkpoint=checkpoint_path)
        
        #logging loss/metric
        logging.info('[EPOCH %05d]train_loss: %0.4f; val_loss: %0.4f; time elapsed: %0.1f min'%(i_epoch, 
                    train_loss_total, loss_total, (time.time()-t0)/60))
        logging.info('[valid loss]grapheme: %0.4f, vowel: %0.4f, consonant: %0.4f'%(loss_grapheme, 
                                                                        loss_vowel, loss_consonant))
        logging.info('[valid recall]total: %0.4f, grapheme: %0.4f, vowel: %0.4f, consonant: %0.4f'%(rec, 
                    rec_grapheme, rec_vowel, rec_consonant))
        logging.info('='*80)

def predict(test_dl):
    pass

In [14]:
from sklearn.metrics import recall_score

class LossLogger(object):
    """loss for an epoch
    
    [Epoch i]:
        loss_logger = LossLogger()
        
        [Batch j]:
            loss = loss_logger.update(logits, truth)
            loss.backward()
        
        loss, loss_grapheme, loss_vowel, loss_consonant = loss_logger.aggregate()
    """
    def __init__(self):
        self.loss_grapheme = []
        self.loss_vowel = []
        self.loss_consonant = []

    def update(self, logits, truth, mode=0):
        """
        logits: logit splitted to [logit_grapheme, logit_vowel, logit_consonant]
        truth: shape (N, 3)
        """
        if mode==0:
            #loss
            loss_grapheme = F.cross_entropy(logits[0], truth[:,0].long())
            loss_vowel = F.cross_entropy(logits[1], truth[:,1].long())
            loss_consonant = F.cross_entropy(logits[2], truth[:,2].long())
            loss = 0.5*loss_grapheme + 0.25*loss_vowel + 0.25*loss_consonant
            #
            self.loss_grapheme.append(loss_grapheme.item())
            self.loss_vowel.append(loss_vowel.item())
            self.loss_consonant.append(loss_consonant.item())
            return loss
        elif mode==1:#cutmix/mixup loss
            truth1, truth2, truth3, truth4, truth5, truth6, lam = \
                    truth[0], truth[1], truth[2], truth[3], truth[4], truth[5], truth[6]
            criterion = nn.CrossEntropyLoss(reduction='mean')
            loss_grapheme = lam * criterion(logits[0], truth1.long()) + \
                (1 - lam) * criterion(logits[0], truth2.long())
            loss_vowel = lam * criterion(logits[1], truth3.long()) + \
                (1 - lam) * criterion(logits[1], truth4.long())
            loss_consonant = lam * criterion(logits[2], truth5.long()) + \
                (1 - lam) * criterion(logits[2], truth6.long())
            loss = 0.5*loss_grapheme + 0.25*loss_vowel + 0.25*loss_consonant
            #
            self.loss_grapheme.append(loss_grapheme.item())
            self.loss_vowel.append(loss_vowel.item())
            self.loss_consonant.append(loss_consonant.item())
            return loss
    
    def aggregate(self):
        """
        for print logging
        """
        loss_grapheme = np.mean(self.loss_grapheme)
        loss_vowel = np.mean(self.loss_vowel)
        loss_consonant = np.mean(self.loss_consonant)
        loss_total = np.mean(
            0.5*np.array(self.loss_grapheme) + \
            0.25*np.array(self.loss_vowel) + \
            0.25*np.array(self.loss_consonant)
        )
        return loss_total, loss_grapheme, loss_vowel, loss_consonant

class MetricLogger(object):
    """recall, precision for an epoch
    
    [Epoch i]:
        metric_logger = MetricLogger()
        
        [Batch j]:
            metric_logger.update(logits, truth)
        
        rec, rec_grapheme, rec_vowel, rec_consonant = metric_logger.aggregate()
    """
    def __init__(self):
        self.pred_grapheme = torch.tensor([], dtype=torch.long).cuda(device)
        self.pred_vowel = torch.tensor([], dtype=torch.long).cuda(device=device)
        self.pred_consonant = torch.tensor([], dtype=torch.long).cuda(device=device)
        
        self.truth_grapheme = torch.tensor([], dtype=torch.long).cuda(device=device)
        self.truth_vowel = torch.tensor([], dtype=torch.long).cuda(device=device)
        self.truth_consonant = torch.tensor([], dtype=torch.long).cuda(device=device)

    def update(self, logits, truth):
        pred = torch.argmax(logits[0], dim=1)
        self.pred_grapheme = torch.cat([self.pred_grapheme, pred])
        self.truth_grapheme = torch.cat([self.truth_grapheme, truth[:, 0].long()])
        #
        pred = torch.argmax(logits[1], dim=1)
        self.pred_vowel = torch.cat([self.pred_vowel, pred])
        self.truth_vowel = torch.cat([self.truth_vowel, truth[:, 1].long()])
        #
        pred = torch.argmax(logits[2], dim=1)
        self.pred_consonant = torch.cat([self.pred_consonant, pred])
        self.truth_consonant = torch.cat([self.truth_consonant, truth[:, 2].long()])

    def aggregate(self):
        rec_grapheme = recall_score(self.truth_grapheme.cpu().numpy(), 
                                    self.pred_grapheme.cpu().numpy(), 
                                    average='macro')
        rec_vowel = recall_score(self.truth_vowel.cpu().numpy(), 
                                 self.pred_vowel.cpu().numpy(), 
                                 average='macro')
        rec_consonant = recall_score(self.truth_consonant.cpu().numpy(), 
                                     self.pred_consonant.cpu().numpy(), 
                                     average='macro')
        #rec = (2*rec_grapheme + 1*rec_vowel + 1*rec_consonant) / 4
        rec = np.average([rec_grapheme, rec_vowel, rec_consonant], weights=[2,1,1])
        return rec, rec_grapheme, rec_vowel, rec_consonant

# #debug MetricLogger
# #Epoch 0
# loss_logger = LossLogger()
# metric_logger = MetricLogger()
# for batch_id, (images, labels) in enumerate(val_dl):
#     inputs = images.to(device=device, dtype=torch.float)
#     truth = labels.to(device=device, dtype=torch.float)
#     if batch_id==10:
#         break
#     logit = net(inputs)
#     logits = torch.split(logit, [n_grapheme, n_vowel, n_consonant], dim=1)
#     loss = loss_logger.update(logits, truth)
#     metric_logger.update(logits, truth)
# rec, rec_grapheme, rec_vowel, rec_consonant = metric_logger.aggregate()
# loss_total, loss_grapheme, loss_vowel, loss_consonant = loss_logger.aggregate()
# print(rec, rec_grapheme, rec_vowel, rec_consonant)
# print(loss_total, loss_grapheme, loss_vowel, loss_consonant)

In [15]:
# from senet import se_resnext50_32x4d
from senet_v2 import se_resnext50_32x4d

In [16]:
#### training for 5 folds here ####
kf = KFold(n_splits=5, shuffle=True, random_state=SEED).split(X=train_images_arr, y=train_label_arr)

for fold, (train_idx, valid_idx) in enumerate(kf):
    
    if fold in [0]:#train 1 fold for testing ideas
        print('========training fold %d========'%fold)
        print(train_idx)
        print(valid_idx)
        
        #1.1 data
        train_inputs, valid_inputs = train_images_arr[train_idx], train_images_arr[valid_idx]
        train_outputs, valid_outputs = train_label_arr[train_idx], train_label_arr[valid_idx]
        train_cls_w = cls_w[train_idx]
        #1.2 Dataset, DataLoader
        train_dl = prepare_dataset(train_inputs, train_outputs, cls_w=train_cls_w, mode='train', debug=debug)
        val_dl = prepare_dataset(valid_inputs, valid_outputs, mode='valid', debug=debug)
        
        #2. model
        #net = se_resnext50_32x4d(num_classes=num_classes, pretrained=None).cuda(device=device)
        net = se_resnext50_32x4d(num_classes=num_classes, pretrained='imagenet').cuda(device=device)

        #3. train session
        train_and_valid(net, train_dl, val_dl)


[     0      1      2 ... 200835 200837 200839]
[     4      6     12 ... 200832 200836 200838]
model state_dict loaded.





[EPOCH 00000]train_loss: 2.1490; val_loss: 0.5532; time elapsed: 5.9 min
[valid loss]grapheme: 0.8221, vowel: 0.3107, consonant: 0.2579
[valid recall]total: 0.8203, grapheme: 0.7796, vowel: 0.9055, consonant: 0.8167
[EPOCH 00001]train_loss: 1.4241; val_loss: 0.3270; time elapsed: 5.8 min
[valid loss]grapheme: 0.4488, vowel: 0.2250, consonant: 0.1854
[valid recall]total: 0.9028, grapheme: 0.8672, vowel: 0.9526, consonant: 0.9243
[EPOCH 00002]train_loss: 1.2960; val_loss: 0.2591; time elapsed: 5.8 min
[valid loss]grapheme: 0.3548, vowel: 0.1828, consonant: 0.1438
[valid recall]total: 0.9270, grapheme: 0.8981, vowel: 0.9613, consonant: 0.9507


KeyboardInterrupt: 

In [None]:
#best CV: .9757 100 epochs, .9754 60 epochs

In [None]:
# fold 0
# [     0      1      2 ... 200835 200837 200839]
# [     4      6     12 ... 200832 200836 200838]

### check dataset/network/etc.

In [None]:
# kf = KFold(n_splits=5, shuffle=True, random_state=SEED).split(X=train_images_arr, y=train_label_arr)

# for fold, (train_idx, valid_idx) in enumerate(kf):
    
#     if fold in [0]:#train 1 fold for testing ideas
#         print('========training fold %d========'%fold)
#         print(train_idx)
#         print(valid_idx)
        
#         #1.1 data
#         train_inputs, valid_inputs = train_images_arr[train_idx], train_images_arr[valid_idx]
#         train_outputs, valid_outputs = train_label_arr[train_idx], train_label_arr[valid_idx]
#         #1.2 Dataset, DataLoader
#         train_dl = prepare_dataset(train_inputs, train_outputs, mode='train', debug=debug)
#         val_dl = prepare_dataset(valid_inputs, valid_outputs, mode='valid', debug=debug)

# ####
# for batch_id, (images, labels) in enumerate(train_dl):
#     inputs = images.to(device=device, dtype=torch.float)
#     truth = labels.to(device=device, dtype=torch.float)
#     if np.random.rand()<1.0:
#         inputs, truth = cutmix(inputs, truth, alpha=None)
#     else:
#         inputs, truth = mixup(inputs, truth, alpha=0.4)
#     if batch_id==1:
#         break

# #print(inputs.shape, truth.shape)

In [None]:
# show_inputs = inputs.cpu().numpy()[np.random.choice(BATCH_SIZE, 20), :, :, :]
# fig,axes = plt.subplots(5,4, figsize=(10,8))
# for i in range(20):
#     axes[i//4, i%4].imshow(show_inputs[i, 0], cmap='binary')

In [None]:
# import importlib
# import senet_v2
# importlib.reload(senet_v2)

# import torch

# from senet_v2 import se_resnext50_32x4d

# net = se_resnext50_32x4d(num_classes=186, pretrained='imagenet', debug=True).cuda(device='cuda:2')

# inputs = torch.rand((128, 1, 137, 236), dtype=torch.float).cuda(device='cuda:2')
# print(inputs.size())

# logit = net(inputs)
# print(logit.shape)