# 作成したモデルでの正誤を一覧にする

一旦モデル設計とのコードと分離したかったため、.h5ファイルを読み込んで使用する。

In [1]:
import os
import re
# CPUの利用を強制する場合
#os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

import shutil
import random
import tensorflow as tf
import keras
from keras import Sequential
from keras.models import Model
from keras.layers import Conv2D, MaxPooling2D, Input, Dense, GlobalAveragePooling2D
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import TensorBoard
from keras.optimizers import SGD

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

pd.set_option('display.max_rows', 200)

Using TensorFlow backend.


In [7]:
# 環境変数とか

# 読み込むモデル
model_path = '.\\models\\model.h5'


# 元データ保存先
dataset_base_path = '.\\splat-scene-dataset'
dataset_split_base_path = '.\\dataset'
tensorboard_log_path = '.\\tflog'

# 画像設定
input_size = (80, 45)
input_shape = (80, 45, 3)

# データ関係
batch_size = 1 # 1回でいい
categories_n = 17

dataset_train_path = os.path.join(dataset_split_base_path, 'train')
dataset_val_path   = os.path.join(dataset_split_base_path, 'val')
dataset_test_path  = os.path.join(dataset_split_base_path, 'test')
pathes = [dataset_train_path, dataset_val_path, dataset_test_path]

In [8]:
def load_model(path):
    return keras.models.load_model(path)

model = load_model(model_path)
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 80, 45, 3)         0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 80, 45, 64)        1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 80, 45, 64)        36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 40, 22, 64)        0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 40, 22, 128)       73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 40, 22, 128)       147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 20, 11, 128)       0         
__________

In [9]:
# ImageDataGeneratorを生成
def create_generator(path,
                     target_size,
                     batch_size = 1,
                     class_mode = 'categorical'):
    print(path)
    dg = ImageDataGenerator(rescale=1/255.0)
    gen = dg.flow_from_directory(path, 
                                 target_size=target_size,
                                 batch_size=batch_size,
                                 class_mode=class_mode,
                                 shuffle=False)
    return (dg, gen)

(train_dg, train_gen) = create_generator(dataset_train_path, target_size=input_size, batch_size=batch_size)
(val_dg,   val_gen)   = create_generator(dataset_val_path, target_size=input_size, batch_size=batch_size)
(test_dg,  test_gen)  = create_generator(dataset_test_path, target_size=input_size, batch_size=batch_size)

generators = [train_gen, val_gen, test_gen]

.\dataset\train
Found 10231 images belonging to 17 classes.
.\dataset\val
Found 3406 images belonging to 17 classes.
.\dataset\test
Found 3406 images belonging to 17 classes.


In [10]:
predicts = [model.predict_generator(g, len(g.filenames), verbose=1) for g in generators]
predicts



