<a href="https://colab.research.google.com/github/jayarnim/M1-DeepLearning/blob/main/skills/4_Workflow.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import math
import tensorflow as tf

In [None]:
import tensorflow.keras.datasets.mnist as mnist

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [None]:
model = tf.keras.Sequential(
    [
        tf.keras.layers.Input(shape = (28, 28,)),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(units = 512, activation = "relu"),
        tf.keras.layers.Dense(units = 10, activation = "softmax")
        ]
    )

In [None]:
model.compile(
    optimizer = "rmsprop",
    loss = "sparse_categorical_crossentropy",
    metrics = ["accuracy"]
    )

# fit

In [None]:
# model.fit(
#     x = train_images,
#     y = train_labels,
#     epochs = 5,
#     batch_size = 128,
#     validation_data = (test_images, test_labels)
#     )

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.src.callbacks.History at 0x7e87605c7550>

## Optimizer

In [None]:
learning_rate = 1e-3
optimizer = tf.keras.optimizers.SGD(learning_rate = learning_rate)

def update_weights(gradients, weights):
    # for g, w in zip(gradients, weights):
    #     w.assign_sub(g * learning_rate)
    optimizer.apply_gradients(zip(gradients, weights))

## Backward Path

```
with tf.GradientTape() as tape:
    tape.watch(x)
    y = x**2

dy_dx = tape.gradient(y, x)
```

In [None]:
loss_function = tf.keras.losses.sparse_categorical_crossentropy

def one_training_step(model, images_batch, labels_batch):
    with tf.GradientTape() as tape:
        predictions = model(images_batch)
        per_sample_losses = loss_function(labels_batch, predictions)
        average_loss = tf.reduce_mean(per_sample_losses)

    gradients = tape.gradient(average_loss, model.weights)
    update_weights(gradients, model.weights)
    return average_loss

## Batch Iterater

In [None]:
class BatchGenerator:
    def __init__(self, images, labels, batch_size = 128):
        assert len(images) == len(labels)
        self.index = 0
        self.images = images
        self.labels = labels
        self.batch_size = batch_size
        self.num_batches = math.ceil(len(images) / batch_size)

    def next(self):
        images = self.images[self.index : self.index + self.batch_size]
        labels = self.labels[self.index : self.index + self.batch_size]
        self.index += self.batch_size
        return images, labels

## Fit

In [None]:
def fit(model, images, labels, epochs, batch_size = 128):
    for epoch_counter in range(epochs):
        print(f"epoch : {epoch_counter}")
        batch_generator = BatchGenerator(images, labels, batch_size)

        for batch_counter in range(batch_generator.num_batches):
            images_batch, labels_batch  = batch_generator.next()
            loss = one_training_step(model, images_batch, labels_batch)
            if batch_counter % 100 == 0:
                print(f"{batch_counter} 번째 배치 손실 : {loss:2f}")

# Callback

- `tf.keras.callbacks.EarlyStopping` : 검증 손실이 더 이상 향상되지 않을 경우 학습 조기 종료
- `tf.keras.callbacks.ModelCheckpoint` : 훈련하는 동안 어떤 지점에서 모델의 현재 가중치를 저장
- `tf.keras.callbacks.TensorBoard`
- `tf.keras.LearningRateScheduler`
- `tf.keras.ReduceLROnPlateau`
- `tf.keras.callbacks.CSVLogger`

## early stopping

In [None]:
# 학습 조기 종료
early_stopping = tf.keras.callbacks.EarlyStopping(
    # 검증 정확도에 대하여 모니터링함
    monitor = "val_accuracy",
    # 2회 이상 개선되지 않을 경우 학습 조기 종료
    patience = 2
    )

callback_list = [early_stopping]

In [None]:
model.compile(
    optimizer = "rmsprop",
    loss = "sparse_categorical_crossentropy",
    metrics = ["accuracy"]
    )

model.fit(
    x = train_images,
    y = train_labels,
    epochs = 10,
    batch_size = 128,
    callbacks = callback_list,
    validation_data = (test_images, test_labels)
    )

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10


<keras.src.callbacks.History at 0x7ffa549038e0>

## model check point

In [None]:
# 매 에포크의 마지막 배치에서 가중치 저장
model_check_point = tf.keras.callbacks.ModelCheckpoint(
    # 저장 경로
    filepath = "checkpoint_path.keras",
    # 모니터링할 지표
    monitor = "val_loss",
    # 모니터링할 지표를 기준으로 가장 성능이 좋은 모델만 저장
    save_best_only = True,
    # 포맷 형식 설정
    save_format = "keras"
    )

callback_list = [model_check_point]

In [None]:
model.compile(
    optimizer = "rmsprop",
    loss = "sparse_categorical_crossentropy",
    metrics = ["accuracy"]
    )

model.fit(
    x = train_images,
    y = train_labels,
    epochs = 10,
    batch_size = 128,
    callbacks = callback_list,
    validation_data = (test_images, test_labels)
    )

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.src.callbacks.History at 0x7ffa54684fd0>

In [None]:
# 모델 저장하기
model.save('checkpoint_path.keras')

In [None]:
# 모델 불러오기
model = tf.keras.models.load_model("checkpoint_path.keras")

## 사용자 정의 Callback

- `def`
    - `on_train_begin(logs)` : 훈련 시작 시 호출
    - `on_train_end(logs)` : 훈련 종료 시 호출
    - `on_epoch_begin(epoch, logs)` : 각 에포크 시작 시 호출
    - `on_epoch_end(epoch, logs)` : 각 에포크 종료 시 호출
    - `on_batch_begin(batch, logs)` : 각 배치 처리 시작 전 호출
    - `on_batch_end(batch, logs)` : 각 배치 처리 종료 후 호출

- `params`
    - `logs` : 훈련 실행 정보, 이전 배치 정보, 이전 에포크 정보가 담긴 딕셔너리

In [None]:
import os
from matplotlib import pyplot as plt

class LossHistory(tf.keras.callbacks.Callback):
    def __init__(self, save_dir):
        super().__init__()
        self.save_dir = save_dir # 저장할 디렉토리 지정

    def on_train_begin(self, logs):
        self.per_batch_losses = []

    def on_batch_end(self, batch, logs):
        self.per_batch_losses.append(logs.get("loss"))

    def on_epoch_end(self, epoch, logs):
        plt.clf()
        plt.plot(
            range(len(self.per_batch_losses)),
            self.per_batch_losses,
            label = "Training loss for each batch"
            )
        plt.xlabel(f"Batch (epoch {epoch})")
        plt.ylabel("Loss")
        plt.legend()
        plt.savefig(
            os.path.join(
                self.save_dir,
                f"plot_at_epoch_{epoch}.png"
                )
            )
        self.per_batch_losses = []