In [1]:
%matplotlib inline

import os
import pickle
import tensorflow.keras as keras
import numpy as np
import matplotlib.pyplot as plt

from tensorflow.keras import Input
from tensorflow.keras.layers import Dense, Convolution2D, MaxPooling2D, Flatten, Input, Activation, Add, Dropout, BatchNormalization, GlobalAveragePooling2D
from tensorflow.keras.layers import Convolution3D, MaxPooling3D, GlobalAveragePooling3D
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split

from IPython.display import SVG
from tensorflow.python.keras.utils.vis_utils import model_to_dot

from tensorflow.examples.tutorials.mnist import input_data

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
def Shortcut3D(input_data, fx, data_format = "channels_first"):

  # チャンネル数の取得
  if data_format == "channels_first":
    channel_num = int(fx.shape[1])
  else:
    channel_num = int(fx.shape[4])
      
  # inputs と residual とでチャネル数が違うかもしれない。
  # そのままだと足せないので、1x1 conv を使って residual 側のフィルタ数に合わせている
  buf = Convolution3D(channel_num, (1,1,1), strides=(1,1,1), padding='valid', data_format=data_format)(input_data)

  # 2つを足す
  return Add()([buf, fx])

In [3]:
def Resblock3D(n_filters, strides=(1,1,1), data_format = "channels_first", axis=1):
  def f(input_data):   
    fx = Convolution3D(n_filters, (3,3,3), strides=strides, kernel_initializer='he_normal', padding='same', data_format=data_format)(input_data)
    fx = BatchNormalization(axis=axis)(fx)
    fx = Activation('relu')(fx)
    fx = Convolution3D(n_filters, (3,3,3), strides=strides, kernel_initializer='he_normal', padding='same', data_format=data_format)(fx)
    fx = BatchNormalization(axis=axis)(fx)

    return Shortcut3D(input_data, fx, data_format)

  return f

In [42]:
def Resnet3D(channel = 3, frame = 10, height = 64, width = 64, num_classes = 2, is_channels_first = True):
  if is_channels_first:
    data_format = "channels_first"
    axis=1
    input_data = Input(shape=(channel, frame, height, width))
  else:
    data_format = "channels_last"
    axis=-1
    input_data = Input(shape=(frame, height, width, channel))
  
  print(input_data.shape)

  x = Convolution3D(32, (7,7,7), strides=(1,1,1), kernel_initializer='he_normal', padding='same', data_format=data_format)(input_data)
  x = BatchNormalization(axis=axis)(x)
  x = Activation('relu')(x)
  x = MaxPooling3D((3, 3,3), strides=(2,2,2), padding='same', data_format=data_format)(x)


  x = Resblock3D(n_filters=64, data_format = data_format, axis=axis)(x)
  x = Resblock3D(n_filters=64, data_format = data_format, axis=axis)(x)
  x = Resblock3D(n_filters=64, data_format = data_format, axis=axis)(x)
  x = MaxPooling3D(strides=(2,2,2), data_format=data_format)(x)  
  x = Resblock3D(n_filters=128, data_format = data_format, axis=axis)(x)
  x = Resblock3D(n_filters=128, data_format = data_format, axis=axis)(x)
  x = Resblock3D(n_filters=128, data_format = data_format, axis=axis)(x)


  x =  GlobalAveragePooling3D(data_format=data_format)(x)
  x = Dense(num_classes, kernel_initializer='he_normal', activation='softmax')(x)

  model = Model(inputs=input_data, outputs=x)
  return model

In [43]:
is_channels_first = False

if is_channels_first:
    with open('resnet3d_data_x_channel_first.pkl', 'rb') as fx:
        X = pickle.load(fx)
    
    with open('resnet3d_data_y_channel_first.pkl', 'rb') as fy:
        Y = pickle.load(fy)
else:
    with open('resnet3d_data_x_channel_last.pkl', 'rb') as fx:
        X = pickle.load(fx)
    
    with open('resnet3d_data_y_channel_last.pkl', 'rb') as fy:
        Y = pickle.load(fy)    
        
model = Resnet3D(channel=3, frame=10, height=64, width=64, num_classes=2, is_channels_first=is_channels_first)
adam = keras.optimizers.Adam()
model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])

(?, 10, 64, 64, 3)


In [40]:
X_train = X[:1200]
Y_train = Y[:1200]

X_test = X[1200:]
Y_test = Y[1200:]

y_train = keras.utils.to_categorical(Y_train, 2)
y_test = keras.utils.to_categorical(Y_test, 2)



In [41]:
history = model.fit(X_train, y_train,
                    batch_size=2,
                    epochs=10,
                    verbose=1,
                    validation_data=(X_test, y_test))

Train on 1200 samples, validate on 417 samples
Epoch 1/10
 190/1200 [===>..........................] - ETA: 47:26 - loss: 0.8700 - acc: 0.5000

KeyboardInterrupt: 

In [34]:
model.summary()

Model: "model_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_4 (InputLayer)            [(None, 10, 64, 64,  0                                            
__________________________________________________________________________________________________
conv3d_57 (Conv3D)              (None, 10, 64, 64, 3 32960       input_4[0][0]                    
__________________________________________________________________________________________________
batch_normalization_39 (BatchNo (None, 10, 64, 64, 3 128         conv3d_57[0][0]                  
__________________________________________________________________________________________________
activation_21 (Activation)      (None, 10, 64, 64, 3 0           batch_normalization_39[0][0]     
____________________________________________________________________________________________