## Work
1. 請嘗試寫一個 callback 用來記錄各類別在訓練過程中，對驗證集的 True Positive 與 True Negative

In [1]:
import os
import keras
from keras.datasets import cifar10, mnist
from keras.models import Model, load_model
from keras.layers import Input, Conv2D, MaxPool2D, Dense
from keras.layers import Flatten, BatchNormalization, Activation
from keras.utils import to_categorical
from keras.optimizers import SGD
from keras.callbacks import Callback
from sklearn.metrics import confusion_matrix

# Disable GPU
os.environ["CUDA_VISIBLE_DEVICES"] = ""

Using TensorFlow backend.


In [2]:
train, test = cifar10.load_data()

In [3]:
## 資料前處理
def preproc_x(x, flatten=True):
    x = x / 255.
    if flatten:
        x = x.reshape((len(x), -1))
    return x

def preproc_y(y, num_classes=10):
    if y.shape[-1] == 1:
        y = keras.utils.to_categorical(y, num_classes)
    return y    

In [4]:
x_train, y_train = train
x_test, y_test = test

# Preproc the inputs
x_train = preproc_x(x_train, flatten=False)
x_test = preproc_x(x_test, flatten=False)

# Preprc the outputs
y_train = preproc_y(y_train)
y_test = preproc_y(y_test)

In [5]:
def build_cnn(input_shape, dense_neurno, output_shape):
    input_layer = Input(input_shape)
    conv = Conv2D(filters=64, kernel_size=(3, 3), padding='same', name='conv1')(input_layer)
    maxpool = MaxPool2D(name='Max_pool1')(conv)
    conv = Conv2D(filters=32, kernel_size=(3, 3), padding='same', name='conv2')(maxpool)
    maxpool = MaxPool2D(name='Max_pool2')(conv)
    flatten_layer = Flatten()(maxpool)
    for ind, num in enumerate(dense_neurno):
        if ind == 0:
            dense_layer = Dense(num)(flatten_layer)
            bn_layer = BatchNormalization()(dense_layer)
            act_layer = Activation(activation='relu')(bn_layer)         
        else:
            dense_layer = Dense(num)(flatten_layer)
            bn_layer = BatchNormalization()(dense_layer)
            act_layer = Activation(activation='relu')(bn_layer)
            
    output = Dense(output_shape, activation='softmax')(act_layer)
    model = Model(inputs=[input_layer], outputs=[output])
    
    return model

In [6]:
## 超參數設定
LEARNING_RATE = 1e-3
EPOCHS = 50
BATCH_SIZE = 256
DENSE_NEURONS = [256, 128, 64]
INPUT_SHAPE = x_train.shape[1:]
OUTPUT_SHAPE = 10
MOMENTUM = 0.95

In [10]:
# Record_fp_tp
class Record_tp_tn(Callback):
    def on_train_begin(self, logs={}):
        logs = logs or {}
        record_metrics = ['val_tp', 'val_fp']
        for name in record_metrics:
            if name not in self.params['metrics']:
                self.params['metrics'].append(name)
        
    def on_epoch_end(self, epoch, logs={}):
        logs = logs or {}
        y_true = self.validation_data[1].argmax(axis=1)
        y_pred = self.model.predict(self.validation_data[0]).argmax(axis=1)
        matrix = confusion_matrix(y_true, y_pred)
        val_fp = matrix.sum(axis=0) - np.diag(matrix)
        val_tp = np.diag(matrix)
        logs['val_tp'] = val_tp
        logs['val_fp'] = val_fp
        
rec_tptn = Record_tp_tn()

In [1]:
optimizer = SGD(lr=LEARNING_RATE, nesterov=True, momentum=MOMENTUM)

model = build_cnn(INPUT_SHAPE, DENSE_NEURONS, OUTPUT_SHAPE)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=EPOCHS, batch_size=BATCH_SIZE,
          validation_split=0.2, shuffle=True, callbacks=[rec_tptn])

# Collect results
train_loss = model.history.history["loss"]
valid_loss = model.history.history["val_loss"]
train_acc = model.history.history["acc"]
valid_acc = model.history.history["val_acc"]

"""
To do:
Collect tp/tn from model history
"""
valid_tp = model.history.history["val_tp"]
valid_fp = model.history.history["val_fp"]

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

plt.plot(range(len(train_loss)), train_loss, label="train loss")
plt.plot(range(len(valid_loss)), valid_loss, label="valid loss")
plt.legend()
plt.title("Loss")
plt.show()

plt.plot(range(len(train_acc)), train_acc, label="train accuracy")
plt.plot(range(len(valid_acc)), valid_acc, label="valid accuracy")
plt.legend()
plt.title("Accuracy")
plt.show()

plt.plot(range(len(valid_tp)), valid_tp, label="valid tp", color="navy")
plt.plot(range(len(valid_tn)), valid_tn, label="valid tn", color="red")
plt.legend()
plt.title("True positives and True Negatives")
plt.show()