# 경사 하강법을 이용한 얕은 신경망 학습


In [1]:
import tensorflow as tf
import numpy as np

## 하이퍼 파라미터 설정

In [2]:
EPOCHS = 1000

## 네트워크 구조 정의
### 얕은 신경망
#### 입력 계층 : 2, 은닉 계층 : 128 (Sigmoid activation), 출력 계층 : 10 (Softmax activation)

In [3]:
class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.d1 = tf.keras.layers.Dense(128, input_dim=2, activation='sigmoid')
        self.d2 = tf.keras.layers.Dense(10, activation='softmax')
            
    def call(self, x, training=None, mask=None):
        x = self.d1(x)
        return self.d2(x)

## 학습 루프 정의

In [4]:
@tf.function
def train_step(model, inputs, labels, loss_object, optimizer, train_loss, train_metric):
    with tf.GradientTape() as tape:
        predictions = model(inputs)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables) # df(x)/dx
    
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    train_loss(loss)
    train_metric(labels, predictions)

## 데이터셋 생성, 전처리

In [5]:
np.random.seed(0)

pts = list()
labels = list()
center_pts = np.random.uniform(-8.0, 8.0, (10, 2))
for label, center_pt in enumerate(center_pts):
    for _ in range(100):
        pts.append(center_pt + np.random.randn(*center_pt.shape))
        labels.append(label)

pts = np.stack(pts, axis=0).astype(np.float32)
labels = np.stack(labels, axis=0)

train_ds = tf.data.Dataset.from_tensor_slices((pts, labels)).shuffle(1000).batch(32)

## 모델 생성

In [6]:
model = MyModel()

## 손실 함수 및 최적화 알고리즘 설정
### CrossEntropy, Adam Optimizer

In [7]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()

## 평가 지표 설정
### Accuracy

In [8]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

## 학습 루프

In [9]:
for epoch in range(EPOCHS):
    for x, label in train_ds:
        train_step(model, x, label, loss_object, optimizer, train_loss, train_accuracy)
        
    template = 'Epoch {}, Loss: {}, Accuracy: {}'
    print(template.format(epoch + 1,
                          train_loss.result(),
                          train_accuracy.result() * 100))
    train_loss.reset_states()
    train_accuracy.reset_states()

Epoch 1, Loss: 2.2251696586608887, Accuracy: 23.399999618530273
Epoch 2, Loss: 1.8348826169967651, Accuracy: 47.20000076293945
Epoch 3, Loss: 1.588505744934082, Accuracy: 57.70000076293945
Epoch 4, Loss: 1.4419457912445068, Accuracy: 59.89999771118164
Epoch 5, Loss: 1.3077824115753174, Accuracy: 69.19999694824219
Epoch 6, Loss: 1.2149763107299805, Accuracy: 76.0
Epoch 7, Loss: 1.1193238496780396, Accuracy: 77.0
Epoch 8, Loss: 1.0501651763916016, Accuracy: 78.89999389648438
Epoch 9, Loss: 0.9818891286849976, Accuracy: 80.5
Epoch 10, Loss: 0.9312219023704529, Accuracy: 81.5
Epoch 11, Loss: 0.8826591372489929, Accuracy: 82.5
Epoch 12, Loss: 0.8393110632896423, Accuracy: 82.5
Epoch 13, Loss: 0.80780428647995, Accuracy: 83.70000457763672
Epoch 14, Loss: 0.7712287306785583, Accuracy: 84.30000305175781
Epoch 15, Loss: 0.737052857875824, Accuracy: 85.5999984741211
Epoch 16, Loss: 0.7099438309669495, Accuracy: 84.69999694824219
Epoch 17, Loss: 0.6829002499580383, Accuracy: 86.0999984741211
Epoc

Epoch 138, Loss: 0.27681639790534973, Accuracy: 89.5
Epoch 139, Loss: 0.2774837017059326, Accuracy: 89.0999984741211
Epoch 140, Loss: 0.27728068828582764, Accuracy: 89.80000305175781
Epoch 141, Loss: 0.27675238251686096, Accuracy: 89.20000457763672
Epoch 142, Loss: 0.2851892411708832, Accuracy: 89.60000610351562
Epoch 143, Loss: 0.2715218961238861, Accuracy: 89.5
Epoch 144, Loss: 0.27047300338745117, Accuracy: 89.5
Epoch 145, Loss: 0.2762073278427124, Accuracy: 89.9000015258789
Epoch 146, Loss: 0.2776285707950592, Accuracy: 89.4000015258789
Epoch 147, Loss: 0.2684745788574219, Accuracy: 89.80000305175781
Epoch 148, Loss: 0.27494150400161743, Accuracy: 89.5
Epoch 149, Loss: 0.2800346612930298, Accuracy: 89.4000015258789
Epoch 150, Loss: 0.2729274332523346, Accuracy: 89.4000015258789
Epoch 151, Loss: 0.27144935727119446, Accuracy: 89.30000305175781
Epoch 152, Loss: 0.27532121539115906, Accuracy: 88.80000305175781
Epoch 153, Loss: 0.2702390253543854, Accuracy: 89.20000457763672
Epoch 154,

