In [27]:
import tensorflow as tf
import numpy as np

In [28]:
# 하이퍼 파라미터
epochs = 1000

In [29]:
# 네트워크 구조 정의
# input 2, hidden 128(sigmoid), output 10(softmax)

class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.d1 = tf.keras.layers.Dense(128, input_dim=2, activation='sigmoid')
        self.d2 = tf.keras.layers.Dense(10, activation='softmax')
            
    def call(self, x, training=None, mask=None):
        x = self.d1(x)
        return self.d2(x)

In [48]:
# 학습 루프 정의

@tf.function
def train_step(model, inputs, labels, loss_object, optimizer, train_loss, train_metric):
    with tf.GradientTape() as tape:
        predictions = model(inputs)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables) # grad(loss) df(x)/dx
    
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    train_loss(loss)
    train_metric(labels, predictions)

In [71]:
# 데이터셋 생성, 전처리

np.random.seed(0)

pts = list()
labels = list()
center_pts = np.random.uniform(-8.0, 8.0, (10, 2))
for label, center_pt in enumerate(center_pts):
    for _ in range(100):
        pts.append(center_pt + np.random.randn(*center_pt.shape))
        labels.append(label)

pts = np.stack(pts, axis=0).astype(np.float32)
labels = np.stack(labels, axis=0)

train_ds = tf.data.Dataset.from_tensor_slices((pts, labels)).shuffle(1000).batch(32)

[[ 2.2750952  3.2378716]
 [ 1.0940838  2.5889342]
 [-1.7719737  4.0966487]
 ...
 [ 4.8216805  4.9303718]
 [ 5.094139   6.609091 ]
 [ 4.7251554  5.316574 ]]


In [50]:
# 모델 생성

model = MyModel()

In [51]:
# 손실 함수 및 최적화 알고리즘 설정
# CrossEntropy, Adam optimzier

loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()

In [52]:
# 평가 지표

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

In [55]:
# 학습 루프

for epoch in range(epochs):
    for x, label in train_ds:
        train_step(model, x, label, loss_object, optimizer, train_loss, train_accuracy)
        
    template = 'Epoch {}, Loss: {}, Accuracy: {}'
    print(template.format(epoch + 1,
                        train_loss.result(),
                        train_accuracy.result() * 100))

Epoch 1, Loss: 0.30156880617141724, Accuracy: 88.72235870361328
Epoch 2, Loss: 0.3014827072620392, Accuracy: 88.72313690185547
Epoch 3, Loss: 0.3014134168624878, Accuracy: 88.72488403320312
Epoch 4, Loss: 0.30132919549942017, Accuracy: 88.72501373291016
Epoch 5, Loss: 0.30124300718307495, Accuracy: 88.72611236572266
Epoch 6, Loss: 0.3011716604232788, Accuracy: 88.72752380371094
Epoch 7, Loss: 0.30108392238616943, Accuracy: 88.72860717773438
Epoch 8, Loss: 0.3010120391845703, Accuracy: 88.73082733154297
Epoch 9, Loss: 0.30093568563461304, Accuracy: 88.7314224243164
Epoch 10, Loss: 0.30085715651512146, Accuracy: 88.73249053955078
Epoch 11, Loss: 0.3007792830467224, Accuracy: 88.73420715332031
Epoch 12, Loss: 0.30069825053215027, Accuracy: 88.73543548583984
Epoch 13, Loss: 0.30061593651771545, Accuracy: 88.7361831665039
Epoch 14, Loss: 0.3005494177341461, Accuracy: 88.7370834350586
Epoch 15, Loss: 0.30047714710235596, Accuracy: 88.73878479003906
Epoch 16, Loss: 0.3004002273082733, Accurac

Epoch 136, Loss: 0.29280006885528564, Accuracy: 88.8777847290039
Epoch 137, Loss: 0.2927463948726654, Accuracy: 88.87875366210938
Epoch 138, Loss: 0.2926924228668213, Accuracy: 88.87970733642578
Epoch 139, Loss: 0.2926429510116577, Accuracy: 88.88066864013672
Epoch 140, Loss: 0.29258736968040466, Accuracy: 88.88162994384766
Epoch 141, Loss: 0.29252487421035767, Accuracy: 88.88245391845703
Epoch 142, Loss: 0.2924659550189972, Accuracy: 88.88353729248047
Epoch 143, Loss: 0.29241374135017395, Accuracy: 88.88435363769531
Epoch 144, Loss: 0.29235246777534485, Accuracy: 88.88556671142578
Epoch 145, Loss: 0.292289137840271, Accuracy: 88.88717651367188
Epoch 146, Loss: 0.29223567247390747, Accuracy: 88.88797760009766
Epoch 147, Loss: 0.29217755794525146, Accuracy: 88.88957214355469
Epoch 148, Loss: 0.2921234965324402, Accuracy: 88.89090728759766
Epoch 149, Loss: 0.29206401109695435, Accuracy: 88.8919677734375
Epoch 150, Loss: 0.2920067012310028, Accuracy: 88.89263916015625
Epoch 151, Loss: 0.2

