In [1]:
from keras import backend as K
from keras.engine.topology import Layer
import numpy as np
import tensorflow as tf

Using TensorFlow backend.


In [2]:
# model構築の準備
from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D,Input,Dropout,Activation
from keras.applications.mobilenet import MobileNet
from keras.applications.mobilenetv2 import MobileNetV2
from keras.applications.vgg16 import VGG16
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import Adam
from keras.callbacks import CSVLogger,EarlyStopping
#from livelossplot.keras import PlotLossesCallback
import numpy as np
%matplotlib inline

n_categories=5
batch_size=32
train_dir = './images'
#validation_dir = './images/val'
file_name='MobileNetV2_shape224'

In [8]:
%run arcface.py 

In [3]:
def create_mobilenet_with_arcface(n_categories, file_path=None):
    base_model=MobileNetV2(input_shape=(224,224,3),
                       weights='imagenet',
                       include_top=False)
    
    
    #add new layers instead of FC networks
    x = base_model.output
    yinput = Input(shape=(n_categories,))
    # stock hidden model
    hidden = GlobalAveragePooling2D()(x)
    # stock Feature extraction
    #x = Dropout(0.5)(hidden)
    x = Arcfacelayer(5, 30, 0.05)([hidden,yinput])
    #x = Dense(1024,activation='relu')(x)
    prediction = Activation('softmax')(x)
    model = Model(inputs=[base_model.input,yinput],outputs=prediction)
    
    if file_path:
        model.load_weights(file_path)
        print('weightは{}'.format(file_path))
    
    return model

In [4]:
def create_predict_model(n_categories, file_path):
    arcface_model = create_mobilenet_with_arcface(n_categories, file_path)
    predict_model = Model(arcface_model.get_layer(index=0).input, arcface_model.get_layer(index=-4).output)
    predict_model.summary()
    return predict_model

In [7]:
# cos sim numpy
def cosine_similarity(x1, x2):
    """
    input
    x1 : shape (n_sample, n_features)
    x2 : shape (n_classes, n_features)
    ------
    output
    cos : shape (n_sample, n_classes)
    """
    x1_norm = np.linalg.norm(x1,axis=1)
    x2_norm = np.linalg.norm(x2,axis=1)
    return np.dot(x1, x2.T)/(x1_norm*x2_norm+1e-10)

In [5]:
# 新しいimageをvectorにする
def predict_vector(predict_model, img_array):
    return predict_model.predict(img_array)

In [9]:
# new画像のcos類似度を比較して一番値が高いindexを取り出しその値が閾値を超えるならindexを閾値以下ならをNoneを返す
def judgment(predict_vector, hold_vector, thresh):
    """
    predict_vector : shape(1,1028)
    hold_vector : shape(5, 1028)
    """
    cos_similarity = cosine_similarity(predict_vector, hold_vector) # shape(1, 5)
    print(cos_similarity[0])
    # 最も値が高いindexを取得
    high_index = np.argmax(cos_similarity[0]) # int
    # cos類似度が閾値を超えるか
    if cos_similarity[0][high_index] > thresh:
        return high_index
    
    else:
        return None

In [10]:
# learn
model = create_mobilenet_with_arcface(5)

In [11]:
# train_test_split folder版
#%run gazo_sprit_many_class.py

In [12]:
class train_Generator_xandy(object): # rule1
    def __init__(self):
        datagen = ImageDataGenerator(
                             vertical_flip = False,
                             width_shift_range = 0.1,
                             height_shift_range = 0.1,
                             rescale=1.0/255.,
                             zoom_range=0.2,  
                             fill_mode = "constant", 
                             cval=0)
        train_generator=datagen.flow_from_directory(
          train_dir,
          target_size=(224,224),
          batch_size=25,
          class_mode='categorical',
          shuffle=True)
        
        self.gene = train_generator
        
    def __iter__(self):
    # __next__()はselfが実装してるのでそのままselfを返す
        return self
    
    def __next__(self): 
        X, Y = self.gene.next()
        return [X,Y], Y
    
    