Epoch 273, Loss: 0.2783152759075165, Accuracy: 89.30000305175781
Epoch 274, Loss: 0.258018434047699, Accuracy: 89.60000610351562
Epoch 275, Loss: 0.2629663348197937, Accuracy: 90.0
Epoch 276, Loss: 0.25765326619148254, Accuracy: 88.80000305175781
Epoch 277, Loss: 0.26148754358291626, Accuracy: 89.9000015258789
Epoch 278, Loss: 0.2577957212924957, Accuracy: 89.5
Epoch 279, Loss: 0.2619819939136505, Accuracy: 90.10000610351562
Epoch 280, Loss: 0.25875288248062134, Accuracy: 89.5
Epoch 281, Loss: 0.2610906958580017, Accuracy: 89.30000305175781
Epoch 282, Loss: 0.2553490996360779, Accuracy: 89.80000305175781
Epoch 283, Loss: 0.2560213506221771, Accuracy: 89.60000610351562
Epoch 284, Loss: 0.26075226068496704, Accuracy: 89.0999984741211
Epoch 285, Loss: 0.25531405210494995, Accuracy: 89.60000610351562
Epoch 286, Loss: 0.25827452540397644, Accuracy: 89.0
Epoch 287, Loss: 0.2796210050582886, Accuracy: 89.5
Epoch 288, Loss: 0.25861552357673645, Accuracy: 89.70000457763672
Epoch 289, Loss: 0.25

Epoch 405, Loss: 0.2656220495700836, Accuracy: 89.0999984741211
Epoch 406, Loss: 0.2519455552101135, Accuracy: 89.70000457763672
Epoch 407, Loss: 0.2633967101573944, Accuracy: 89.5
Epoch 408, Loss: 0.25232142210006714, Accuracy: 89.60000610351562
Epoch 409, Loss: 0.25076499581336975, Accuracy: 89.5
Epoch 410, Loss: 0.255124032497406, Accuracy: 89.20000457763672
Epoch 411, Loss: 0.26122596859931946, Accuracy: 89.30000305175781
Epoch 412, Loss: 0.25289878249168396, Accuracy: 89.4000015258789
Epoch 413, Loss: 0.2568817734718323, Accuracy: 88.80000305175781
Epoch 414, Loss: 0.26908862590789795, Accuracy: 89.5
Epoch 415, Loss: 0.25708842277526855, Accuracy: 89.0
Epoch 416, Loss: 0.2565416097640991, Accuracy: 89.5
Epoch 417, Loss: 0.27228230237960815, Accuracy: 89.0999984741211
Epoch 418, Loss: 0.25938376784324646, Accuracy: 89.60000610351562
Epoch 419, Loss: 0.25909432768821716, Accuracy: 90.20000457763672
Epoch 420, Loss: 0.253929078578949, Accuracy: 89.30000305175781
Epoch 421, Loss: 0.26

Epoch 543, Loss: 0.2582841217517853, Accuracy: 88.9000015258789
Epoch 544, Loss: 0.25082927942276, Accuracy: 89.60000610351562
Epoch 545, Loss: 0.25685420632362366, Accuracy: 89.70000457763672
Epoch 546, Loss: 0.25051456689834595, Accuracy: 89.60000610351562
Epoch 547, Loss: 0.25266948342323303, Accuracy: 90.0
Epoch 548, Loss: 0.25296497344970703, Accuracy: 89.0
Epoch 549, Loss: 0.24923594295978546, Accuracy: 89.60000610351562
Epoch 550, Loss: 0.24972796440124512, Accuracy: 89.5
Epoch 551, Loss: 0.252422034740448, Accuracy: 89.5
Epoch 552, Loss: 0.2607933282852173, Accuracy: 89.9000015258789
Epoch 553, Loss: 0.2561110854148865, Accuracy: 90.0
Epoch 554, Loss: 0.2497560828924179, Accuracy: 89.0
Epoch 555, Loss: 0.25175195932388306, Accuracy: 89.60000610351562
Epoch 556, Loss: 0.2580801248550415, Accuracy: 89.4000015258789
Epoch 557, Loss: 0.2547471225261688, Accuracy: 90.0
Epoch 558, Loss: 0.2570285499095917, Accuracy: 90.10000610351562
Epoch 559, Loss: 0.2510332763195038, Accuracy: 89.

