# 경사 하강법을 이용한 얕은 신경망 학습


In [17]:
import tensorflow as tf
import numpy as np

## 하이퍼 파라미터 설정

In [18]:
EPOCH = 1000

## 네트워크 구조 정의
### 얕은 신경망
#### 입력 계층 : 2, 은닉 계층 : 128 (Sigmoid activation), 출력 계층 : 10 (Softmax activation)

In [19]:
class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel,self).__init__()
        self.d1 = tf.keras.layers.Dense(128, input_dim =2, activation='sigmoid')
        self.d2 = tf.keras.layers.Dense(10, activation='softmax')
    def call(self,x, traning = None, mask = None):
        x = self.d1(x)
        return self.d2(x)

## 학습 루프 정의

In [20]:
@tf.function
def train_step(model, inputs, labels, loss_object, optimizer, train_loss, train_metric):
    with tf.GradientTape() as tape:
        predictions = model(inputs)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    
    optimizer.apply_gradients(zip(gradients,model.trainable_variables))
    train_loss(loss)
    train_metric(labels, predictions)

## 데이터셋 생성, 전처리

In [21]:
np.random.seed(0)

pts = list()
labels = list()
center_pts = np.random.uniform(-8.0, 8.0, (10,2))
for label , center_pt in enumerate(center_pts):
    for _ in range(100):
        pts.append(center_pt + np.random.randn(*center_pt.shape))
        labels.append(label)
        
pts = np.stack(pts, axis=0).astype(np.float32) 
labels = np.stack(labels, axis=0)

train_ds = tf.data.Dataset.from_tensor_slices((pts, labels)).shuffle(1000).batch(32)

## 모델 생성

In [22]:
model = MyModel()

## 손실 함수 및 최적화 알고리즘 설정
### CrossEntropy, Adam Optimizer

In [23]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()

## 평가 지표 설정
### Accuracy

In [24]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

## 학습 루프

In [26]:
for epoch in range(EPOCH):
    for x, label in train_ds:
        train_step(model, x, label, loss_object, optimizer, train_loss, train_accuracy)
        
    template = 'EPoch {}, Loss: {}, Accuracy: {}'
    print(template.format(epoch + 1, train_loss.result(), train_accuracy.result()*100))
        

EPoch 1, Loss: 0.2989155054092407, Accuracy: 88.81913757324219
EPoch 2, Loss: 0.2988480031490326, Accuracy: 88.81974792480469
EPoch 3, Loss: 0.29876944422721863, Accuracy: 88.82083892822266
EPoch 4, Loss: 0.2986966073513031, Accuracy: 88.8219223022461
EPoch 5, Loss: 0.29863452911376953, Accuracy: 88.82316589355469
EPoch 6, Loss: 0.2985582947731018, Accuracy: 88.82503509521484
EPoch 7, Loss: 0.29847726225852966, Accuracy: 88.82626342773438
EPoch 8, Loss: 0.29840192198753357, Accuracy: 88.82733917236328
EPoch 9, Loss: 0.2983386218547821, Accuracy: 88.82950592041016
EPoch 10, Loss: 0.2982668876647949, Accuracy: 88.83087921142578
EPoch 11, Loss: 0.2981865108013153, Accuracy: 88.83256530761719
EPoch 12, Loss: 0.29811447858810425, Accuracy: 88.83409118652344
EPoch 13, Loss: 0.29804185032844543, Accuracy: 88.83451080322266
EPoch 14, Loss: 0.2979723811149597, Accuracy: 88.8357162475586
EPoch 15, Loss: 0.2979041337966919, Accuracy: 88.83769989013672
EPoch 16, Loss: 0.2978273034095764, Accuracy:

EPoch 138, Loss: 0.2903607189655304, Accuracy: 88.98334503173828
EPoch 139, Loss: 0.290304571390152, Accuracy: 88.98455047607422
EPoch 140, Loss: 0.29024407267570496, Accuracy: 88.98548126220703
EPoch 141, Loss: 0.29018649458885193, Accuracy: 88.98667907714844
EPoch 142, Loss: 0.2901311218738556, Accuracy: 88.98760986328125
EPoch 143, Loss: 0.29007789492607117, Accuracy: 88.98880767822266
EPoch 144, Loss: 0.29002630710601807, Accuracy: 88.9896011352539
EPoch 145, Loss: 0.28997603058815, Accuracy: 88.99040222167969
EPoch 146, Loss: 0.28992271423339844, Accuracy: 88.99223327636719
EPoch 147, Loss: 0.2898798882961273, Accuracy: 88.99315643310547
EPoch 148, Loss: 0.2898269593715668, Accuracy: 88.99472045898438
EPoch 149, Loss: 0.28978174924850464, Accuracy: 88.9951171875
EPoch 150, Loss: 0.2897275388240814, Accuracy: 88.99628448486328
EPoch 151, Loss: 0.2896767556667328, Accuracy: 88.99745178222656
EPoch 152, Loss: 0.2896342873573303, Accuracy: 88.99810028076172
EPoch 153, Loss: 0.28958165

