In [1]:
%cd ../src

C:\Users\nozoe-tatsuya\dev\ai-ocr-ensemble\src


In [2]:
import torch
import torch.nn as nn
import pytorch_lightning as pl

from pathlib import Path

import dataprocess as dp
import models
from args import Args

In [3]:
# デフォルト値以外のパラメータ設定
# デフォルト値はArgsの定義を参照

arg_params = {
    'epochs': 100,
    'batch_size': 512,
    'nb_classes': 17,
    'eval_interval': 5,
    # params of gMLP
    'patch_size': 4,
    'num_blocks': 30,
    'd_model': 128,
    'mlp_ratio': 6,

    'device': 'cuda',
    'log_dir': '../logs',  # logのルートディレクトリ
    'project': 'model-ensemble',  # logファイルのディレクトリ
    'model': 'stacking',
    'version': 'modify-CvT-210615',
    'json_path': '../data/dataset_info_1.json',
}

In [4]:
args = Args(**arg_params)
device = torch.device(args.device)
args, device

(Args(batch_size=512, nb_classes=17, epochs=100, ckpt_interval=20, eval_interval=5, device='cuda', log_dir='../logs', model_dir='../models', project='model-ensemble', model='stacking', version='modify-CvT-210615', resnet_layers=18, use_se_module=False, activation='ReLU', patch_size=4, num_blocks=30, d_model=128, mlp_ratio=6, json_path='../data/dataset_info_1.json', lr=0.001, weight_decay=0.0001, optimizer='adaberief', final_lr=0.1, scheduler=None, warmup=True, warmup_epoch=5),
 device(type='cuda'))

In [5]:
# データセット設定ファイル
json_path = Path('../data/dataset_info_1.json')
assert json_path.is_file()

### DataLoader

In [6]:
train_loader = dp.build_dataloader('train', args)
val_loader = dp.build_dataloader('val', args)

### load pretrained models

In [7]:
resnet = models.build_resnet(args)
gmlp = models.build_gmlp(args)
cvt = models.build_cvt(args)

In [8]:
model_dir = Path(args.model_dir)
resnet_weights_path = model_dir / 'resnet_epoch200_210525.pth'
gmlp_weights_path = model_dir / 'gmlp_epoch100_210607.pth'
cvt_weights_path = model_dir / 'cvt_epoch200_210615.pth'
assert resnet_weights_path.is_file()
assert gmlp_weights_path.is_file()
assert cvt_weights_path.is_file()

In [9]:
models.utils.load_model_weights(resnet, resnet_weights_path)
models.utils.load_model_weights(gmlp, gmlp_weights_path)
models.utils.load_model_weights(cvt, cvt_weights_path)

### prepare for training

In [10]:
system = models.build_stacking_system([resnet, gmlp, cvt], args)

In [11]:
# for TensorBoard
logger = models.get_tensorboard_logger(args)

In [12]:
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor='valid_err_epoch',
    dirpath='../models',
    filename=args.model + '-' + args.version +  '-{epoch:03d}-{valid_err_epoch:.2f}',
    save_top_k=1,
    mode='err',
)



In [13]:
trainer = pl.Trainer(gpus=1, max_epochs=args.epochs, min_epochs=1,
                     logger=logger,
                     check_val_every_n_epoch=args.eval_interval,
                     callbacks=[checkpoint_callback])

GPU available: True, used: True
TPU available: None, using: 0 TPU cores


In [22]:
trainer.fit(system, train_loader, val_dataloaders=val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type             | Params
-------------------------------------------------------
0 | criterion         | CrossEntropyLoss | 0     
1 | train_metrics     | Accuracy         | 0     
2 | valid_metrics     | Accuracy         | 0     
3 | pretrained_models | ModuleList       | 35.4 M
4 | fc                | Linear           | 884   
-------------------------------------------------------
884       Trainable params
35.4 M    Non-trainable params
35.4 M    Total params
141.722   Total estimated model params size (MB)


[31mPlease check your arguments if you have upgraded adabelief-pytorch from version 0.0.5.
[31mModifications to default arguments:
[31m                           eps  weight_decouple    rectify
-----------------------  -----  -----------------  ---------
adabelief-pytorch=0.0.5  1e-08  False              False
>=0.1.0 (Current 0.2.0)  1e-16  True               True
[34mSGD better than Adam (e.g. CNN for Image Classification)    Adam better than SGD (e.g. Transformer, GAN)
----------------------------------------------------------  ----------------------------------------------
Recommended eps = 1e-8                                      Recommended eps = 1e-16
[34mFor a complete table of recommended hyperparameters, see
[34mhttps://github.com/juntang-zhuang/Adabelief-Optimizer
[32mYou can disable the log message by setting "print_change_log = False", though it is recommended to keep as a reminder.
[0m
Weight decoupling enabled in AdaBelief
Rectification enabled in AdaBelief


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

1

## Inference

In [15]:
results = models.infer(system, val_loader, device=device)

In [16]:
result_df, conf_matrix_df = models.build_inference(results, args, datatype='val', return_conf_matrix=True)

  conf_matrix_df.loc[:-1, 'recall'] = recall


In [17]:
conf_matrix_df

Unnamed: 0_level_0,Unnamed: 1_level_0,推論ラベル,推論ラベル,推論ラベル,推論ラベル,推論ラベル,推論ラベル,推論ラベル,推論ラベル,推論ラベル,推論ラベル,推論ラベル,推論ラベル,推論ラベル,推論ラベル,推論ラベル,推論ラベル,推論ラベル,recall
Unnamed: 0_level_1,Unnamed: 1_level_1,0,1,2,3,4,5,6,7,8,9,*,×,-,・,/,字,―,Unnamed: 19_level_1
正解ラベル,0,396.0,0.0,2.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.99
正解ラベル,1,0.0,398.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.995
正解ラベル,2,1.0,0.0,397.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,0.0,0.0,0.9925
正解ラベル,3,0.0,0.0,0.0,399.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.9975
正解ラベル,4,0.0,1.0,0.0,0.0,396.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,0.0,0.0,0.992481
正解ラベル,5,0.0,0.0,1.0,0.0,1.0,396.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.99
正解ラベル,6,0.0,0.0,0.0,0.0,0.0,0.0,401.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
正解ラベル,7,0.0,0.0,0.0,1.0,0.0,0.0,0.0,399.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.9975
正解ラベル,8,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,399.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.9975
正解ラベル,9,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,398.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.995


In [36]:
conf_matrix_df.to_excel('../results/conf_matrix_stacking_modigy_cvt_0615.xlsx')

In [23]:
result_df['correct'].mean()

0.9952102978596019

In [25]:
pred_labels = results['pred_labels']
true_labels = results['true_labels']
confidences = results['confidence']
corrects = results['correct']

In [33]:
threshold = 0.9996
filtering = confidences > threshold

In [30]:
corrects.mean()

0.9952102978596019

In [37]:
corrects[filtering].mean()

0.9996734693877551

In [34]:
corrects[filtering].mean()

0.9996785599485696

In [35]:
filtering.sum() / len(corrects)

0.9312977099236641

In [39]:
filtering.sum(), len(corrects)

(6125, 6681)

In [29]:
(~filtering).sum()

307