[array([[9.9992669e-01, 9.4142285e-09, 6.7294813e-13, ..., 1.9919972e-11,
         9.9119775e-12, 7.7481570e-13],
        [9.9999952e-01, 1.0768900e-11, 6.5192194e-09, ..., 6.7203036e-12,
         2.3322017e-15, 2.2510327e-12],
        [9.9999976e-01, 1.4061428e-08, 3.7634201e-12, ..., 2.3442753e-11,
         1.1582858e-09, 4.9358295e-12],
        ...,
        [1.3502877e-05, 1.9559741e-06, 3.5771166e-07, ..., 1.3140851e-07,
         4.2098411e-10, 9.9726534e-01],
        [1.9689092e-05, 1.7569361e-06, 1.4526995e-07, ..., 1.7556890e-07,
         4.1650114e-10, 9.9706763e-01],
        [9.0788726e-06, 2.4718563e-06, 2.6627191e-07, ..., 1.6819654e-07,
         3.1636482e-10, 9.9869376e-01]], dtype=float32),
 array([[9.9440837e-01, 1.1689134e-06, 4.8308745e-08, ..., 1.4952057e-06,
         7.7833789e-08, 5.4410997e-08],
        [9.9996352e-01, 8.9950788e-08, 1.2139269e-08, ..., 2.7310804e-08,
         1.3904674e-06, 1.3642547e-10],
        [9.9999714e-01, 9.6358059e-09, 2.0086785e-10, ...,

In [11]:
# ファイルごとに予測結果を出力させる
def create_result(generators, predicts):
    # label_name: indexのdictなので、順に並び替えて使えるようにしておく
    labels = [x[0] for x in sorted(generators[0].class_indices.items(), key=lambda x: x[1])]
    print(labels)
    def generate(generators, predicts):
        for g,ps in zip(generators, predicts):
            files = g.filenames
            for file, predict in zip(files, ps):
                # 基本譲歩を追加
                dst = {'base_dir': g.directory, 'file': file, 'path': os.path.join(g.directory, file), }
                # 予測情報を追加
                for label, p in zip(labels, predict):
                    dst['predict_{}'.format(label)] = p
                # 予測結果が一致しているかを追加
                index = predict.argmax()
                label_correct = re.search(r'[^\\]+', file).group(0)
                label_predict = labels[index]
                is_correct = label_predict == label_correct
                
                dst['label_predict'] = label_predict
                dst['label_correct'] = label_correct
                dst['is_correct'] = is_correct
                yield dst
    return pd.DataFrame(generate(generators, predicts))

result = create_result(generators, predicts)

# 結果をCSV保存しとく
result.to_csv('predict_all.csv')
result[result['is_correct'] == False].to_csv('predict_error.csv')

result

['battle', 'battle_finish', 'battle_loby', 'battle_matching', 'battle_result', 'battle_rule', 'battle_start', 'loading', 'menu', 'other', 'salmon', 'salmon_lobby', 'salmon_matching', 'salmon_miss', 'salmon_result', 'salmon_start', 'weapon_select']


Unnamed: 0,base_dir,file,is_correct,label_correct,label_predict,path,predict_battle,predict_battle_finish,predict_battle_loby,predict_battle_matching,...,predict_loading,predict_menu,predict_other,predict_salmon,predict_salmon_lobby,predict_salmon_matching,predict_salmon_miss,predict_salmon_result,predict_salmon_start,predict_weapon_select
0,.\dataset\train,battle\WIN_20181022_15_12_23_Pro-00000044.jpg,True,battle,battle,.\dataset\train\battle\WIN_20181022_15_12_23_P...,9.999267e-01,9.414229e-09,6.729481e-13,1.499654e-15,...,6.018176e-14,8.489932e-11,3.121440e-09,7.100549e-05,5.053809e-09,1.075816e-11,7.281211e-08,1.991997e-11,9.911977e-12,7.748157e-13
1,.\dataset\train,battle\WIN_20181022_15_12_23_Pro-00000046.jpg,True,battle,battle,.\dataset\train\battle\WIN_20181022_15_12_23_P...,9.999995e-01,1.076890e-11,6.519219e-09,2.895360e-16,...,1.301098e-13,7.617067e-11,2.499345e-11,3.576956e-07,6.036873e-10,3.141283e-12,5.302428e-11,6.720304e-12,2.332202e-15,2.251033e-12
2,.\dataset\train,battle\WIN_20181022_15_12_23_Pro-00000049.jpg,True,battle,battle,.\dataset\train\battle\WIN_20181022_15_12_23_P...,9.999998e-01,1.406143e-08,3.763420e-12,1.634683e-12,...,1.893649e-11,3.580359e-11,1.164471e-08,7.180495e-08,7.315780e-08,8.239632e-10,2.232417e-08,2.344275e-11,1.158286e-09,4.935830e-12
3,.\dataset\train,battle\WIN_20181022_15_12_23_Pro-00000050.jpg,True,battle,battle,.\dataset\train\battle\WIN_20181022_15_12_23_P...,9.999993e-01,3.817827e-10,1.329851e-10,5.104198e-11,...,3.973827e-11,3.795608e-11,6.509688e-09,6.584544e-07,1.440981e-09,9.285323e-11,1.809324e-09,6.431028e-12,1.291296e-09,1.419188e-12
4,.\dataset\train,battle\WIN_20181022_15_12_23_Pro-00000051.jpg,True,battle,battle,.\dataset\train\battle\WIN_20181022_15_12_23_P...,9.999303e-01,1.825522e-09,2.494112e-09,2.593592e-11,...,7.907250e-12,2.879001e-10,1.943941e-08,6.968671e-05,1.780186e-08,3.679756e-10,1.728875e-08,3.768250e-09,4.157435e-09,1.811141e-11
5,.\dataset\train,battle\WIN_20181022_15_12_23_Pro-00000053.jpg,True,battle,battle,.\dataset\train\battle\WIN_20181022_15_12_23_P...,9.999962e-01,1.240528e-06,1.482103e-09,3.349563e-11,...,8.628395e-10,2.583452e-09,6.619464e-08,7.111892e-09,1.620676e-07,1.254874e-08,1.078379e-07,2.138351e-10,5.208441e-10,1.463128e-09
6,.\dataset\train,battle\WIN_20181022_15_12_23_Pro-00000058.jpg,True,battle,battle,.\dataset\train\battle\WIN_20181022_15_12_23_P...,1.000000e+00,9.694789e-09,1.059954e-11,3.750767e-15,...,3.602935e-13,2.593140e-11,1.039325e-09,1.729835e-10,1.556481e-09,9.354626e-10,1.147784e-10,2.529812e-10,5.450948e-11,2.399627e-10
7,.\dataset\train,battle\WIN_20181022_15_12_23_Pro-00000059.jpg,True,battle,battle,.\dataset\train\battle\WIN_20181022_15_12_23_P...,9.993590e-01,1.319989e-05,2.161693e-09,1.233674e-11,...,3.401690e-09,3.079563e-08,3.297565e-06,1.505417e-05,1.710389e-06,1.699872e-07,4.521536e-07,1.327162e-07,3.039477e-09,4.227877e-07
8,.\dataset\train,battle\WIN_20181022_15_12_23_Pro-00000061.jpg,True,battle,battle,.\dataset\train\battle\WIN_20181022_15_12_23_P...,9.999347e-01,2.431324e-09,1.170379e-10,3.922145e-13,...,4.956559e-14,2.123759e-09,7.970632e-10,6.532844e-05,6.909727e-10,7.837128e-10,9.393236e-11,7.589263e-12,1.168252e-13,2.430062e-10
9,.\dataset\train,battle\WIN_20181022_15_12_23_Pro-00000063.jpg,True,battle,battle,.\dataset\train\battle\WIN_20181022_15_12_23_P...,9.999995e-01,2.538418e-11,2.008587e-14,5.891333e-16,...,1.156873e-15,1.444023e-12,3.135403e-10,4.992285e-07,6.166536e-11,3.383347e-12,1.724999e-11,6.279212e-13,1.344524e-13,9.486357e-15


In [12]:
counts = pd.DataFrame(result.groupby(['label_correct', 'label_predict', 'is_correct', 'base_dir'])['is_correct'].count())
counts.to_csv('correct_count.csv')
counts

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,is_correct
label_correct,label_predict,is_correct,base_dir,Unnamed: 4_level_1
battle,battle,True,.\dataset\test,1911
battle,battle,True,.\dataset\train,5733
battle,battle,True,.\dataset\val,1866
battle,battle_result,False,.\dataset\val,5
battle,battle_rule,False,.\dataset\val,1
battle,battle_start,False,.\dataset\val,8
battle,salmon,False,.\dataset\train,1
battle,salmon,False,.\dataset\val,31
battle_finish,battle,False,.\dataset\val,2
battle_finish,battle_finish,True,.\dataset\test,43


In [13]:
# Tensorflow.js向けに出力
import tensorflowjs as tfjs
tfjs.converters.save_keras_model(model, 'models/js-model', quantization_dtype=np.uint8)