<a href="https://colab.research.google.com/github/miinkang/Taeguekgi_Classifier/blob/main/Taegeukgi_UNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install tensorflow-gpu

In [None]:
use_colab = True
assert use_colab in [True, False]

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from __future__ import absolute_import, division
from __future__ import print_function, unicode_literals

import os
import time
import shutil
import functools

import numpy as np
# import matplotlib.pyplot as plt
# %matplotlib inline
# import matplotlib as mpl
# mpl.rcParams['axes.grid'] = False
# mpl.rcParams['figure.figsize'] = (12,12)

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
import matplotlib.image as mpimg
import pandas as pd
from PIL import Image
from IPython.display import clear_output

import tensorflow as tf
import tensorflow_addons as tfa
print(tf.__version__)

from tensorflow.keras import layers
from tensorflow.keras import losses
from tensorflow.keras import models
from tensorflow.keras import preprocessing



In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
train = ImageDataGenerator(rescale = 1/255)
val =ImageDataGenerator(rescale=1/255)

In [None]:
train_dataset = train.flow_from_directory('/content/drive/MyDrive/[바로알리기공모전]3456/data/train',
                                          target_size = (200,200),
                                          batch_size = 32,
                                          shuffle=True,
                                          class_mode='binary')

val_dataset = train.flow_from_directory('/content/drive/MyDrive/[바로알리기공모전]3456/data/val',
                                          target_size = (200,200),
                                          batch_size = 32,
                                          shuffle=True,
                                          class_mode='binary')

In [None]:
# print(train_dataset[0].shape, val_dataset[0].shape)

In [None]:
class Conv(tf.keras.Model):
        def __init__(self, num_filters, kernel_size, input_shape=(200,200,3)):
            super(Conv, self).__init__()
            self.conv_1 = layers.Conv2D(num_filters, kernel_size, padding='same')# Conv2D
            self.bn_1 = layers.BatchNormalization()# batch norm
            self.relu = layers.ReLU()
            self.dropout_1 = layers.Dropout(0.5)

        def call(self, inputs, training=True):
            x = self.conv_1(inputs)# conv
            x = self.bn_1(x, training=True)# batch norm
            x = self.relu(x) # relu
            # if training:
            #     x = self.dropout_1(x)

            return x

In [None]:
class ConvBlock(tf.keras.Model):
        def __init__(self, num_filters):
            super(ConvBlock, self).__init__()
            self.conv1 = Conv(num_filters, kernel_size=3) # Conv class
            self.conv2 = Conv(num_filters, kernel_size=3) # Conv class

        def call(self, inputs, training=True):
            encoder = self.conv1(inputs) # conv1
            encoder = self.conv2(encoder) # conv2

            return encoder
    
class ConvBlock_R(tf.keras.Model):
    def __init__(self, num_filters):
        super(ConvBlock_R, self).__init__()
        self.conv1 = Conv(num_filters, kernel_size=3)# Conv class 
        self.conv2 = Conv(num_filters, kernel_size=3)# Conv class

    def call(self, inputs, training=True):
        decoder = self.conv1(inputs)# conv1
        decoder = self.conv2(decoder)# conv2

        return decoder


class EncoderBlock(tf.keras.Model):
    def __init__(self, num_filters):
        super(EncoderBlock, self).__init__()
        self.conv_block = ConvBlock(num_filters) # Conv block
        self.encoder_pool = layers.MaxPooling2D()# max pool

    def call(self, inputs, training=True):
        encoder = self.conv_block(inputs)# conv block
        encoder_pool = self.encoder_pool(encoder)# encoder pool

        return encoder_pool, encoder


class DecoderBlock(tf.keras.Model):
    def __init__(self, num_filters):
        super(DecoderBlock, self).__init__()
        self.convT = layers.Conv2DTranspose(num_filters, kernel_size=3, strides=2, padding='same') # conv transpose
        self.bn = layers.BatchNormalization()# batch norm
        self.conv_block_r = ConvBlock_R(num_filters) # convblock R
        self.relu = layers.ReLU()
        self.dropout_1 = layers.Dropout(0.5)

    def call(self, input_tensor, concat_tensor, training=True):
        # convT - bn - relu - concat - conv black R
        decoder = self.convT(input_tensor) 
        decoder = self.bn(decoder)
        decoder = self.relu(decoder)
        # if training:
        #     decoder = self.dropout_1(decoder)
        decoder = tf.concat([concat_tensor, decoder], -1) 
        decoder = self.conv_block_r(decoder) 
        
        

        return decoder

In [None]:
class UNet(tf.keras.Model):
        def __init__(self):
            super(UNet, self).__init__()
            self.encoder_input = tf.keras.layers.Input((200, 200, 3))
            self.encoder_block1 = EncoderBlock(32) # encoder 32
            self.encoder_block2 = EncoderBlock(64) # encoder 64
            self.encoder_block3 = EncoderBlock(128) # encoder 128
            self.encoder_block4 = EncoderBlock(256) # encoder 256

            self.center = ConvBlock(512) # conv block 512

            self.decoder_block4 = DecoderBlock(256) # decoder 256
            self.decoder_block3 = DecoderBlock(128) # decoder 128
            self.decoder_block2 = DecoderBlock(64) # decoder 64
            self.decoder_block1 = DecoderBlock(32) # decoder 32

            self.output_conv = layers.Conv2D(1, 1, activation='sigmoid') # a output layer conv2d
            # padding='same'이 없으니 kernel size를 1로, 1x1 conv를 사용함. 
            # 1x1 conv 는 dense 와 같은 역할을 한다. 

        def call(self, inputs, training=True): 
            # inputs = self.encoder_input
            encoder1_pool, encoder1 = self.encoder_block1(inputs) # encoder1 outputs
            encoder2_pool, encoder2 = self.encoder_block2(encoder1_pool) # encoder2 outputs
            encoder3_pool, encoder3 = self.encoder_block3(encoder2_pool) # encoder3 outputs
            encoder4_pool, encoder4 = self.encoder_block4(encoder3_pool) # encoder4 outputs
            # pooling, concat용 데이터 
            center = self.center(encoder4_pool) # center outputs

            decoder4 = self.decoder_block4(center, encoder4)# decoder4 output
            decoder3 = self.decoder_block3(decoder4, encoder3)# decoder3 output
            decoder2 = self.decoder_block2(decoder3, encoder2)# decoder2 output
            decoder1 = self.decoder_block1(decoder2, encoder1)# decoder1 output

            outputs = self.output_conv(decoder1) # the model output

            return outputs

In [None]:
optimizer = tf.keras.optimizers.Adam(0.001)
loss = tf.keras.losses.BinaryCrossentropy
max_epochs = 50
batch_size = 128

In [None]:
model = UNet()

In [None]:
model.compile(optimizer=optimizer, loss=loss, metrics=['acc'])

In [None]:
checkpoint_dir = 'drive/MyDrive'
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_dir,
                                                 save_weights_only=True,
                                                 monitor='val_loss',
                                                 mode='auto',
                                                 save_best_only=True,
                                                 verbose=1)

In [None]:
cos_decay = tf.keras.experimental.CosineDecay(0.0001, 
                                              max_epochs)
lr_callback = tf.keras.callbacks.LearningRateScheduler(cos_decay, verbose=1)

In [None]:
model.summary()

In [None]:
history = model.fit(train_dataset,
                    epochs=max_epochs,
                    steps_per_epoch=1000//batch_size,
                    validation_data=val_dataset,
                    # validation_steps=num_test_examples//batch_size,
                    callbacks=[cp_callback, lr_callback, cos_decay]
                    )