class val_Generator_xandy(object):
    def __init__(self):
        validation_datagen=ImageDataGenerator(rescale=1.0/255.)
        
        validation_generator=validation_datagen.flow_from_directory(
            validation_dir,
            target_size=(224,224),
            batch_size=25,
            class_mode='categorical',
            shuffle=True)
            
        self.gene = validation_generator
        
    def __iter__(self):
    # __next__()はselfが実装してるのでそのままselfを返す
        return self
            
    def __next__(self): 
        X, Y = self.gene.next()
        return [X,Y], Y

train_dir = './zidolegi_data2/train'
validation_dir = './zidolegi_data2/validation'
train_gene = train_Generator_xandy()
val_gane = val_Generator_xandy()

Found 769 images belonging to 5 classes.
Found 165 images belonging to 5 classes.


In [13]:
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 224, 224, 3)  0                                            
__________________________________________________________________________________________________
Conv1_pad (ZeroPadding2D)       (None, 225, 225, 3)  0           input_1[0][0]                    
__________________________________________________________________________________________________
Conv1 (Conv2D)                  (None, 112, 112, 32) 864         Conv1_pad[0][0]                  
__________________________________________________________________________________________________
bn_Conv1 (BatchNormalization)   (None, 112, 112, 32) 128         Conv1[0][0]                      
__________________________________________________________________________________________________
Conv1_relu

In [18]:
#for i, layer in enumerate(model.layers):
    #print(i, layer)

In [21]:
# layerを徐々に解凍する
from keras import callbacks

touketulayerlists = [
   model.layers.index(model.get_layer("arcfacelayer_1")),
   model.layers.index(model.get_layer("block_16_expand")),
   model.layers.index(model.get_layer("block_15_expand")),
   model.layers.index(model.get_layer("block_14_expand")),
   model.layers.index(model.get_layer("block_13_expand")),
   model.layers.index(model.get_layer("block_12_expand")),
   model.layers.index(model.get_layer("block_11_expand")),
   model.layers.index(model.get_layer("block_10_expand")),
   model.layers.index(model.get_layer("block_9_expand")),
   model.layers.index(model.get_layer("block_8_expand")),
   model.layers.index(model.get_layer("block_7_expand")),
   model.layers.index(model.get_layer("block_6_expand"))
]