Epoch 262, Loss: 0.2866162061691284, Accuracy: 89.00394439697266
Epoch 263, Loss: 0.2865673899650574, Accuracy: 89.00553894042969
Epoch 264, Loss: 0.28652068972587585, Accuracy: 89.00656127929688
Epoch 265, Loss: 0.2864765226840973, Accuracy: 89.00724029541016
Epoch 266, Loss: 0.28643348813056946, Accuracy: 89.00779724121094
Epoch 267, Loss: 0.2863905131816864, Accuracy: 89.00869750976562
Epoch 268, Loss: 0.28635096549987793, Accuracy: 89.00949096679688
Epoch 269, Loss: 0.28631532192230225, Accuracy: 89.00992584228516
Epoch 270, Loss: 0.28627142310142517, Accuracy: 89.01116943359375
Epoch 271, Loss: 0.28622549772262573, Accuracy: 89.01240539550781
Epoch 272, Loss: 0.28618451952934265, Accuracy: 89.01317596435547
Epoch 273, Loss: 0.28614386916160583, Accuracy: 89.01350402832031
Epoch 274, Loss: 0.2860976457595825, Accuracy: 89.0145034790039
Epoch 275, Loss: 0.28605180978775024, Accuracy: 89.01561737060547
Epoch 276, Loss: 0.28600645065307617, Accuracy: 89.01728820800781
Epoch 277, Loss:

Epoch 388, Loss: 0.2817050814628601, Accuracy: 89.12190246582031
Epoch 389, Loss: 0.2816716730594635, Accuracy: 89.12287902832031
Epoch 390, Loss: 0.28163808584213257, Accuracy: 89.1237564086914
Epoch 391, Loss: 0.2816062569618225, Accuracy: 89.12452697753906
Epoch 392, Loss: 0.281582772731781, Accuracy: 89.12520599365234
Epoch 393, Loss: 0.28154340386390686, Accuracy: 89.1258773803711
Epoch 394, Loss: 0.28150510787963867, Accuracy: 89.12684631347656
Epoch 395, Loss: 0.2814728617668152, Accuracy: 89.1272201538086
Epoch 396, Loss: 0.28143510222435, Accuracy: 89.12808227539062
Epoch 397, Loss: 0.2813982665538788, Accuracy: 89.12884521484375
Epoch 398, Loss: 0.2813619077205658, Accuracy: 89.12970733642578
Epoch 399, Loss: 0.2813222110271454, Accuracy: 89.13076782226562
Epoch 400, Loss: 0.2812904417514801, Accuracy: 89.13162994384766
Epoch 401, Loss: 0.28125709295272827, Accuracy: 89.1325912475586
Epoch 402, Loss: 0.28121793270111084, Accuracy: 89.13334655761719
Epoch 403, Loss: 0.28117990

Epoch 514, Loss: 0.2777085304260254, Accuracy: 89.23664855957031
Epoch 515, Loss: 0.2776825428009033, Accuracy: 89.23741149902344
Epoch 516, Loss: 0.2776491343975067, Accuracy: 89.23834991455078
Epoch 517, Loss: 0.2776152789592743, Accuracy: 89.23929595947266
Epoch 518, Loss: 0.27758461236953735, Accuracy: 89.24049377441406
Epoch 519, Loss: 0.27755773067474365, Accuracy: 89.24161529541016
Epoch 520, Loss: 0.27752920985221863, Accuracy: 89.2422866821289
Epoch 521, Loss: 0.27750566601753235, Accuracy: 89.24295043945312
Epoch 522, Loss: 0.2774868905544281, Accuracy: 89.24406433105469
Epoch 523, Loss: 0.2774563133716583, Accuracy: 89.24437713623047
Epoch 524, Loss: 0.2774309813976288, Accuracy: 89.2452163696289
Epoch 525, Loss: 0.27740970253944397, Accuracy: 89.24588012695312
Epoch 526, Loss: 0.27737969160079956, Accuracy: 89.24681091308594
Epoch 527, Loss: 0.2773492932319641, Accuracy: 89.24755096435547
Epoch 528, Loss: 0.2773158848285675, Accuracy: 89.24839782714844
Epoch 529, Loss: 0.27

