## Training Logic

In [1]:
import numpy as np
import pandas as pd
import tensorflow as tf

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

In [2]:
np.random.seed(7777)
tf.random.set_seed(7777)

In [3]:
class Cifar10DataLoader():
    
    def __init__(self):
        (self.train_x, self.train_y),(self.test_x, self.test_y) = tf.keras.datasets.cifar10.load_data()
        self.input_shape = self.train_x.shape[1:]
    
    def scale(self, x):
        return (x / 255.0).astype(np.float32)
    
    def preprocess_dataset(self, dataset):
        feature, target = dataset
        
        # scale
        scaled_x = np.array([self.scale(x) for x in feature])
    
        # label encoding
        ohe_y = np.array([tf.keras.utils.to_categorical(y, num_classes=10) for y in target])
        
        return scaled_x, ohe_y.squeeze(1)
    
    def get_train_dataset(self):
        return self.preprocess_dataset((self.train_x, self.train_y))
    
    def get_test_dataset(self):
        return self.preprocess_dataset((self.test_x, self.test_y))
    
cifar10_loader = Cifar10DataLoader()
train_x, train_y = cifar10_loader.get_train_dataset()

print(train_x.shape, train_x.dtype)
print(train_y.shape, train_y.dtype)

test_x, test_y = cifar10_loader.get_test_dataset()

print(test_x.shape, test_x.dtype)
print(test_y.shape, test_y.dtype)

(50000, 32, 32, 3) float32
(50000, 10) float32
(10000, 32, 32, 3) float32
(10000, 10) float32


In [4]:
from tensorflow.keras.layers import Input, Conv2D, MaxPool2D, Flatten, Dense, Add

def build_resnet(input_shape):
    inputs = Input(input_shape)
    
    net = Conv2D(32, kernel_size=3, strides=2, padding='same', activation='relu')(inputs)
    net = MaxPool2D()(net)
    
    net1 = Conv2D(64, kernel_size=1, padding='same', activation='relu')(net)
    net2 = Conv2D(64, kernel_size=3, padding='same', activation='relu')(net1) 
    net3 = Conv2D(64, kernel_size=1, padding='same', activation='relu')(net2)
    
    net1_1 = Conv2D(64, kernel_size=1, padding='same', activation='relu')(net)
    net = Add()([net1_1, net3])
    
    net1 = Conv2D(64, kernel_size=1, padding='same', activation='relu')(net)
    net2 = Conv2D(64, kernel_size=3, padding='same', activation='relu')(net1) 
    net3 = Conv2D(64, kernel_size=1, padding='same', activation='relu')(net2)
    
    net = Add()([net, net3])
    
    net = MaxPool2D()(net)
    
    net = Flatten()(net)
    net = Dense(10, activation='softmax')(net)
    
    model = tf.keras.Model(inputs=inputs, outputs=net, name ='resnet')
    
    return model

model= build_resnet((32,32,3))
model.summary()

Model: "resnet"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 32, 32, 3)]  0           []                               
                                                                                                  
 conv2d (Conv2D)                (None, 16, 16, 32)   896         ['input_1[0][0]']                
                                                                                                  
 max_pooling2d (MaxPooling2D)   (None, 8, 8, 32)     0           ['conv2d[0][0]']                 
                                                                                                  
 conv2d_1 (Conv2D)              (None, 8, 8, 64)     2112        ['max_pooling2d[0][0]']          
                                                                                             

### 학습하는 과정 직접 만들기!

In [5]:
lr = 0.03
batch_size = 64

In [6]:
opt = tf.keras.optimizers.Adam(lr)
loss_fn = tf.keras.losses.categorical_crossentropy
train_loss = tf.keras.metrics.Mean()
train_acc = tf.keras.metrics.CategoricalAccuracy()

In [7]:
def train_step(x, y):
    with tf.GradientTape() as tape:
        pred = model(x)
        loss = loss_fn(y, pred)
        
    gradients = tape.gradient(loss, model.trainable_variables)
    opt.apply_gradients(zip(gradients, model.trainable_variables))
    
    train_loss(y)
    train_acc(y, pred)

