In [1]:
%cd ../src

C:\Users\nozoe-tatsuya\dev\ai-ocr-ensemble\src


In [2]:
import torch
import torch.nn as nn
import pytorch_lightning as pl

from pathlib import Path

import dataprocess as dp
import models
from args import Args

In [3]:
# デフォルト値以外のパラメータ設定
# デフォルト値はArgsの定義を参照

arg_params = {
    'epochs': 200,
    'batch_size': 512,
    'nb_classes': 17,
    'eval_interval': 5,
    'device': 'cuda',
    
    'log_dir': '../logs',  # logのルートディレクトリ
    'project': 'model-ensemble',  # logファイルのディレクトリ
    'model': 'CvT',
    'version': 'depth-13-3x3stem',
    'json_path': '../data/dataset_info_1.json'
}

In [4]:
args = Args(**arg_params)
device = torch.device(args.device)
args, device

(Args(batch_size=512, nb_classes=17, epochs=200, ckpt_interval=20, eval_interval=5, device='cuda', log_dir='../logs', model_dir='../models', project='model-ensemble', model='CvT', version='depth-13-3x3stem', resnet_layers=18, use_se_module=False, activation='ReLU', patch_size=16, num_blocks=30, d_model=128, mlp_ratio=6, json_path='../data/dataset_info_1.json', lr=0.001, weight_decay=0.0001, optimizer='adaberief', final_lr=0.1, scheduler=None, warmup=True, warmup_epoch=5),
 device(type='cuda'))

In [5]:
# データセット設定ファイル
json_path = Path(args.json_path)
assert json_path.is_file()

In [6]:
train_loader = dp.build_dataloader('train', args)
val_loader = dp.build_dataloader('val', args)

In [7]:
system = models.build_system(args)

In [8]:
# for TensorBoard
logger = models.get_tensorboard_logger(args)

In [9]:
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor='valid_err_epoch',
    dirpath='../models',
    filename=args.model + '-' + args.version + '-{epoch:03d}-{valid_err_epoch:.2f}',
    save_top_k=1,
    mode='min',
)



In [10]:
trainer = pl.Trainer(gpus=1, max_epochs=args.epochs, min_epochs=1,
                     logger=logger,
                     check_val_every_n_epoch=args.eval_interval,
                     callbacks=[checkpoint_callback])

GPU available: True, used: True
TPU available: None, using: 0 TPU cores


In [11]:
trainer.fit(system, train_loader, val_dataloaders=val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


[31mPlease check your arguments if you have upgraded adabelief-pytorch from version 0.0.5.
[31mModifications to default arguments:
[31m                           eps  weight_decouple    rectify
-----------------------  -----  -----------------  ---------
adabelief-pytorch=0.0.5  1e-08  False              False
>=0.1.0 (Current 0.2.0)  1e-16  True               True
[34mSGD better than Adam (e.g. CNN for Image Classification)    Adam better than SGD (e.g. Transformer, GAN)
----------------------------------------------------------  ----------------------------------------------
Recommended eps = 1e-8                                      Recommended eps = 1e-16
[34mFor a complete table of recommended hyperparameters, see
[34mhttps://github.com/juntang-zhuang/Adabelief-Optimizer
[32mYou can disable the log message by setting "print_change_log = False", though it is recommended to keep as a reminder.
[0m
Weight decoupling enabled in AdaBelief
Rectification enabled in AdaBelief



  | Name          | Type                           | Params
-----------------------------------------------------------------
0 | criterion     | CrossEntropyLoss               | 0     
1 | train_metrics | Accuracy                       | 0     
2 | valid_metrics | Accuracy                       | 0     
3 | model         | ConvolutionalVisionTransformer | 19.6 M
-----------------------------------------------------------------
19.6 M    Trainable params
0         Non-trainable params
19.6 M    Total params
78.445    Total estimated model params size (MB)


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




1

## Inference

In [12]:
results = models.infer(system, val_loader, device=device)

In [13]:
result_df, conf_matrix_df = models.build_inference(results, args, datatype='val', return_conf_matrix=True)

  conf_matrix_df.loc[:-1, 'recall'] = recall


In [14]:
conf_matrix_df

Unnamed: 0_level_0,Unnamed: 1_level_0,推論ラベル,推論ラベル,推論ラベル,推論ラベル,推論ラベル,推論ラベル,推論ラベル,推論ラベル,推論ラベル,推論ラベル,推論ラベル,推論ラベル,推論ラベル,推論ラベル,推論ラベル,推論ラベル,推論ラベル,recall
Unnamed: 0_level_1,Unnamed: 1_level_1,0,1,2,3,4,5,6,7,8,9,*,×,-,・,/,字,―,Unnamed: 19_level_1
正解ラベル,0,391.0,0.0,2.0,0.0,0.0,0.0,1.0,2.0,0.0,1.0,0.0,0.0,0.0,0.0,2.0,1.0,0.0,0.9775
正解ラベル,1,0.0,396.0,0.0,0.0,0.0,0.0,0.0,3.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.99
正解ラベル,2,1.0,0.0,393.0,1.0,0.0,0.0,0.0,3.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.9825
正解ラベル,3,0.0,0.0,1.0,397.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.9925
正解ラベル,4,0.0,1.0,0.0,0.0,391.0,0.0,1.0,0.0,0.0,3.0,0.0,0.0,0.0,0.0,3.0,0.0,0.0,0.97995
正解ラベル,5,0.0,0.0,1.0,0.0,0.0,395.0,1.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.9875
正解ラベル,6,1.0,0.0,0.0,0.0,0.0,0.0,400.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.997506
正解ラベル,7,1.0,1.0,0.0,0.0,0.0,0.0,0.0,397.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.9925
正解ラベル,8,1.0,0.0,0.0,2.0,0.0,0.0,0.0,0.0,391.0,0.0,0.0,1.0,0.0,0.0,5.0,0.0,0.0,0.9775
正解ラベル,9,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,397.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.9925


In [16]:
conf_matrix_df.to_excel('../results/conf_matrix_cvt_0615.xlsx')

In [15]:
result_df['correct'].mean()

0.9893728483759916

In [17]:
torch.save(system.model.state_dict(), '../models/cvt_epoch200_210615.pth', pickle_protocol=4)