Epoch 640, Loss: 0.27435898780822754, Accuracy: 89.34281921386719
Epoch 641, Loss: 0.27434107661247253, Accuracy: 89.34366607666016
Epoch 642, Loss: 0.2743181586265564, Accuracy: 89.34442901611328
Epoch 643, Loss: 0.2742922604084015, Accuracy: 89.34534454345703
Epoch 644, Loss: 0.27426597476005554, Accuracy: 89.34579467773438
Epoch 645, Loss: 0.27424535155296326, Accuracy: 89.34671020507812
Epoch 646, Loss: 0.274222731590271, Accuracy: 89.34818267822266
Epoch 647, Loss: 0.2741992175579071, Accuracy: 89.34862518310547
Epoch 648, Loss: 0.27417850494384766, Accuracy: 89.34954071044922
Epoch 649, Loss: 0.27415403723716736, Accuracy: 89.35028839111328
Epoch 650, Loss: 0.27413010597229004, Accuracy: 89.35128784179688
Epoch 651, Loss: 0.27410465478897095, Accuracy: 89.35243225097656
Epoch 652, Loss: 0.27408596873283386, Accuracy: 89.35270690917969
Epoch 653, Loss: 0.2740619480609894, Accuracy: 89.35297393798828
Epoch 654, Loss: 0.2740464210510254, Accuracy: 89.35396575927734
Epoch 655, Loss: 

Epoch 766, Loss: 0.27144938707351685, Accuracy: 89.43916320800781
Epoch 767, Loss: 0.27142608165740967, Accuracy: 89.44000244140625
Epoch 768, Loss: 0.271406352519989, Accuracy: 89.44098663330078
Epoch 769, Loss: 0.2713828384876251, Accuracy: 89.4417495727539
Epoch 770, Loss: 0.2713591754436493, Accuracy: 89.44259643554688
Epoch 771, Loss: 0.27133792638778687, Accuracy: 89.44342803955078
Epoch 772, Loss: 0.27131417393684387, Accuracy: 89.44419860839844
Epoch 773, Loss: 0.2712911367416382, Accuracy: 89.44509887695312
Epoch 774, Loss: 0.2712709307670593, Accuracy: 89.44535827636719
Epoch 775, Loss: 0.2712465822696686, Accuracy: 89.44611358642578
Epoch 776, Loss: 0.27122583985328674, Accuracy: 89.4468765258789
Epoch 777, Loss: 0.27120521664619446, Accuracy: 89.44756317138672
Epoch 778, Loss: 0.27118197083473206, Accuracy: 89.44824981689453
Epoch 779, Loss: 0.271159291267395, Accuracy: 89.44886016845703
Epoch 780, Loss: 0.2711416780948639, Accuracy: 89.44925689697266
Epoch 781, Loss: 0.271

Epoch 892, Loss: 0.2689889967441559, Accuracy: 89.52041625976562
Epoch 893, Loss: 0.2689710557460785, Accuracy: 89.52066040039062
Epoch 894, Loss: 0.2689577639102936, Accuracy: 89.52124786376953
Epoch 895, Loss: 0.268937885761261, Accuracy: 89.52196502685547
Epoch 896, Loss: 0.2689160406589508, Accuracy: 89.5224838256836
Epoch 897, Loss: 0.2688998579978943, Accuracy: 89.52306365966797
Epoch 898, Loss: 0.2688785791397095, Accuracy: 89.52377319335938
Epoch 899, Loss: 0.26885923743247986, Accuracy: 89.52415466308594
Epoch 900, Loss: 0.2688368558883667, Accuracy: 89.5250015258789
Epoch 901, Loss: 0.2688194215297699, Accuracy: 89.5255126953125
Epoch 902, Loss: 0.26880475878715515, Accuracy: 89.5264892578125
Epoch 903, Loss: 0.2687869668006897, Accuracy: 89.52726745605469
Epoch 904, Loss: 0.2687755227088928, Accuracy: 89.52796936035156
Epoch 905, Loss: 0.2687586545944214, Accuracy: 89.5284194946289
Epoch 906, Loss: 0.26874205470085144, Accuracy: 89.52925109863281
Epoch 907, Loss: 0.268721789

In [59]:
np.savez_compressed('ch2_dataset.npz', inputs=pts, labels=labels)

W_h, b_h = model.d1.get_weights()
W_o, b_o = model.d2.get_weights()
W_h = np.transpose(W_h)
W_o = np.transpose(W_o)
np.savez_compressed('ch2_parameters.npz',
                   W_h=W_h,
                   b_h=b_h,
                   W_o=W_o,
                   b_o=b_o)