In [8]:
for epoch in range(1):
    for i in range(train_x.shape[0] // batch_size):
        idx = i * batch_size
        x, y = train_x[idx:idx+batch_size], train_y[idx:idx+batch_size]
        train_step(x, y)
        print('{} / {}'.format(i, train_x.shape[0] // batch_size), end='\r')
        fmt = 'epoch {} loss : {} acc : {}'
        print(fmt.format(
            epoch+1,
            train_loss.result(),
            train_acc.result()
        ))
        train_loss.reset_states()
        train_acc.reset_states()

epoch 1 loss : 0.10000000149011612 acc : 0.109375
epoch 1 loss : 0.10000000149011612 acc : 0.078125
epoch 1 loss : 0.10000000149011612 acc : 0.09375
epoch 1 loss : 0.10000000149011612 acc : 0.09375
epoch 1 loss : 0.10000000149011612 acc : 0.1875
epoch 1 loss : 0.10000000149011612 acc : 0.0625
epoch 1 loss : 0.10000000149011612 acc : 0.171875
epoch 1 loss : 0.10000000149011612 acc : 0.140625
epoch 1 loss : 0.10000000149011612 acc : 0.171875
epoch 1 loss : 0.10000000149011612 acc : 0.078125
epoch 1 loss : 0.10000000149011612 acc : 0.015625
epoch 1 loss : 0.10000000149011612 acc : 0.125
epoch 1 loss : 0.10000000149011612 acc : 0.140625
epoch 1 loss : 0.10000000149011612 acc : 0.09375
epoch 1 loss : 0.10000000149011612 acc : 0.078125
epoch 1 loss : 0.10000000149011612 acc : 0.09375
epoch 1 loss : 0.10000000149011612 acc : 0.078125
epoch 1 loss : 0.10000000149011612 acc : 0.1875
epoch 1 loss : 0.10000000149011612 acc : 0.15625
epoch 1 loss : 0.10000000149011612 acc : 0.140625
epoch 1 loss :

epoch 1 loss : 0.10000000149011612 acc : 0.28125
epoch 1 loss : 0.10000000149011612 acc : 0.34375
epoch 1 loss : 0.10000000149011612 acc : 0.3125
epoch 1 loss : 0.10000000149011612 acc : 0.359375
epoch 1 loss : 0.10000000149011612 acc : 0.34375
epoch 1 loss : 0.10000000149011612 acc : 0.171875
epoch 1 loss : 0.10000000149011612 acc : 0.234375
epoch 1 loss : 0.10000000149011612 acc : 0.25
epoch 1 loss : 0.10000000149011612 acc : 0.1875
epoch 1 loss : 0.10000000149011612 acc : 0.234375
epoch 1 loss : 0.10000000149011612 acc : 0.21875
epoch 1 loss : 0.10000000149011612 acc : 0.28125
epoch 1 loss : 0.10000000149011612 acc : 0.203125
epoch 1 loss : 0.10000000149011612 acc : 0.359375
epoch 1 loss : 0.10000000149011612 acc : 0.3125
epoch 1 loss : 0.10000000149011612 acc : 0.265625
epoch 1 loss : 0.10000000149011612 acc : 0.3125
epoch 1 loss : 0.10000000149011612 acc : 0.140625
epoch 1 loss : 0.10000000149011612 acc : 0.265625
epoch 1 loss : 0.10000000149011612 acc : 0.234375
epoch 1 loss : 0.

epoch 1 loss : 0.10000000149011612 acc : 0.40625
epoch 1 loss : 0.10000000149011612 acc : 0.359375
epoch 1 loss : 0.10000000149011612 acc : 0.359375
epoch 1 loss : 0.10000000149011612 acc : 0.390625
epoch 1 loss : 0.10000000149011612 acc : 0.25
epoch 1 loss : 0.10000000149011612 acc : 0.3125
epoch 1 loss : 0.10000000149011612 acc : 0.34375
epoch 1 loss : 0.10000000149011612 acc : 0.453125
epoch 1 loss : 0.10000000149011612 acc : 0.34375
epoch 1 loss : 0.10000000149011612 acc : 0.296875
epoch 1 loss : 0.10000000149011612 acc : 0.375
epoch 1 loss : 0.10000000149011612 acc : 0.421875
epoch 1 loss : 0.10000000149011612 acc : 0.34375
epoch 1 loss : 0.10000000149011612 acc : 0.359375
epoch 1 loss : 0.10000000149011612 acc : 0.328125
epoch 1 loss : 0.10000000149011612 acc : 0.28125
epoch 1 loss : 0.10000000149011612 acc : 0.34375
epoch 1 loss : 0.10000000149011612 acc : 0.328125
epoch 1 loss : 0.10000000149011612 acc : 0.296875
epoch 1 loss : 0.10000000149011612 acc : 0.390625
epoch 1 loss : 

epoch 1 loss : 0.10000000149011612 acc : 0.328125
epoch 1 loss : 0.10000000149011612 acc : 0.3125
epoch 1 loss : 0.10000000149011612 acc : 0.40625
epoch 1 loss : 0.10000000149011612 acc : 0.328125
epoch 1 loss : 0.10000000149011612 acc : 0.28125
epoch 1 loss : 0.10000000149011612 acc : 0.359375
epoch 1 loss : 0.10000000149011612 acc : 0.34375
epoch 1 loss : 0.10000000149011612 acc : 0.296875
epoch 1 loss : 0.10000000149011612 acc : 0.375
epoch 1 loss : 0.10000000149011612 acc : 0.296875
epoch 1 loss : 0.10000000149011612 acc : 0.375
epoch 1 loss : 0.10000000149011612 acc : 0.265625
epoch 1 loss : 0.10000000149011612 acc : 0.328125
epoch 1 loss : 0.10000000149011612 acc : 0.21875
epoch 1 loss : 0.10000000149011612 acc : 0.40625
epoch 1 loss : 0.10000000149011612 acc : 0.296875
epoch 1 loss : 0.10000000149011612 acc : 0.34375
epoch 1 loss : 0.10000000149011612 acc : 0.296875
epoch 1 loss : 0.10000000149011612 acc : 0.296875
epoch 1 loss : 0.10000000149011612 acc : 0.421875
epoch 1 loss :

epoch 1 loss : 0.10000000149011612 acc : 0.390625
epoch 1 loss : 0.10000000149011612 acc : 0.328125
epoch 1 loss : 0.10000000149011612 acc : 0.390625
epoch 1 loss : 0.10000000149011612 acc : 0.421875
epoch 1 loss : 0.10000000149011612 acc : 0.3125
epoch 1 loss : 0.10000000149011612 acc : 0.359375
epoch 1 loss : 0.10000000149011612 acc : 0.34375
epoch 1 loss : 0.10000000149011612 acc : 0.296875
epoch 1 loss : 0.10000000149011612 acc : 0.328125
epoch 1 loss : 0.10000000149011612 acc : 0.265625
epoch 1 loss : 0.10000000149011612 acc : 0.5625
epoch 1 loss : 0.10000000149011612 acc : 0.390625
epoch 1 loss : 0.10000000149011612 acc : 0.40625
epoch 1 loss : 0.10000000149011612 acc : 0.390625
epoch 1 loss : 0.10000000149011612 acc : 0.296875
epoch 1 loss : 0.10000000149011612 acc : 0.40625
epoch 1 loss : 0.10000000149011612 acc : 0.40625
epoch 1 loss : 0.10000000149011612 acc : 0.390625
epoch 1 loss : 0.10000000149011612 acc : 0.390625
epoch 1 loss : 0.10000000149011612 acc : 0.265625
epoch 1 

In [9]:
@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        pred = model(x)
        loss = loss_fn(y, pred)
        
    gradients = tape.gradient(loss, model.trainable_variables)
    opt.apply_gradients(zip(gradients, model.trainable_variables))
    
    train_loss(y)
    train_acc(y, pred)

In [10]:
for epoch in range(1):
    for i in range(train_x.shape[0] // batch_size):
        idx = i * batch_size
        x, y = train_x[idx:idx+batch_size], train_y[idx:idx+batch_size]
        train_step(x, y)
        print('{} / {}'.format(i, train_x.shape[0] // batch_size), end='\r')
        fmt = 'epoch {} loss : {} acc : {}'
        print(fmt.format(
            epoch+1,
            train_loss.result(),
            train_acc.result()
        ))
        train_loss.reset_states()
        train_acc.reset_states()

epoch 1 loss : 0.10000000149011612 acc : 0.328125
epoch 1 loss : 0.10000000149011612 acc : 0.328125
epoch 1 loss : 0.10000000149011612 acc : 0.375
epoch 1 loss : 0.10000000149011612 acc : 0.390625
epoch 1 loss : 0.10000000149011612 acc : 0.421875
epoch 1 loss : 0.10000000149011612 acc : 0.390625
epoch 1 loss : 0.10000000149011612 acc : 0.328125
epoch 1 loss : 0.10000000149011612 acc : 0.359375
epoch 1 loss : 0.10000000149011612 acc : 0.40625
epoch 1 loss : 0.10000000149011612 acc : 0.375
epoch 1 loss : 0.10000000149011612 acc : 0.265625
epoch 1 loss : 0.10000000149011612 acc : 0.3125
epoch 1 loss : 0.10000000149011612 acc : 0.34375
epoch 1 loss : 0.10000000149011612 acc : 0.25
epoch 1 loss : 0.10000000149011612 acc : 0.328125
epoch 1 loss : 0.10000000149011612 acc : 0.375
epoch 1 loss : 0.10000000149011612 acc : 0.296875
epoch 1 loss : 0.10000000149011612 acc : 0.40625
epoch 1 loss : 0.10000000149011612 acc : 0.34375
epoch 1 loss : 0.10000000149011612 acc : 0.5625
epoch 1 loss : 0.1000

epoch 1 loss : 0.10000000149011612 acc : 0.234375
epoch 1 loss : 0.10000000149011612 acc : 0.328125
epoch 1 loss : 0.10000000149011612 acc : 0.34375
epoch 1 loss : 0.10000000149011612 acc : 0.3125
epoch 1 loss : 0.10000000149011612 acc : 0.25
epoch 1 loss : 0.10000000149011612 acc : 0.453125
epoch 1 loss : 0.10000000149011612 acc : 0.375
epoch 1 loss : 0.10000000149011612 acc : 0.375
epoch 1 loss : 0.10000000149011612 acc : 0.40625
epoch 1 loss : 0.10000000149011612 acc : 0.359375
epoch 1 loss : 0.10000000149011612 acc : 0.390625
epoch 1 loss : 0.10000000149011612 acc : 0.453125
epoch 1 loss : 0.10000000149011612 acc : 0.34375
epoch 1 loss : 0.10000000149011612 acc : 0.375
epoch 1 loss : 0.10000000149011612 acc : 0.453125
epoch 1 loss : 0.10000000149011612 acc : 0.25
epoch 1 loss : 0.10000000149011612 acc : 0.390625
epoch 1 loss : 0.10000000149011612 acc : 0.40625
epoch 1 loss : 0.10000000149011612 acc : 0.328125
epoch 1 loss : 0.10000000149011612 acc : 0.265625
epoch 1 loss : 0.100000

epoch 1 loss : 0.10000000149011612 acc : 0.40625
epoch 1 loss : 0.10000000149011612 acc : 0.296875
epoch 1 loss : 0.10000000149011612 acc : 0.453125
epoch 1 loss : 0.10000000149011612 acc : 0.296875
epoch 1 loss : 0.10000000149011612 acc : 0.28125
epoch 1 loss : 0.10000000149011612 acc : 0.390625
epoch 1 loss : 0.10000000149011612 acc : 0.453125
epoch 1 loss : 0.10000000149011612 acc : 0.421875
epoch 1 loss : 0.10000000149011612 acc : 0.375
epoch 1 loss : 0.10000000149011612 acc : 0.40625
epoch 1 loss : 0.10000000149011612 acc : 0.265625
epoch 1 loss : 0.10000000149011612 acc : 0.328125
epoch 1 loss : 0.10000000149011612 acc : 0.328125
epoch 1 loss : 0.10000000149011612 acc : 0.34375
epoch 1 loss : 0.10000000149011612 acc : 0.421875
epoch 1 loss : 0.10000000149011612 acc : 0.40625
epoch 1 loss : 0.10000000149011612 acc : 0.328125
epoch 1 loss : 0.10000000149011612 acc : 0.25
epoch 1 loss : 0.10000000149011612 acc : 0.3125
epoch 1 loss : 0.10000000149011612 acc : 0.328125
epoch 1 loss :

epoch 1 loss : 0.10000000149011612 acc : 0.34375
epoch 1 loss : 0.10000000149011612 acc : 0.3125
epoch 1 loss : 0.10000000149011612 acc : 0.390625
epoch 1 loss : 0.10000000149011612 acc : 0.265625
epoch 1 loss : 0.10000000149011612 acc : 0.359375
epoch 1 loss : 0.10000000149011612 acc : 0.28125
epoch 1 loss : 0.10000000149011612 acc : 0.421875
epoch 1 loss : 0.10000000149011612 acc : 0.3125
epoch 1 loss : 0.10000000149011612 acc : 0.375
epoch 1 loss : 0.10000000149011612 acc : 0.28125
epoch 1 loss : 0.10000000149011612 acc : 0.390625
epoch 1 loss : 0.10000000149011612 acc : 0.34375
epoch 1 loss : 0.10000000149011612 acc : 0.296875
epoch 1 loss : 0.10000000149011612 acc : 0.484375
epoch 1 loss : 0.10000000149011612 acc : 0.296875
epoch 1 loss : 0.10000000149011612 acc : 0.40625
epoch 1 loss : 0.10000000149011612 acc : 0.265625
epoch 1 loss : 0.10000000149011612 acc : 0.421875
epoch 1 loss : 0.10000000149011612 acc : 0.34375
epoch 1 loss : 0.10000000149011612 acc : 0.375
epoch 1 loss : 0

epoch 1 loss : 0.10000000149011612 acc : 0.296875
epoch 1 loss : 0.10000000149011612 acc : 0.359375
epoch 1 loss : 0.10000000149011612 acc : 0.234375
epoch 1 loss : 0.10000000149011612 acc : 0.46875
epoch 1 loss : 0.10000000149011612 acc : 0.34375
epoch 1 loss : 0.10000000149011612 acc : 0.375
epoch 1 loss : 0.10000000149011612 acc : 0.296875
epoch 1 loss : 0.10000000149011612 acc : 0.296875
epoch 1 loss : 0.10000000149011612 acc : 0.40625
epoch 1 loss : 0.10000000149011612 acc : 0.421875
epoch 1 loss : 0.10000000149011612 acc : 0.328125
epoch 1 loss : 0.10000000149011612 acc : 0.34375
epoch 1 loss : 0.10000000149011612 acc : 0.28125
epoch 1 loss : 0.10000000149011612 acc : 0.40625
epoch 1 loss : 0.10000000149011612 acc : 0.28125
epoch 1 loss : 0.10000000149011612 acc : 0.40625
epoch 1 loss : 0.10000000149011612 acc : 0.40625
epoch 1 loss : 0.10000000149011612 acc : 0.3125
epoch 1 loss : 0.10000000149011612 acc : 0.453125
epoch 1 loss : 0.10000000149011612 acc : 0.328125
epoch 1 loss :