# 경사 하강법을 이용한 얕은 신경망 학습


In [1]:
import tensorflow as tf
import numpy as np

## 하이퍼 파라미터 설정

In [2]:
EPOCHS = 1000

## 네트워크 구조 정의
### 얕은 신경망
#### 입력 계층 : 2, 은닉 계층 : 128 (Sigmoid activation), 출력 계층 : 10 (Softmax activation)

In [3]:
class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.d1 = tf.keras.layers.Dense(128, input_dim=2, activation='sigmoid')
        self.d2 = tf.keras.layers.Dense(10, input_dim=2, activation='softmax')
        
    def call(self, x, training=None, mask=None):
        x = self.d1(x)
        return self.d2(x)

## 학습 루프 정의

In [4]:
@tf.function
def train_step(model, inputs, labels, loss_object, optimizer, train_loss, train_metric):
    with tf.GradientTape() as tape: # 내부에서 계산되는 Gradient는 계산되고 된 내용을 tape로 저장한다.
        predictions = model(inputs)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables) #df(x)/dx
    #loss는  각 model.trainable_variables 편미분해서 gradent를 계산한다.
    
    #편미분한 결과를 gradients에 적용이되서 optimizer에 학습이 되는 과정
    optimizer.apply_gradients(zip(gradients, model.trainable_variables)) 
    train_loss(loss)
    train_metric(labels, predictions) # 학습한 결과(acc, 정확도)
        

## 데이터셋 생성, 전처리

In [5]:
np.random.seed(0) # 랜덤값을 동일하게 나옴

pts = list() #입력값
labels = list() # 출력값
center_pts = np.random.uniform(-8.0, 8.0, (10, 2))
for label, center_pt in enumerate(center_pts):
    for _ in range(100):
        pts.append(center_pt + np.random.randn(*center_pt.shape))
        labels.append(label)

pts = np.stack(pts, axis=0).astype(np.float32)
labels = np.stack(labels, axis=0)

train_ds = tf.data.Dataset.from_tensor_slices((pts, labels)).shuffle(1000).batch(32)

## 모델 생성

In [6]:
model = MyModel()

## 손실 함수 및 최적화 알고리즘 설정
### CrossEntropy, Adam Optimizer

In [7]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy() #몇번째 1인가 인덱스 번호를 나타낸다.
optimizer = tf.keras.optimizers.Adam()

## 평가 지표 설정
### Accuracy