EPoch 275, Loss: 0.28404858708381653, Accuracy: 89.13050842285156
EPoch 276, Loss: 0.28400444984436035, Accuracy: 89.13136291503906
EPoch 277, Loss: 0.2839622497558594, Accuracy: 89.13243865966797
EPoch 278, Loss: 0.2839326858520508, Accuracy: 89.13329315185547
EPoch 279, Loss: 0.2838897109031677, Accuracy: 89.13447570800781
EPoch 280, Loss: 0.283848375082016, Accuracy: 89.13565826416016
EPoch 281, Loss: 0.28381046652793884, Accuracy: 89.13705444335938
EPoch 282, Loss: 0.28377246856689453, Accuracy: 89.13778686523438
EPoch 283, Loss: 0.28373393416404724, Accuracy: 89.13929748535156
EPoch 284, Loss: 0.28369632363319397, Accuracy: 89.1399154663086
EPoch 285, Loss: 0.2836596965789795, Accuracy: 89.1409683227539
EPoch 286, Loss: 0.2836238741874695, Accuracy: 89.14213562011719
EPoch 287, Loss: 0.2835780680179596, Accuracy: 89.1429672241211
EPoch 288, Loss: 0.28353220224380493, Accuracy: 89.14424133300781
EPoch 289, Loss: 0.2834983766078949, Accuracy: 89.1455078125
EPoch 290, Loss: 0.2834666

EPoch 412, Loss: 0.27917495369911194, Accuracy: 89.27243041992188
EPoch 413, Loss: 0.27914267778396606, Accuracy: 89.27303314208984
EPoch 414, Loss: 0.2791045308113098, Accuracy: 89.27364349365234
EPoch 415, Loss: 0.2790733575820923, Accuracy: 89.2744369506836
EPoch 416, Loss: 0.2790425717830658, Accuracy: 89.27552032470703
EPoch 417, Loss: 0.27901408076286316, Accuracy: 89.27670288085938
EPoch 418, Loss: 0.27897730469703674, Accuracy: 89.27806854248047
EPoch 419, Loss: 0.2789401412010193, Accuracy: 89.27914428710938
EPoch 420, Loss: 0.2789035439491272, Accuracy: 89.2799301147461
EPoch 421, Loss: 0.27886849641799927, Accuracy: 89.28109741210938
EPoch 422, Loss: 0.2788299322128296, Accuracy: 89.281982421875
EPoch 423, Loss: 0.27880293130874634, Accuracy: 89.2828598022461
EPoch 424, Loss: 0.27876728773117065, Accuracy: 89.28382873535156
EPoch 425, Loss: 0.27873194217681885, Accuracy: 89.28489685058594
EPoch 426, Loss: 0.278699666261673, Accuracy: 89.2857666015625
EPoch 427, Loss: 0.27866

EPoch 551, Loss: 0.27508753538131714, Accuracy: 89.39354705810547
EPoch 552, Loss: 0.2750598192214966, Accuracy: 89.39449310302734
EPoch 553, Loss: 0.27503547072410583, Accuracy: 89.39517211914062
EPoch 554, Loss: 0.2750091850757599, Accuracy: 89.39628601074219
EPoch 555, Loss: 0.27498069405555725, Accuracy: 89.39679718017578
EPoch 556, Loss: 0.27495068311691284, Accuracy: 89.39730834960938
EPoch 557, Loss: 0.2749202251434326, Accuracy: 89.39815521240234
EPoch 558, Loss: 0.2748911380767822, Accuracy: 89.39926147460938
EPoch 559, Loss: 0.2748686671257019, Accuracy: 89.39994049072266
EPoch 560, Loss: 0.2748411297798157, Accuracy: 89.40052795410156
EPoch 561, Loss: 0.2748143672943115, Accuracy: 89.40087127685547
EPoch 562, Loss: 0.2747957706451416, Accuracy: 89.40162658691406
EPoch 563, Loss: 0.2747671902179718, Accuracy: 89.40221405029297
EPoch 564, Loss: 0.27473610639572144, Accuracy: 89.4029769897461
EPoch 565, Loss: 0.27471327781677246, Accuracy: 89.40355682373047
EPoch 566, Loss: 0.2