Epoch 675, Loss: 0.25192761421203613, Accuracy: 89.60000610351562
Epoch 676, Loss: 0.2556580603122711, Accuracy: 89.20000457763672
Epoch 677, Loss: 0.24969500303268433, Accuracy: 89.60000610351562
Epoch 678, Loss: 0.25199031829833984, Accuracy: 89.70000457763672
Epoch 679, Loss: 0.24843038618564606, Accuracy: 90.0
Epoch 680, Loss: 0.2651914954185486, Accuracy: 90.0
Epoch 681, Loss: 0.25097012519836426, Accuracy: 89.9000015258789
Epoch 682, Loss: 0.2522733211517334, Accuracy: 90.0
Epoch 683, Loss: 0.2548852562904358, Accuracy: 89.80000305175781
Epoch 684, Loss: 0.25323307514190674, Accuracy: 90.0
Epoch 685, Loss: 0.24832095205783844, Accuracy: 90.30000305175781
Epoch 686, Loss: 0.2589169144630432, Accuracy: 89.80000305175781
Epoch 687, Loss: 0.2578844428062439, Accuracy: 89.30000305175781
Epoch 688, Loss: 0.2550066411495209, Accuracy: 89.9000015258789
Epoch 689, Loss: 0.2592141032218933, Accuracy: 89.60000610351562
Epoch 690, Loss: 0.2567923665046692, Accuracy: 89.70000457763672
Epoch 6

Epoch 811, Loss: 0.244619220495224, Accuracy: 89.9000015258789
Epoch 812, Loss: 0.24550016224384308, Accuracy: 90.20000457763672
Epoch 813, Loss: 0.2472037822008133, Accuracy: 89.4000015258789
Epoch 814, Loss: 0.24512287974357605, Accuracy: 90.0
Epoch 815, Loss: 0.24675634503364563, Accuracy: 89.9000015258789
Epoch 816, Loss: 0.25812703371047974, Accuracy: 89.9000015258789
Epoch 817, Loss: 0.2512072026729584, Accuracy: 89.9000015258789
Epoch 818, Loss: 0.24463611841201782, Accuracy: 90.80000305175781
Epoch 819, Loss: 0.2567143738269806, Accuracy: 90.0
Epoch 820, Loss: 0.2440369427204132, Accuracy: 89.70000457763672
Epoch 821, Loss: 0.24872317910194397, Accuracy: 90.30000305175781
Epoch 822, Loss: 0.2529122233390808, Accuracy: 89.70000457763672
Epoch 823, Loss: 0.2531664967536926, Accuracy: 89.80000305175781
Epoch 824, Loss: 0.25154778361320496, Accuracy: 90.5
Epoch 825, Loss: 0.24884545803070068, Accuracy: 89.9000015258789
Epoch 826, Loss: 0.25010770559310913, Accuracy: 90.0
Epoch 827,

Epoch 943, Loss: 0.24946895241737366, Accuracy: 89.80000305175781
Epoch 944, Loss: 0.24338634312152863, Accuracy: 89.5
Epoch 945, Loss: 0.2416282445192337, Accuracy: 90.30000305175781
Epoch 946, Loss: 0.2539368271827698, Accuracy: 90.30000305175781
Epoch 947, Loss: 0.2467852383852005, Accuracy: 90.5
Epoch 948, Loss: 0.24637189507484436, Accuracy: 90.0
Epoch 949, Loss: 0.24351133406162262, Accuracy: 90.30000305175781
Epoch 950, Loss: 0.2426651120185852, Accuracy: 90.20000457763672
Epoch 951, Loss: 0.24671362340450287, Accuracy: 90.10000610351562
Epoch 952, Loss: 0.2435694932937622, Accuracy: 90.20000457763672
Epoch 953, Loss: 0.2590559720993042, Accuracy: 90.30000305175781
Epoch 954, Loss: 0.2435169667005539, Accuracy: 89.5
Epoch 955, Loss: 0.25162503123283386, Accuracy: 90.4000015258789
Epoch 956, Loss: 0.25222334265708923, Accuracy: 90.0
Epoch 957, Loss: 0.24862423539161682, Accuracy: 90.0
Epoch 958, Loss: 0.2449776828289032, Accuracy: 90.20000457763672
Epoch 959, Loss: 0.250814825296

## 데이터셋 및 학습 파라미터 저장

In [10]:
np.savez_compressed('ch2_dataset.npz', inputs=pts, labels=labels)

W_h, b_h = model.d1.get_weights()
W_o, b_o = model.d2.get_weights()
W_h = np.transpose(W_h)
W_o = np.transpose(W_o)
np.savez_compressed('ch2_parameters.npz', W_h=W_h, b_h=b_h, W_o=W_o, b_o=b_o)