In [None]:
# 모듈 임포트
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np

## 데이터셋 준비

In [None]:
train_datasets = tfds.load('cassava',split='train')
valid_datasets = tfds.load('cassava',split='validation')
test_datasets = tfds.load('cassava',split='test')

def preprocessing(data):
    image = tf.cast(data['image'],dtype=tf.float32)/255.0
    image = tf.image.resize(image,size=(224,224))
    label = data['label']
    return image,label

train_data = train_datasets.map(preprocessing).shuffle(1000).batch(128)
valid_data = valid_datasets.map(preprocessing).batch(32)
test_data = test_datasets.map(preprocessing).batch(32)

[1mDownloading and preparing dataset cassava/0.1.0 (download: 1.26 GiB, generated: Unknown size, total: 1.26 GiB) to /root/tensorflow_datasets/cassava/0.1.0...[0m


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…









HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /root/tensorflow_datasets/cassava/0.1.0.incomplete1QOZI2/cassava-train.tfrecord


HBox(children=(FloatProgress(value=0.0, max=5656.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /root/tensorflow_datasets/cassava/0.1.0.incomplete1QOZI2/cassava-test.tfrecord


HBox(children=(FloatProgress(value=0.0, max=1885.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /root/tensorflow_datasets/cassava/0.1.0.incomplete1QOZI2/cassava-validation.tfrecord


HBox(children=(FloatProgress(value=0.0, max=1889.0), HTML(value='')))

[1mDataset cassava downloaded and prepared to /root/tensorflow_datasets/cassava/0.1.0. Subsequent calls will reuse this data.[0m


## 전이 모델

In [None]:
transfer_model = tf.keras.applications.VGG16(weights='imagenet',       # 이미지 분류에 대한 가중치
                                             include_top=False,        #  fully connected layer 제외
                                             input_shape=(224,224,3))  # input shape 정의
transfer_model.trainable = False # 기존 VGG16 가중치 고정 

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5


#모델링

### Transfer Layer 구현

In [None]:

class Transfer(tf.keras.models.Model):
    def __init__(self):
        super(Transfer,self).__init__()
        self.transfer_model = tf.keras.applications.VGG16(weights = 'imagenet',
                                                          include_top = False,
                                                          input_shape = (224,224,3))
        self.transfer_model.trainable = False
        
    
    def call(self,input_):
        x = self.transfer_model(input_)
        return x

### 모델 정의

In [None]:
class Mymodel(tf.keras.models.Model):
    def __init__(self):
        super(Mymodel,self).__init__()
        self.vgg16 = Transfer()
        self.flatten = tf.keras.layers.Flatten()
        self.Dense = tf.keras.layers.Dense(32,activation='relu')
        self.output_ = tf.keras.layers.Dense(32,activation='softmax')

    def call(self,input_):
        x = self.vgg16(input_)
        x = self.flatten(x)
        x = self.Dense(x)
        x = self.output_(x)
        return x

In [None]:
model = Mymodel()

input_ = tf.keras.layers.Input(shape=(224,224,3))
model(input_)

model.summary()

Model: "mymodel"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
transfer (Transfer)          multiple                  14714688  
_________________________________________________________________
flatten (Flatten)            multiple                  0         
_________________________________________________________________
dense (Dense)                multiple                  802848    
_________________________________________________________________
dense_1 (Dense)              multiple                  1056      
Total params: 15,518,592
Trainable params: 803,904
Non-trainable params: 14,714,688
_________________________________________________________________


## 모델 컴파일,체크포인트,학습

In [None]:
# 컴파일
optimizer = tf.keras.optimizers.Adam(lr=0.0001)
model.compile(optimizer=optimizer,loss='sparse_categorical_crossentropy',metrics=['acc'])

In [None]:
# 체크 포인트
checkpoint_path = 'my_checkpoint.ckpt'
checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                save_weights_only = True,
                                                save_best_only = True,
                                                monitor = 'val_loss',
                                                verbose = 1)

In [None]:
#학습
model.fit(train_data,
          validation_data=(valid_data),
          epochs=20,
          callbacks = [checkpoint])

Epoch 1/20

Epoch 00001: val_loss improved from inf to 2.17591, saving model to my_checkpoint.ckpt
Epoch 2/20

Epoch 00002: val_loss improved from 2.17591 to 1.88055, saving model to my_checkpoint.ckpt
Epoch 3/20

Epoch 00003: val_loss improved from 1.88055 to 1.63836, saving model to my_checkpoint.ckpt
Epoch 4/20

Epoch 00004: val_loss improved from 1.63836 to 1.40186, saving model to my_checkpoint.ckpt
Epoch 5/20

Epoch 00005: val_loss improved from 1.40186 to 1.21300, saving model to my_checkpoint.ckpt
Epoch 6/20

Epoch 00006: val_loss improved from 1.21300 to 1.02730, saving model to my_checkpoint.ckpt
Epoch 7/20

Epoch 00007: val_loss improved from 1.02730 to 0.90206, saving model to my_checkpoint.ckpt
Epoch 8/20

Epoch 00008: val_loss improved from 0.90206 to 0.84252, saving model to my_checkpoint.ckpt
Epoch 9/20

Epoch 00009: val_loss improved from 0.84252 to 0.81911, saving model to my_checkpoint.ckpt
Epoch 10/20

Epoch 00010: val_loss improved from 0.81911 to 0.79731, saving m

<tensorflow.python.keras.callbacks.History at 0x7fb5eed78910>

In [None]:
# 가중치 적용
model.load_weights(checkpoint_path)

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fb5eed61890>

In [None]:
# 테스트
model.evaluate(test_data)



[0.7520942091941833, 0.7326260209083557]