# 경사 하강법을 이용한 얕은 신경망 학습


In [29]:
import tensorflow as tf
import numpy as np

## 하이퍼 파라미터 설정

In [30]:
EPOCHS=1000

## 네트워크 구조 정의
### 얕은 신경망
#### 입력 계층 : 2, 은닉 계층 : 128 (Sigmoid activation), 출력 계층 : 10 (Softmax activation)

In [39]:
class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel,self).__init__()
        self.d1 = tf.keras.layers.Dense(128,input_dim=2,activation='sigmoid')
        self.d2 = tf.keras.layers.Dense(10,activation='softmax')
        
        
    def call(self,x,traning=None,mask=None):
        x = self.d1(x)
        return self.d2(x)

## 학습 루프 정의

In [40]:
@tf.function
def train_step(model,inputs,labels,loss_object,optimizer,train_loss,train_metric):
    with tf.GradientTape() as tape:
        predictions = model(inputs)
        loss = loss_object(labels,predictions)
    gradients = tape.gradient(loss,model.trainable_variables) # grad(loss)
    
    optimizer.apply_gradients(zip(gradients,model.trainable_variables))
    train_loss(loss)
    train_metric(labels,predictions)

## 데이터셋 생성, 전처리

In [41]:
np.random.seed(0)

pts = list()
labels=list()
center_pts = np.random.uniform(-8.0,8.0,(10,2))
for label,center_pt in enumerate(center_pts):
    for _ in range(100):
        pts.append(center_pt + np.random.randn(*center_pt.shape))
        labels.append(label)
        
pts=np.stack(pts,axis=0).astype(np.float32)
labels=np.stack(labels,axis=0)

train_ds = tf.data.Dataset.from_tensor_slices((pts,labels)).shuffle(1000).batch(32)



## 모델 생성

In [42]:
model = MyModel()

## 손실 함수 및 최적화 알고리즘 설정
### CrossEntropy, Adam Optimizer

In [43]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()

## 평가 지표 설정
### Accuracy

In [44]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

## 학습 루프

In [47]:
for epoch in range(EPOCHS):
    for x,label in train_ds:
        train_step(model,x,label,loss_object,optimizer,train_loss,train_accuracy)
        
    template = "Epoch: {}, Loss:{}, Accuracy: {}"
    print(template.format(epoch+1,train_loss.result(),train_accuracy.result() * 100))

Epoch: 1, Loss:0.3196081519126892, Accuracy: 88.39398193359375
Epoch: 2, Loss:0.31945398449897766, Accuracy: 88.39586639404297
Epoch: 3, Loss:0.3192978799343109, Accuracy: 88.39775085449219
Epoch: 4, Loss:0.3191627264022827, Accuracy: 88.40078735351562
Epoch: 5, Loss:0.31900614500045776, Accuracy: 88.40194702148438
Epoch: 6, Loss:0.3188669681549072, Accuracy: 88.40380096435547
Epoch: 7, Loss:0.3187248706817627, Accuracy: 88.40634155273438
Epoch: 8, Loss:0.3185803294181824, Accuracy: 88.40933227539062
Epoch: 9, Loss:0.3184296786785126, Accuracy: 88.41207885742188
Epoch: 10, Loss:0.31828975677490234, Accuracy: 88.41365814208984
Epoch: 11, Loss:0.3181475102901459, Accuracy: 88.4163818359375
Epoch: 12, Loss:0.31799986958503723, Accuracy: 88.4197769165039
Epoch: 13, Loss:0.31786254048347473, Accuracy: 88.42133331298828
Epoch: 14, Loss:0.31770390272140503, Accuracy: 88.42333221435547
Epoch: 15, Loss:0.31755930185317993, Accuracy: 88.42555236816406
Epoch: 16, Loss:0.3174108564853668, Accuracy

Epoch: 132, Loss:0.30442556738853455, Accuracy: 88.63626861572266
Epoch: 133, Loss:0.30433517694473267, Accuracy: 88.63763427734375
Epoch: 134, Loss:0.30423593521118164, Accuracy: 88.63971710205078
Epoch: 135, Loss:0.3041588366031647, Accuracy: 88.6403579711914
Epoch: 136, Loss:0.30406612157821655, Accuracy: 88.64189147949219
Epoch: 137, Loss:0.3039731979370117, Accuracy: 88.64448547363281
Epoch: 138, Loss:0.3038717806339264, Accuracy: 88.64600372314453
Epoch: 139, Loss:0.30378058552742004, Accuracy: 88.64823150634766
Epoch: 140, Loss:0.30368736386299133, Accuracy: 88.64814758300781
Epoch: 141, Loss:0.3036099076271057, Accuracy: 88.64982604980469
Epoch: 142, Loss:0.30352434515953064, Accuracy: 88.65150451660156
Epoch: 143, Loss:0.3034300208091736, Accuracy: 88.65352630615234
Epoch: 144, Loss:0.30334001779556274, Accuracy: 88.65396118164062
Epoch: 145, Loss:0.30324429273605347, Accuracy: 88.65597534179688
Epoch: 146, Loss:0.3031562864780426, Accuracy: 88.65675354003906
Epoch: 147, Loss:

Epoch: 260, Loss:0.29465264081954956, Accuracy: 88.81028747558594
Epoch: 261, Loss:0.29458242654800415, Accuracy: 88.81187438964844
Epoch: 262, Loss:0.2945190966129303, Accuracy: 88.81346130371094
Epoch: 263, Loss:0.2944512665271759, Accuracy: 88.81431579589844
Epoch: 264, Loss:0.2943805456161499, Accuracy: 88.81501770019531
Epoch: 265, Loss:0.29430943727493286, Accuracy: 88.81644439697266
Epoch: 266, Loss:0.2942390441894531, Accuracy: 88.81714630126953
Epoch: 267, Loss:0.29417696595191956, Accuracy: 88.81827545166016
Epoch: 268, Loss:0.2941246032714844, Accuracy: 88.81983947753906
Epoch: 269, Loss:0.29406842589378357, Accuracy: 88.82052612304688
Epoch: 270, Loss:0.2940070927143097, Accuracy: 88.82222747802734
Epoch: 271, Loss:0.29393696784973145, Accuracy: 88.82377624511719
Epoch: 272, Loss:0.29387497901916504, Accuracy: 88.8253173828125
Epoch: 273, Loss:0.2938069999217987, Accuracy: 88.82628631591797
Epoch: 274, Loss:0.29374170303344727, Accuracy: 88.82782745361328
Epoch: 275, Loss:0

Epoch: 388, Loss:0.28760117292404175, Accuracy: 88.95275115966797
Epoch: 389, Loss:0.2875581979751587, Accuracy: 88.95391082763672
Epoch: 390, Loss:0.28751441836357117, Accuracy: 88.95494842529297
Epoch: 391, Loss:0.2874665856361389, Accuracy: 88.95635986328125
Epoch: 392, Loss:0.2874143421649933, Accuracy: 88.95690155029297
Epoch: 393, Loss:0.2873600125312805, Accuracy: 88.95744323730469
Epoch: 394, Loss:0.28730714321136475, Accuracy: 88.95834350585938
Epoch: 395, Loss:0.28725624084472656, Accuracy: 88.9587631225586
Epoch: 396, Loss:0.28722086548805237, Accuracy: 88.96039581298828
Epoch: 397, Loss:0.2871691584587097, Accuracy: 88.96117401123047
Epoch: 398, Loss:0.28711628913879395, Accuracy: 88.96243286132812
Epoch: 399, Loss:0.28707605600357056, Accuracy: 88.96296691894531
Epoch: 400, Loss:0.2870256304740906, Accuracy: 88.9639892578125
Epoch: 401, Loss:0.2869834899902344, Accuracy: 88.96487426757812
Epoch: 402, Loss:0.2869308590888977, Accuracy: 88.96588897705078
Epoch: 403, Loss:0.2

Epoch: 516, Loss:0.28222227096557617, Accuracy: 89.0771713256836
Epoch: 517, Loss:0.28217971324920654, Accuracy: 89.0776138305664
Epoch: 518, Loss:0.28213635087013245, Accuracy: 89.07881164550781
Epoch: 519, Loss:0.2821028530597687, Accuracy: 89.07936096191406
Epoch: 520, Loss:0.2820621430873871, Accuracy: 89.08075714111328
Epoch: 521, Loss:0.2820313572883606, Accuracy: 89.08141326904297
Epoch: 522, Loss:0.28199437260627747, Accuracy: 89.082275390625
Epoch: 523, Loss:0.2819633483886719, Accuracy: 89.083251953125
Epoch: 524, Loss:0.28192904591560364, Accuracy: 89.08368682861328
Epoch: 525, Loss:0.2818922698497772, Accuracy: 89.08433532714844
Epoch: 526, Loss:0.2818589508533478, Accuracy: 89.08529663085938
Epoch: 527, Loss:0.2818200886249542, Accuracy: 89.08625793457031
Epoch: 528, Loss:0.28178316354751587, Accuracy: 89.08711242675781
Epoch: 529, Loss:0.281747967004776, Accuracy: 89.08786010742188
Epoch: 530, Loss:0.28170517086982727, Accuracy: 89.08860778808594
Epoch: 531, Loss:0.281670