EPoch 677, Loss: 0.27196207642555237, Accuracy: 89.4866714477539
EPoch 678, Loss: 0.27194491028785706, Accuracy: 89.48744201660156
EPoch 679, Loss: 0.27192097902297974, Accuracy: 89.48853302001953
EPoch 680, Loss: 0.27190300822257996, Accuracy: 89.48892211914062
EPoch 681, Loss: 0.2718813121318817, Accuracy: 89.49000549316406
EPoch 682, Loss: 0.27185943722724915, Accuracy: 89.49070739746094
EPoch 683, Loss: 0.27183404564857483, Accuracy: 89.49156188964844
EPoch 684, Loss: 0.27180859446525574, Accuracy: 89.49217224121094
EPoch 685, Loss: 0.2717824876308441, Accuracy: 89.49317169189453
EPoch 686, Loss: 0.27175968885421753, Accuracy: 89.4938735961914
EPoch 687, Loss: 0.27174821496009827, Accuracy: 89.49417877197266
EPoch 688, Loss: 0.2717244327068329, Accuracy: 89.4950180053711
EPoch 689, Loss: 0.27170467376708984, Accuracy: 89.49555969238281
EPoch 690, Loss: 0.27167969942092896, Accuracy: 89.49639892578125
EPoch 691, Loss: 0.2716551721096039, Accuracy: 89.49700927734375
EPoch 692, Loss: 

EPoch 815, Loss: 0.2691337764263153, Accuracy: 89.57330322265625
EPoch 816, Loss: 0.26911458373069763, Accuracy: 89.57373809814453
EPoch 817, Loss: 0.2690919041633606, Accuracy: 89.57438659667969
EPoch 818, Loss: 0.2690703868865967, Accuracy: 89.57516479492188
EPoch 819, Loss: 0.26904869079589844, Accuracy: 89.57573699951172
EPoch 820, Loss: 0.2690289318561554, Accuracy: 89.57624053955078
EPoch 821, Loss: 0.2690091133117676, Accuracy: 89.57667541503906
EPoch 822, Loss: 0.2689877450466156, Accuracy: 89.57723999023438
EPoch 823, Loss: 0.26897522807121277, Accuracy: 89.57746887207031
EPoch 824, Loss: 0.2689516246318817, Accuracy: 89.57831573486328
EPoch 825, Loss: 0.26893216371536255, Accuracy: 89.57874298095703
EPoch 826, Loss: 0.2689095139503479, Accuracy: 89.57917022705078
EPoch 827, Loss: 0.2688864469528198, Accuracy: 89.58000946044922
EPoch 828, Loss: 0.2688671946525574, Accuracy: 89.58092498779297
EPoch 829, Loss: 0.26884493231773376, Accuracy: 89.58155822753906
EPoch 830, Loss: 0.2

EPoch 954, Loss: 0.2666472792625427, Accuracy: 89.64689636230469
EPoch 955, Loss: 0.26662713289260864, Accuracy: 89.64749908447266
EPoch 956, Loss: 0.26660963892936707, Accuracy: 89.64790344238281
EPoch 957, Loss: 0.26659122109413147, Accuracy: 89.64857482910156
EPoch 958, Loss: 0.2665799558162689, Accuracy: 89.64898681640625
EPoch 959, Loss: 0.2665631175041199, Accuracy: 89.64945983886719
EPoch 960, Loss: 0.26654693484306335, Accuracy: 89.64981079101562
EPoch 961, Loss: 0.26653149724006653, Accuracy: 89.6506576538086
EPoch 962, Loss: 0.2665177285671234, Accuracy: 89.65094757080078
EPoch 963, Loss: 0.26649871468544006, Accuracy: 89.65166473388672
EPoch 964, Loss: 0.26648423075675964, Accuracy: 89.6520767211914
EPoch 965, Loss: 0.2664696276187897, Accuracy: 89.6524887084961
EPoch 966, Loss: 0.26645562052726746, Accuracy: 89.65258026123047
EPoch 967, Loss: 0.2664383351802826, Accuracy: 89.65311431884766
EPoch 968, Loss: 0.26641935110092163, Accuracy: 89.6537094116211
EPoch 969, Loss: 0.2

## 데이터셋 및 학습 파라미터 저장

In [27]:
np.savez_compressed('ch2_dataset.npz',inputs=pts, labels = labels)

W_h, b_h = model.d1.get_weights()
W_o, b_o = model.d2.get_weights()

W_h = np.transpose(W_h)
W_o = np.transpose(W_o)
np.savez_compressed('ch2_parameters.npz',
                   W_h = W_h,b_h = b_h,
                    W_o = W_o, b_o=b_o
                   )