In [8]:
train_loss = tf.keras.metrics.Mean(name ='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

## 학습 루프

In [9]:
for epoch in range(EPOCHS):
    for x, label in train_ds:
        train_step(model, x, label, loss_object, optimizer, train_loss, train_accuracy)
        
    template = 'Epoch: {}, Loss: {}, Accuracy: {}'
    print(template.format( epoch +1,
                            train_loss.result(),
                            train_accuracy.result() * 100))

Epoch: 1, Loss: 2.2107772827148438, Accuracy: 17.5
Epoch: 2, Loss: 2.0359303951263428, Accuracy: 31.25
Epoch: 3, Loss: 1.9011949300765991, Accuracy: 41.80000305175781
Epoch: 4, Loss: 1.7886520624160767, Accuracy: 45.375
Epoch: 5, Loss: 1.6955524682998657, Accuracy: 49.86000061035156
Epoch: 6, Loss: 1.6181648969650269, Accuracy: 52.79999923706055
Epoch: 7, Loss: 1.5507185459136963, Accuracy: 55.9714241027832
Epoch: 8, Loss: 1.4894002676010132, Accuracy: 58.912498474121094
Epoch: 9, Loss: 1.434766173362732, Accuracy: 61.333335876464844
Epoch: 10, Loss: 1.3851691484451294, Accuracy: 63.44000244140625
Epoch: 11, Loss: 1.339597225189209, Accuracy: 65.30000305175781
Epoch: 12, Loss: 1.2982627153396606, Accuracy: 66.82500457763672
Epoch: 13, Loss: 1.2603131532669067, Accuracy: 68.26154327392578
Epoch: 14, Loss: 1.2255849838256836, Accuracy: 69.50714111328125
Epoch: 15, Loss: 1.1930230855941772, Accuracy: 70.61332702636719
Epoch: 16, Loss: 1.1621571779251099, Accuracy: 71.55000305175781
Epoch:

Epoch: 138, Loss: 0.44521045684814453, Accuracy: 86.5927505493164
Epoch: 139, Loss: 0.4440793991088867, Accuracy: 86.61151123046875
Epoch: 140, Loss: 0.442891925573349, Accuracy: 86.62786102294922
Epoch: 141, Loss: 0.44167694449424744, Accuracy: 86.64396667480469
Epoch: 142, Loss: 0.44047486782073975, Accuracy: 86.66338348388672
Epoch: 143, Loss: 0.43928438425064087, Accuracy: 86.68181610107422
Epoch: 144, Loss: 0.4381304979324341, Accuracy: 86.69999694824219
Epoch: 145, Loss: 0.4369945526123047, Accuracy: 86.71793365478516
Epoch: 146, Loss: 0.43584275245666504, Accuracy: 86.73287963867188
Epoch: 147, Loss: 0.4347047209739685, Accuracy: 86.75102233886719
Epoch: 148, Loss: 0.4336088001728058, Accuracy: 86.7641830444336
Epoch: 149, Loss: 0.4325118660926819, Accuracy: 86.7832260131836
Epoch: 150, Loss: 0.43142715096473694, Accuracy: 86.79800415039062
Epoch: 151, Loss: 0.43037712574005127, Accuracy: 86.8099365234375
Epoch: 152, Loss: 0.429314523935318, Accuracy: 86.82500457763672
Epoch: 15

Epoch: 267, Loss: 0.3593278229236603, Accuracy: 87.83370971679688
Epoch: 268, Loss: 0.35899344086647034, Accuracy: 87.83880615234375
Epoch: 269, Loss: 0.3586233854293823, Accuracy: 87.84461212158203
Epoch: 270, Loss: 0.35828450322151184, Accuracy: 87.84925842285156
Epoch: 271, Loss: 0.3579125702381134, Accuracy: 87.85387420654297
Epoch: 272, Loss: 0.3575493097305298, Accuracy: 87.85698699951172
Epoch: 273, Loss: 0.3572101593017578, Accuracy: 87.86336517333984
Epoch: 274, Loss: 0.3568713068962097, Accuracy: 87.86751556396484
Epoch: 275, Loss: 0.3565399944782257, Accuracy: 87.87200164794922
Epoch: 276, Loss: 0.35621553659439087, Accuracy: 87.87572479248047
Epoch: 277, Loss: 0.3558514714241028, Accuracy: 87.8801498413086
Epoch: 278, Loss: 0.35552483797073364, Accuracy: 87.88633728027344
Epoch: 279, Loss: 0.35518893599510193, Accuracy: 87.89282989501953
Epoch: 280, Loss: 0.3548983037471771, Accuracy: 87.89785766601562
Epoch: 281, Loss: 0.35457882285118103, Accuracy: 87.90142822265625
Epoch

Epoch: 397, Loss: 0.3269120156764984, Accuracy: 88.29244232177734
Epoch: 398, Loss: 0.3267514109611511, Accuracy: 88.29573059082031
Epoch: 399, Loss: 0.32657545804977417, Accuracy: 88.29774475097656
Epoch: 400, Loss: 0.32640379667282104, Accuracy: 88.29900360107422
Epoch: 401, Loss: 0.3262465000152588, Accuracy: 88.30175018310547
Epoch: 402, Loss: 0.32607564330101013, Accuracy: 88.30496978759766
Epoch: 403, Loss: 0.325899600982666, Accuracy: 88.30644989013672
Epoch: 404, Loss: 0.32573065161705017, Accuracy: 88.30816650390625
Epoch: 405, Loss: 0.3255542814731598, Accuracy: 88.31135559082031
Epoch: 406, Loss: 0.3253750503063202, Accuracy: 88.31281280517578
Epoch: 407, Loss: 0.3252173066139221, Accuracy: 88.31498718261719
Epoch: 408, Loss: 0.325057715177536, Accuracy: 88.31764221191406
Epoch: 409, Loss: 0.3248845636844635, Accuracy: 88.3207778930664
Epoch: 410, Loss: 0.3247022032737732, Accuracy: 88.32366180419922
Epoch: 411, Loss: 0.3245413601398468, Accuracy: 88.32627868652344
Epoch: 41

Epoch: 522, Loss: 0.30992960929870605, Accuracy: 88.55459594726562
Epoch: 523, Loss: 0.3098350465297699, Accuracy: 88.55679321289062
Epoch: 524, Loss: 0.3097262680530548, Accuracy: 88.55802154541016
Epoch: 525, Loss: 0.3096121549606323, Accuracy: 88.55905151367188
Epoch: 526, Loss: 0.3095170855522156, Accuracy: 88.5606460571289
Epoch: 527, Loss: 0.3094134032726288, Accuracy: 88.56128692626953
Epoch: 528, Loss: 0.3093212842941284, Accuracy: 88.56382751464844
Epoch: 529, Loss: 0.3092156946659088, Accuracy: 88.56427001953125
Epoch: 530, Loss: 0.30912140011787415, Accuracy: 88.56584930419922
Epoch: 531, Loss: 0.30902084708213806, Accuracy: 88.56761169433594
Epoch: 532, Loss: 0.3089268207550049, Accuracy: 88.57048797607422
Epoch: 533, Loss: 0.3088375926017761, Accuracy: 88.57241821289062
Epoch: 534, Loss: 0.30873027443885803, Accuracy: 88.57396697998047
Epoch: 535, Loss: 0.3086338937282562, Accuracy: 88.57588958740234
Epoch: 536, Loss: 0.3085242211818695, Accuracy: 88.5766830444336
Epoch: 5

Epoch: 648, Loss: 0.29907411336898804, Accuracy: 88.73981475830078
Epoch: 649, Loss: 0.29900139570236206, Accuracy: 88.74129486083984
Epoch: 650, Loss: 0.2989378571510315, Accuracy: 88.74246215820312
Epoch: 651, Loss: 0.2988603413105011, Accuracy: 88.74408721923828
Epoch: 652, Loss: 0.2987823784351349, Accuracy: 88.74524688720703
Epoch: 653, Loss: 0.2987108826637268, Accuracy: 88.74609375
Epoch: 654, Loss: 0.29865026473999023, Accuracy: 88.74785614013672
Epoch: 655, Loss: 0.2985716760158539, Accuracy: 88.74900817871094
Epoch: 656, Loss: 0.2984974682331085, Accuracy: 88.75091552734375
Epoch: 657, Loss: 0.29842495918273926, Accuracy: 88.75235748291016
Epoch: 658, Loss: 0.2983741760253906, Accuracy: 88.75379943847656
Epoch: 659, Loss: 0.2982981503009796, Accuracy: 88.75493621826172
Epoch: 660, Loss: 0.29822081327438354, Accuracy: 88.75666809082031
Epoch: 661, Loss: 0.2981449365615845, Accuracy: 88.75763702392578
Epoch: 662, Loss: 0.29807132482528687, Accuracy: 88.75785827636719
Epoch: 663

Epoch: 773, Loss: 0.2913610637187958, Accuracy: 88.89146423339844
Epoch: 774, Loss: 0.29130426049232483, Accuracy: 88.89224243164062
Epoch: 775, Loss: 0.29125142097473145, Accuracy: 88.89432525634766
Epoch: 776, Loss: 0.2912086546421051, Accuracy: 88.8953628540039
Epoch: 777, Loss: 0.29115724563598633, Accuracy: 88.89601135253906
Epoch: 778, Loss: 0.2910979688167572, Accuracy: 88.89665222167969
Epoch: 779, Loss: 0.29103884100914, Accuracy: 88.89768981933594
Epoch: 780, Loss: 0.2909887731075287, Accuracy: 88.89910125732422
Epoch: 781, Loss: 0.2909320294857025, Accuracy: 88.9000015258789
Epoch: 782, Loss: 0.2908782362937927, Accuracy: 88.9010238647461
Epoch: 783, Loss: 0.2908230125904083, Accuracy: 88.90204620361328
Epoch: 784, Loss: 0.29078081250190735, Accuracy: 88.90306091308594
Epoch: 785, Loss: 0.29072681069374084, Accuracy: 88.90433502197266
Epoch: 786, Loss: 0.2906731963157654, Accuracy: 88.90547180175781
Epoch: 787, Loss: 0.29062145948410034, Accuracy: 88.90647888183594
Epoch: 78

Epoch: 901, Loss: 0.28546103835105896, Accuracy: 89.02464294433594
Epoch: 902, Loss: 0.28542062640190125, Accuracy: 89.02505493164062
Epoch: 903, Loss: 0.2853851616382599, Accuracy: 89.02558135986328
Epoch: 904, Loss: 0.2853495180606842, Accuracy: 89.0268783569336
Epoch: 905, Loss: 0.28530454635620117, Accuracy: 89.02806854248047
Epoch: 906, Loss: 0.28527137637138367, Accuracy: 89.02924346923828
Epoch: 907, Loss: 0.2852277159690857, Accuracy: 89.03031921386719
Epoch: 908, Loss: 0.2851821184158325, Accuracy: 89.03128051757812
Epoch: 909, Loss: 0.2851462662220001, Accuracy: 89.03223419189453
Epoch: 910, Loss: 0.2851056158542633, Accuracy: 89.03318786621094
Epoch: 911, Loss: 0.2850622534751892, Accuracy: 89.03392028808594
Epoch: 912, Loss: 0.28503522276878357, Accuracy: 89.03508758544922
Epoch: 913, Loss: 0.2849985361099243, Accuracy: 89.03636169433594
Epoch: 914, Loss: 0.28495678305625916, Accuracy: 89.03741455078125
Epoch: 915, Loss: 0.28491470217704773, Accuracy: 89.03825378417969
Epoc

## 데이터셋 및 학습 파라미터 저장

In [10]:
np.savez_compressed('ch2_dataset.npz', inputs=pts, labels= labels)

W_h, b_h = model.d1.get_weights()
W_o, b_o = model.d2.get_weights()
W_h = np.transpose(W_h)
W_o = np.transpose(W_o)
np.savez_compressed('ch2_parameters.npz',
                    w_h = W_h,
                    b_h = b_h,
                    w_o = W_o,
                    b_o = b_o)