Epoch: 645, Loss:0.27797752618789673, Accuracy: 89.1906967163086
Epoch: 646, Loss:0.27794337272644043, Accuracy: 89.19145202636719
Epoch: 647, Loss:0.2779146432876587, Accuracy: 89.19276428222656
Epoch: 648, Loss:0.2778853178024292, Accuracy: 89.19379425048828
Epoch: 649, Loss:0.2778635621070862, Accuracy: 89.19502258300781
Epoch: 650, Loss:0.2778308391571045, Accuracy: 89.19567108154297
Epoch: 651, Loss:0.2778032124042511, Accuracy: 89.19651794433594
Epoch: 652, Loss:0.2777690589427948, Accuracy: 89.19734954833984
Epoch: 653, Loss:0.2777363955974579, Accuracy: 89.19828033447266
Epoch: 654, Loss:0.2777126133441925, Accuracy: 89.19902801513672
Epoch: 655, Loss:0.2776893675327301, Accuracy: 89.19985961914062
Epoch: 656, Loss:0.2776576578617096, Accuracy: 89.20124816894531
Epoch: 657, Loss:0.2776326537132263, Accuracy: 89.2021713256836
Epoch: 658, Loss:0.27760541439056396, Accuracy: 89.20235443115234
Epoch: 659, Loss:0.2775843143463135, Accuracy: 89.2029037475586
Epoch: 660, Loss:0.277552

Epoch: 774, Loss:0.27453047037124634, Accuracy: 89.2917251586914
Epoch: 775, Loss:0.27450063824653625, Accuracy: 89.29248046875
Epoch: 776, Loss:0.274469792842865, Accuracy: 89.29348754882812
Epoch: 777, Loss:0.2744424641132355, Accuracy: 89.29448699951172
Epoch: 778, Loss:0.2744203209877014, Accuracy: 89.29507446289062
Epoch: 779, Loss:0.27439382672309875, Accuracy: 89.29549407958984
Epoch: 780, Loss:0.2743760645389557, Accuracy: 89.29624938964844
Epoch: 781, Loss:0.2743563652038574, Accuracy: 89.29666137695312
Epoch: 782, Loss:0.27432990074157715, Accuracy: 89.29774475097656
Epoch: 783, Loss:0.27430957555770874, Accuracy: 89.29840850830078
Epoch: 784, Loss:0.2742873728275299, Accuracy: 89.299072265625
Epoch: 785, Loss:0.2742636203765869, Accuracy: 89.29924011230469
Epoch: 786, Loss:0.2742382884025574, Accuracy: 89.29998016357422
Epoch: 787, Loss:0.27421510219573975, Accuracy: 89.30097198486328
Epoch: 788, Loss:0.2741906940937042, Accuracy: 89.30220794677734
Epoch: 789, Loss:0.2741665

Epoch: 903, Loss:0.27154338359832764, Accuracy: 89.38259887695312
Epoch: 904, Loss:0.27152398228645325, Accuracy: 89.38291931152344
Epoch: 905, Loss:0.271501362323761, Accuracy: 89.38346099853516
Epoch: 906, Loss:0.27147752046585083, Accuracy: 89.38385009765625
Epoch: 907, Loss:0.2714542746543884, Accuracy: 89.38461303710938
Epoch: 908, Loss:0.271432489156723, Accuracy: 89.38529205322266
Epoch: 909, Loss:0.27140870690345764, Accuracy: 89.38575744628906
Epoch: 910, Loss:0.27138757705688477, Accuracy: 89.38652038574219
Epoch: 911, Loss:0.2713712453842163, Accuracy: 89.38690185546875
Epoch: 912, Loss:0.27134931087493896, Accuracy: 89.38780975341797
Epoch: 913, Loss:0.27132540941238403, Accuracy: 89.38878631591797
Epoch: 914, Loss:0.27130138874053955, Accuracy: 89.38946533203125
Epoch: 915, Loss:0.271278440952301, Accuracy: 89.38984680175781
Epoch: 916, Loss:0.271254301071167, Accuracy: 89.3901596069336
Epoch: 917, Loss:0.2712303698062897, Accuracy: 89.39068603515625
Epoch: 918, Loss:0.271

## 데이터셋 및 학습 파라미터 저장

In [48]:
np.savez_compressed('ch2_dataset.npz',inputs=pts,labels=labels)

w_h,b_h = model.d1.get_weights()
w_o,b_o = model.d2.get_weights()
w_h = np.transpose(w_h)
w_o = np.transpose(w_o)
np.savez_compressed('ch2_parameters.npz',
                   w_h=w_h,
                   b_h=b_h,
                   w_o=w_o,
                   b_o=b_o)