In [2]:
import tensorflow as tf
import numpy as np
import pathlib
tf.__version__

'2.0.0'

NiN使用1×1卷积层来替代全连接层，从而使空间信息能够自然传递到后面的层中去，下图对比了NiN同AlexNet和VGG等网络在结构上的主要区别

![Image text](http://zh.d2l.ai/_images/nin.svg)

In [7]:
def nin_block(num_chanels,padding,strides,kernel_size):
    net = tf.keras.Sequential()
    net.add(tf.keras.layers.ZeroPadding2D(padding=padding))
    net.add(tf.keras.layers.Conv2D(num_chanels,strides = strides,kernel_size = kernel_size,activation="relu"))
    net.add(tf.keras.layers.Conv2D(num_chanels,kernel_size=1,activation="relu"))
    net.add(tf.keras.layers.Conv2D(num_chanels,kernel_size=1,activation="relu"))
    return net

In [21]:
model = tf.keras.Sequential()
model.add(tf.keras.layers.InputLayer(input_shape=(224,224,3)))
model.add(nin_block(96,kernel_size=11,strides=(4,4),padding=(0,0)))
model.add(tf.keras.layers.MaxPooling2D(pool_size=3,strides=2))
model.add(nin_block(256,kernel_size=5,padding=(2,2),strides=(1,1)))
model.add(tf.keras.layers.MaxPooling2D(pool_size=3,strides=(2,2)))
model.add(nin_block(384,kernel_size=3,strides=1,padding=(1,1)))
model.add(tf.keras.layers.MaxPooling2D(pool_size=3,strides=2))
model.add(tf.keras.layers.Dropout(0.5))
model.add(nin_block(10,kernel_size=3,strides=1,padding=(1,1)))
model.add(tf.keras.layers.GlobalAveragePooling2D())
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Activation("softmax"))

In [22]:
model.summary()

Model: "sequential_13"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
sequential_14 (Sequential)   (None, 54, 54, 96)        53568     
_________________________________________________________________
max_pooling2d_8 (MaxPooling2 (None, 26, 26, 96)        0         
_________________________________________________________________
sequential_15 (Sequential)   (None, 26, 26, 256)       746240    
_________________________________________________________________
max_pooling2d_9 (MaxPooling2 (None, 12, 12, 256)       0         
_________________________________________________________________
sequential_16 (Sequential)   (None, 12, 12, 384)       1180800   
_________________________________________________________________
max_pooling2d_10 (MaxPooling (None, 5, 5, 384)         0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 5, 5, 384)       

In [23]:
x = tf.random.normal(shape=(224,224,3))
x = tf.expand_dims(x,axis=0)
x.shape

TensorShape([1, 224, 224, 3])

In [24]:
model(x)

<tf.Tensor: id=9524, shape=(1, 10), dtype=float32, numpy=
array([[0.09955324, 0.10052283, 0.09955355, 0.09955144, 0.10119517,
        0.10093878, 0.09955144, 0.0996519 , 0.0997114 , 0.09977026]],
      dtype=float32)>

In [25]:
IMG_WIDTH = 224
IMG_HEIGHT = 224
BATCH_SIZE = 32

In [26]:
DATA_URL = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file(fname="flower_photos",origin=DATA_URL,untar=True)
import pathlib
data_dir = pathlib.Path(data_dir)
ds = tf.data.Dataset.list_files(str(data_dir/'*/*'))
class_lable =  np.array([ item.name for item in data_dir.glob("*") if item.name != "LICENSE.txt" ])
## 转换lable为数值型数据
lable_dic =dict( zip(class_lable,np.asarray(range(0,len(class_lable)))))
@tf.function
def get_lable(x):
    arr_str = tf.strings.split(x,sep="\\")
    return arr_str[-2] == class_lable

def decode_img(img):
    img = tf.image.decode_image(img,channels=3)
    img = tf.image.convert_image_dtype(img, tf.float32)
    img = tf.image.resize_with_crop_or_pad(img,IMG_WIDTH,IMG_HEIGHT)
    img = tf.image.random_flip_left_right(img)
    img = tf.image.random_flip_up_down(img)
    return img

def img_map_fun(x):
    lable = get_lable(x)
    img = tf.io.read_file(x)
    img = decode_img(img)
    return img,lable

lable = tf.constant([0,1,2,3,4])
@tf.function
def map_lable_fun(x,y):
    print(y)
    y = lable[y][0]
    return x,y

ds = ds.map(map_func=img_map_fun,num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = ds.map(map_func=map_lable_fun)

def prepare_for_training(ds, cache=True, shuffle_buffer_size=1000,batch_size = 256):
    # This is a small dataset, only load it once, and keep it in memory.
    # use `.cache(filename)` to cache preprocessing work for datasets that don't
    # fit in memory.
    if cache:
        if isinstance(cache, str):
            ds = ds.cache(cache)
    else:
        ds = ds.cache()

    ds = ds.shuffle(buffer_size=shuffle_buffer_size)

    # Repeat forever
    ds = ds.repeat()

    ds = ds.batch(batch_size)

    # `prefetch` lets the dataset fetch batches in the background while the model
    # is training.
    ds = ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

    return ds

ds = prepare_for_training(ds,batch_size = BATCH_SIZE)

image_batch, label_batch = next(iter(ds))

Tensor("y:0", shape=(5,), dtype=bool)


In [27]:
model(image_batch)

<tf.Tensor: id=10110, shape=(32, 10), dtype=float32, numpy=
array([[0.09979424, 0.10017834, 0.09979573, 0.09978852, 0.10048943,
        0.10053866, 0.09978912, 0.09985548, 0.09991376, 0.09985679],
       [0.09982231, 0.10014329, 0.09982099, 0.09981655, 0.10037722,
        0.10054319, 0.09981655, 0.09986108, 0.09994427, 0.09985459],
       [0.09987631, 0.1000984 , 0.09986973, 0.09986672, 0.10026178,
        0.10038055, 0.09986751, 0.09990907, 0.09998167, 0.09988821],
       [0.09986027, 0.10012738, 0.09985249, 0.09985249, 0.10023209,
        0.10041685, 0.09986018, 0.09992839, 0.09997012, 0.09989975],
       [0.09978769, 0.10018132, 0.09979094, 0.09978241, 0.10048405,
        0.10057083, 0.09978608, 0.09985456, 0.09992738, 0.09983474],
       [0.09983917, 0.10015063, 0.09983138, 0.09982822, 0.10036184,
        0.10044729, 0.09982993, 0.09988883, 0.0999561 , 0.09986661],
       [0.09986003, 0.10010026, 0.0998689 , 0.09985211, 0.10031845,
        0.10041818, 0.09985344, 0.09989356, 0.0999

In [28]:
model.compile(loss=tf.keras.losses.sparse_categorical_crossentropy,metrics=["acc"])
model.fit(ds,steps_per_epoch=50,epochs=5)

Train for 50 steps
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x4bc9f668>