maenosavepath = None
for touketu in touketulayerlists:
    print('touketu{}'.format(touketu))
    
    modelsavepath = "zidolege_model/m02_fine{}kara_weights".format(touketu)
    if maenosavepath:
        model.load_weights(maenosavepath)
        
    maenosavepath = modelsavepath
    #凍結
    for layer in model.layers[:touketu]:
        layer.trainable=False
    for layer in model.layers[touketu:]:
        layer.trainable=True
        
    model.compile(optimizer=Adam(lr=0.001),
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    
    callbacks_list = [
        #バリデーションlossが改善したらモデルをsave
        callbacks.ModelCheckpoint(
        filepath=modelsavepath,
        monitor="val_loss",
        save_weights_only=True,
        save_best_only=True),
        
        #バリデーションlossが改善しなくなったら学習率を変更する
        callbacks.ReduceLROnPlateau(
            monitor="val_loss",
            factor=0.8,
            patience=5,
            verbose=1)]
    
    model.fit_generator(train_gene, steps_per_epoch=80, epochs=30, validation_steps=20, validation_data=val_gane, callbacks=callbacks_list)

touketu157
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30

Epoch 00021: ReduceLROnPlateau reducing learning rate to 0.000800000037997961.
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30

Epoch 00026: ReduceLROnPlateau reducing learning rate to 0.0006400000303983689.
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30
touketu144
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30

Epoch 00009: ReduceLROnPlateau reducing learning rate to 0.000800000037997961.
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30

Epoch 00023: ReduceLROnPlateau reducing learning rate to 0.0006400000303983689.
Epoch 24/30
Epoch 25/30
Epoch 26/30
Ep


Epoch 00025: ReduceLROnPlateau reducing learning rate to 0.0006400000303983689.
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30
touketu117
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30

Epoch 00006: ReduceLROnPlateau reducing learning rate to 0.000800000037997961.
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30

Epoch 00016: ReduceLROnPlateau reducing learning rate to 0.0006400000303983689.
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30

Epoch 00021: ReduceLROnPlateau reducing learning rate to 0.0005120000336319208.
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30

Epoch 00026: ReduceLROnPlateau reducing learning rate to 0.00040960004553198815.
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30
touketu108
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epo

Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30

Epoch 00024: ReduceLROnPlateau reducing learning rate to 0.0006400000303983689.
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30

Epoch 00029: ReduceLROnPlateau reducing learning rate to 0.0005120000336319208.
Epoch 30/30
touketu82
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30

Epoch 00006: ReduceLROnPlateau reducing learning rate to 0.000800000037997961.
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30

Epoch 00021: ReduceLROnPlateau reducing learning rate to 0.0006400000303983689.
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30

Epoch 00026: ReduceLROnPlateau reducing learning rate to 0.0005120000336319208.
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30
touketu73
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30

Epoch 00006: ReduceLROnPla

Epoch 16/30
Epoch 17/30
Epoch 18/30

Epoch 00018: ReduceLROnPlateau reducing learning rate to 0.0005120000336319208.
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30

Epoch 00023: ReduceLROnPlateau reducing learning rate to 0.00040960004553198815.
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30

Epoch 00028: ReduceLROnPlateau reducing learning rate to 0.00032768002711236477.
Epoch 29/30
Epoch 30/30
touketu64
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30

Epoch 00015: ReduceLROnPlateau reducing learning rate to 0.000800000037997961.
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30

Epoch 00027: ReduceLROnPlateau reducing learning rate to 0.0006400000303983689.
Epoch 28/30
Epoch 29/30
Epoch 30/30
touketu55
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30

Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30

Epoch 00016: ReduceLROnPlateau reducing learning rate to 0.0005120000336319208.
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30

Epoch 00021: ReduceLROnPlateau reducing learning rate to 0.00040960004553198815.
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30

Epoch 00026: ReduceLROnPlateau reducing learning rate to 0.00032768002711236477.
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30


In [6]:
import numpy as np
from keras.preprocessing.image import load_img, img_to_array
import os

def test_acc(model, test_dir, hold_dir, classes, thresh=0, sample=100):
    """
    テスト用
    model: 特徴抽出用モデル
    X: array
    test_dir: str 画像入ってるフォルダ
    hold_dir:str 登録データのフォルダ　ファイル名はclass名.jpgにしてください
    classes:　フォルダ名のリスト 
    """
    correct = 0
    hold_vector = get_hold_vector(model, classes, hold_dir)
    
    test_datagen=ImageDataGenerator(rescale=1.0/255.)
    test_generator=test_datagen.flow_from_directory(
            test_dir,
            target_size=(224,224),
            batch_size=1,
            class_mode='categorical',
            classes=classes)
    
    for i in range(sample):
        X, Y = test_generator.next()
        Y = np.argmax(Y, axis=1)
        predict_vector = model.predict(X)
        index = judgment(predict_vector,hold_vector, thresh)
        label_index = index // 4
        if Y == label_index:
            correct += 1
        
        print('label_index{}'.format(label_index))
        print('Y{}'.format(Y))
    acc = correct / sample
    print("acc: {}".format(acc))
    return acc

def cosine_similarity(x1, x2):
    x1_norm = np.linalg.norm(x1, axis=1)
    x2_norm = np.linalg.norm(x2, axis=1)
    cosine_sim = np.dot(x1, x2.T)/(x1_norm*x2_norm+1e-10)
    return cosine_sim

# new画像のcos類似度を比較して一番値が高いindexを取り出しその値が閾値を超えるならindexを閾値以下ならをNoneを返す

def judgment(predict_vector, hold_vector, thresh):
    """
    predict_vector : shape(1,1028)
    hold_vector : shape(5, 1028)
    """
    cos_similarity = cosine_similarity(predict_vector, hold_vector) # shape(1, 5)
    print('cos_similarity{}'.format(cos_similarity[0]))
    # 最も値が高いindexを取得
    high_index = np.argmax(cos_similarity[0]) # int

    # cos類似度が閾値を超えるか
    if cos_similarity[0, high_index] > thresh:
        #print('high_index{}'.format(high_index))
        return high_index

    else:
        return None

def get_hold_vector(model, classes, hold_dir):
    """
    classes: クラス名のリスト　イメージの名前はこのリスト名にしてください
    hold_dir: str イメージが入ったフォルダpath
    """
    img_array = np.empty((0, 224,224,3))

    for clas in classes:
        for i in range(4):
            imagepath = os.path.join(hold_dir, clas + str(i) +".jpg")
            img = load_img(imagepath, target_size=(224,224))
            array = img_to_array(img).reshape(1, 224, 224, 3)
            img_array = np.vstack((img_array, array))
        
    img_array = img_array/255.0
    hold_vector = model.predict(img_array)

    return hold_vector

In [9]:
# create_model
file_path = "./zidolege_model/m005_fine91kara_weights"
load_model = create_predict_model(5, file_path)

weightは./zidolege_model/m005_fine91kara_weights
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            (None, 224, 224, 3)  0                                            
__________________________________________________________________________________________________
Conv1_pad (ZeroPadding2D)       (None, 225, 225, 3)  0           input_3[0][0]                    
__________________________________________________________________________________________________
Conv1 (Conv2D)                  (None, 112, 112, 32) 864         Conv1_pad[0][0]                  
__________________________________________________________________________________________________
bn_Conv1 (BatchNormalization)   (None, 112, 112, 32) 128         Conv1[0][0]                      
_____________________________________________________________

In [13]:
hold_dir = "zidolegi_data2/hold"
test_dir = "zidolegi_data2/test"
classes = os.listdir(test_dir)
hold_vector = get_hold_vector(load_model, classes, hold_dir)
test_acc(load_model, test_dir, hold_dir, classes, sample=150)

Found 145 images belonging to 5 classes.
cos_similarity[0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.999998   0.9999979  0.9999991  0.99999833 0.         0.
 0.         0.        ]
label_index3
Y[3]
cos_similarity[0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
 0.0000000e+00 0.0000000e+00 4.2262673e-07 9.9998790e-01 9.9999815e-01
 9.9999696e-01 9.9994230e-01 0.0000000e+00 0.0000000e+00 0.0000000e+00
 1.9743919e-07 0.0000000e+00 0.0000000e+00 4.8567466e-07 0.0000000e+00]
label_index2
Y[2]
cos_similarity[0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
 0.0000000e+00 0.0000000e+00 1.4141209e-07 9.9998313e-01 9.9999976e-01
 9.9999905e-01 9.9993300e-01 0.0000000e+00 5.9497109e-11 0.0000000e+00
 7.9422037e-08 0.0000000e+00 0.0000000e+00 1.6250812e-07 0.0000000e+00]
label_index2
Y[2]
cos_similarity[0.         0.         0.         0.         0.         0.
 0.         0.     

 0.99999815 0.99999946]
label_index4
Y[4]
cos_similarity[0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         1.0000001  0.9999992
 0.99999887 0.9999999 ]
label_index4
Y[4]
cos_similarity[0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.9999969  0.99999934 0.9999997  0.99999917 0.         0.
 0.         0.        ]
label_index3
Y[3]
cos_similarity[0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.99999905 0.9999997
 0.99999815 0.99999946]
label_index4
Y[4]
cos_similarity[0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
 0.0000000e+00 0.0000000e+00 2.0990581e-07 9.9998373e-01 9.9999905e-01
 9.9999744e-01 9.9993807e-01 0.0000000e+00 0.0000000e+00 0.0000000e+00
 9.8062046e-08 0.0

cos_similarity[0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 9.9999738e-01
 9.9999887e-01 9.9999833e-01 9.9999434e-01 1.3318920e-05 0.0000000e+00
 0.0000000e+00 8.4972894e-03 0.0000000e+00 0.0000000e+00 0.0000000e+00
 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00]
label_index1
Y[1]
cos_similarity[0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 2.23544205e-07
 9.99983251e-01 9.99998152e-01 9.99996245e-01 9.99940693e-01
 0.00000000e+00 0.00000000e+00 0.00000000e+00 1.04433504e-07
 0.00000000e+00 0.00000000e+00 2.56892804e-07 0.00000000e+00]
label_index2
Y[2]
cos_similarity[0.        0.        0.        0.        0.        0.        0.
 0.        0.        0.        0.        0.        0.9999995 0.9999959
 0.9999963 0.9999948 0.        0.        0.        0.       ]
label_index3
Y[3]
cos_similarity[0.9999961  0.9999946  0.99999774 0.9999972  0.         0.
 0.         0.         0.         0.       

 3.6649112e-07 0.0000000e+00 0.0000000e+00 9.0152031e-07 0.0000000e+00]
label_index2
Y[2]
cos_similarity[0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
 0.0000000e+00 0.0000000e+00 5.7936916e-07 9.9998987e-01 9.9999666e-01
 9.9999696e-01 9.9994051e-01 0.0000000e+00 0.0000000e+00 0.0000000e+00
 2.7066483e-07 0.0000000e+00 0.0000000e+00 6.6580014e-07 0.0000000e+00]
label_index2
Y[2]
cos_similarity[0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
 0.0000000e+00 0.0000000e+00 1.1584870e-08 2.6217097e-07 0.0000000e+00
 2.3902139e-08 2.5954569e-07 0.0000000e+00 0.0000000e+00 0.0000000e+00
 5.4121227e-09 9.9999970e-01 9.9999934e-01 9.9999821e-01 9.9999982e-01]
label_index4
Y[4]
cos_similarity[0.99999636 0.99999917 0.99999857 0.99999714 0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.        ]
label_index0
Y[0]
cos_similarity[0.0000000e+00 0.0000000e+00 

 0.        0.        0.        0.        0.        0.       ]
label_index0
Y[0]
cos_similarity[0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
 0.0000000e+00 0.0000000e+00 3.6549846e-08 8.2713990e-07 0.0000000e+00
 7.5410384e-08 8.1885725e-07 9.9999738e-01 9.9999833e-01 9.9999940e-01
 9.9999869e-01 0.0000000e+00 0.0000000e+00 4.2002398e-08 0.0000000e+00]
label_index3
Y[3]
cos_similarity[0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.99999905 0.9999998
 0.9999982  0.99999946]
label_index4
Y[4]
cos_similarity[0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.9999999  0.9999995
 0.99999845 0.9999998 ]
label_index4
Y[4]
cos_similarity[0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.   

1.0

In [11]:
classes

['calpis', 'ilohas_normal', 'ilohas_peach', 'mitsuya', 'pocari']

In [15]:
hold_vector.shape

(20, 1280)

In [17]:
drink_dict = {"calpis":5, "ilohas_peach":4, "mitsuya":3, "ilohas_normal":2, "pocari":1}

In [21]:
import pickle
def pickel_hold_vector(hold_vector, classes, num_image=4):
    #ファイルに書き込み
    name_list = []
    for clas in classes:
        name_list += ["zidolegi_data2/feature/{}_{}_feature.dump".format(drink_dict[clas], i) for i in range(num_image)]
        
    for vec, name in zip(hold_vector, name_list):
        with open(name , 'wb') as f:
            pickle.dump(vec, f)


In [22]:
pickel_hold_vector(hold_vector, classes)