In [11]:
import tensorflow as tf
import numpy as np

### 하이퍼 파라미터 설정

In [12]:
EPOCHS = 1000

### 네트워크 구조 정의

In [20]:
#input 2 / hidden 128 (sigmoid) / output 10 (softmax) => classification
class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.d1 = tf.keras.layers.Dense(128, input_dim=2, activation='sigmoid')
        self.d2 = tf.keras.layers.Dense(10, activation='softmax')
    
    def call(self, x, training=None, mask=None):
        x = self.d1(x)
        return self.d2(x)

### 학습 루프 정의

In [21]:
@tf.function
def train_step(model, inputs, labels, loss_object, optimizer, train_loss, train_metric):
    with tf.GradientTape() as tape:
        predictions = model(inputs)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables) #df(x)/dx
    
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    train_loss(loss)
    train_metric(labels, predictions)

### 데이터셋 생성, 전처리

In [22]:
np.random.seed(0)

pts = list() #input
labels = list() #output
center_pts = np.random.uniform(-8.0, 8.0, (10, 2))

for label, center_pt in enumerate(center_pts):
    for _ in range(100):
        pts.append(center_pt + np.random.randn(*center_pt.shape))
        labels.append(label)
        
pts = np.stack(pts, axis=0).astype(np.float32)
labels = np.stack(labels, axis=0)

train_ds = tf.data.Dataset.from_tensor_slices((pts, labels)).shuffle(1000).batch(8)

### 모델 생성

In [23]:
model = MyModel()

### loss, optimize algorithm

- CrossEntropy, Adam Optimizer

In [24]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy() #ohe
optimizer = tf.keras.optimizers.Adam()

### Metric ; acc

In [25]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

### 학습 루프

In [27]:
for epoch in range(EPOCHS):
    for x, label in train_ds:
        train_step(model, x, label, loss_object, optimizer, train_loss, train_accuracy)
        
    template = 'Epoch {}, Loss : {}, Accuracy : {}%'
    print(template.format(epoch + 1, train_loss.result(), train_accuracy.result()*100))

Epoch 1, Loss : 0.4151080548763275, Accuracy : 86.52652740478516%
Epoch 2, Loss : 0.4131194055080414, Accuracy : 86.5607681274414%
Epoch 3, Loss : 0.4112066328525543, Accuracy : 86.595458984375%
Epoch 4, Loss : 0.4093279242515564, Accuracy : 86.61375427246094%
Epoch 5, Loss : 0.40746328234672546, Accuracy : 86.64537048339844%
Epoch 6, Loss : 0.4056413769721985, Accuracy : 86.6761245727539%
Epoch 7, Loss : 0.40383878350257874, Accuracy : 86.6979751586914%
Epoch 8, Loss : 0.40211665630340576, Accuracy : 86.71925354003906%
Epoch 9, Loss : 0.40042370557785034, Accuracy : 86.74388885498047%
Epoch 10, Loss : 0.39881575107574463, Accuracy : 86.7704849243164%
Epoch 11, Loss : 0.3972133994102478, Accuracy : 86.8066177368164%
Epoch 12, Loss : 0.3956109583377838, Accuracy : 86.83805847167969%
Epoch 13, Loss : 0.3941078186035156, Accuracy : 86.87244415283203%
Epoch 14, Loss : 0.39261239767074585, Accuracy : 86.89984130859375%
Epoch 15, Loss : 0.39118295907974243, Accuracy : 86.92170715332031%
Epoc

Epoch 124, Loss : 0.3205508291721344, Accuracy : 88.13665771484375%
Epoch 125, Loss : 0.3202449083328247, Accuracy : 88.140625%
Epoch 126, Loss : 0.31995150446891785, Accuracy : 88.14507293701172%
Epoch 127, Loss : 0.31965723633766174, Accuracy : 88.15049743652344%
Epoch 128, Loss : 0.31938090920448303, Accuracy : 88.15535736083984%
Epoch 129, Loss : 0.31912288069725037, Accuracy : 88.15813446044922%
Epoch 130, Loss : 0.3188386857509613, Accuracy : 88.16291046142578%
Epoch 131, Loss : 0.31854861974716187, Accuracy : 88.1711654663086%
Epoch 132, Loss : 0.3182729184627533, Accuracy : 88.17582702636719%
Epoch 133, Loss : 0.31800660490989685, Accuracy : 88.1804428100586%
Epoch 134, Loss : 0.31772908568382263, Accuracy : 88.18600463867188%
Epoch 135, Loss : 0.3174368739128113, Accuracy : 88.18952941894531%
Epoch 136, Loss : 0.3171665072441101, Accuracy : 88.19696044921875%
Epoch 137, Loss : 0.3169054687023163, Accuracy : 88.20040130615234%
Epoch 138, Loss : 0.31664156913757324, Accuracy : 8

Epoch 246, Loss : 0.297529935836792, Accuracy : 88.56507873535156%
Epoch 247, Loss : 0.2974061667919159, Accuracy : 88.5667724609375%
Epoch 248, Loss : 0.29729050397872925, Accuracy : 88.56941223144531%
Epoch 249, Loss : 0.29716911911964417, Accuracy : 88.57172393798828%
Epoch 250, Loss : 0.29705023765563965, Accuracy : 88.57591247558594%
Epoch 251, Loss : 0.29693126678466797, Accuracy : 88.57818603515625%
Epoch 252, Loss : 0.2968160808086395, Accuracy : 88.5788803100586%
Epoch 253, Loss : 0.2966962158679962, Accuracy : 88.58050537109375%
Epoch 254, Loss : 0.2965872883796692, Accuracy : 88.58119201660156%
Epoch 255, Loss : 0.29646435379981995, Accuracy : 88.58341979980469%
Epoch 256, Loss : 0.29634982347488403, Accuracy : 88.58470916748047%
Epoch 257, Loss : 0.2962389588356018, Accuracy : 88.58537292480469%
Epoch 258, Loss : 0.296118825674057, Accuracy : 88.58726501464844%
Epoch 259, Loss : 0.2960049510002136, Accuracy : 88.5879135131836%
Epoch 260, Loss : 0.2958947718143463, Accuracy 

Epoch 367, Loss : 0.2866007685661316, Accuracy : 88.77719116210938%
Epoch 368, Loss : 0.2865365445613861, Accuracy : 88.77885437011719%
Epoch 369, Loss : 0.2864694893360138, Accuracy : 88.77981567382812%
Epoch 370, Loss : 0.2864019572734833, Accuracy : 88.78055572509766%
Epoch 371, Loss : 0.2863345742225647, Accuracy : 88.78082275390625%
Epoch 372, Loss : 0.2862589955329895, Accuracy : 88.78382110595703%
Epoch 373, Loss : 0.2861862778663635, Accuracy : 88.7874984741211%
Epoch 374, Loss : 0.2861136198043823, Accuracy : 88.78888702392578%
Epoch 375, Loss : 0.28605276346206665, Accuracy : 88.78913879394531%
Epoch 376, Loss : 0.28599298000335693, Accuracy : 88.79051208496094%
Epoch 377, Loss : 0.28592461347579956, Accuracy : 88.79076385498047%
Epoch 378, Loss : 0.2858545780181885, Accuracy : 88.79280090332031%
Epoch 379, Loss : 0.28578752279281616, Accuracy : 88.79461669921875%
Epoch 380, Loss : 0.2857225835323334, Accuracy : 88.79551696777344%
Epoch 381, Loss : 0.2856620252132416, Accurac

Epoch 490, Loss : 0.2796526551246643, Accuracy : 88.9388656616211%
Epoch 491, Loss : 0.27960652112960815, Accuracy : 88.93951416015625%
Epoch 492, Loss : 0.27955982089042664, Accuracy : 88.93944549560547%
Epoch 493, Loss : 0.2795122563838959, Accuracy : 88.9400863647461%
Epoch 494, Loss : 0.27946361899375916, Accuracy : 88.94161987304688%
Epoch 495, Loss : 0.27941322326660156, Accuracy : 88.94314575195312%
Epoch 496, Loss : 0.27936825156211853, Accuracy : 88.9443130493164%
Epoch 497, Loss : 0.279326856136322, Accuracy : 88.94512176513672%
Epoch 498, Loss : 0.27928319573402405, Accuracy : 88.94628143310547%
Epoch 499, Loss : 0.2792414724826813, Accuracy : 88.94743347167969%
Epoch 500, Loss : 0.2791941165924072, Accuracy : 88.94964599609375%
Epoch 501, Loss : 0.2791444659233093, Accuracy : 88.94990539550781%
Epoch 502, Loss : 0.2791011929512024, Accuracy : 88.95087432861328%
Epoch 503, Loss : 0.2790597975254059, Accuracy : 88.95271301269531%
Epoch 504, Loss : 0.2790156900882721, Accuracy

Epoch 612, Loss : 0.2746812403202057, Accuracy : 89.07807159423828%
Epoch 613, Loss : 0.27464449405670166, Accuracy : 89.07942199707031%
Epoch 614, Loss : 0.27460816502571106, Accuracy : 89.07974243164062%
Epoch 615, Loss : 0.2745717465877533, Accuracy : 89.08080291748047%
Epoch 616, Loss : 0.2745320498943329, Accuracy : 89.08243560791016%
Epoch 617, Loss : 0.274495929479599, Accuracy : 89.08406829833984%
Epoch 618, Loss : 0.2744643986225128, Accuracy : 89.08512115478516%
Epoch 619, Loss : 0.27442917227745056, Accuracy : 89.0870361328125%
Epoch 620, Loss : 0.27439358830451965, Accuracy : 89.08836364746094%
Epoch 621, Loss : 0.27435746788978577, Accuracy : 89.08881378173828%
Epoch 622, Loss : 0.27432289719581604, Accuracy : 89.09027862548828%
Epoch 623, Loss : 0.2742879092693329, Accuracy : 89.09159851074219%
Epoch 624, Loss : 0.27425405383110046, Accuracy : 89.09204864501953%
Epoch 625, Loss : 0.2742207646369934, Accuracy : 89.09306335449219%
Epoch 626, Loss : 0.27418291568756104, Accu

Epoch 734, Loss : 0.2707628309726715, Accuracy : 89.2093276977539%
Epoch 735, Loss : 0.27073535323143005, Accuracy : 89.2100601196289%
Epoch 736, Loss : 0.27070605754852295, Accuracy : 89.2109146118164%
Epoch 737, Loss : 0.27068111300468445, Accuracy : 89.2112808227539%
Epoch 738, Loss : 0.27065324783325195, Accuracy : 89.2122573852539%
Epoch 739, Loss : 0.27062490582466125, Accuracy : 89.2133560180664%
Epoch 740, Loss : 0.2705939710140228, Accuracy : 89.21495056152344%
Epoch 741, Loss : 0.2705664336681366, Accuracy : 89.21665954589844%
Epoch 742, Loss : 0.2705375552177429, Accuracy : 89.21775817871094%
Epoch 743, Loss : 0.2705092430114746, Accuracy : 89.21859741210938%
Epoch 744, Loss : 0.27048054337501526, Accuracy : 89.21981048583984%
Epoch 745, Loss : 0.2704515755176544, Accuracy : 89.22089385986328%
Epoch 746, Loss : 0.27042484283447266, Accuracy : 89.22110748291016%
Epoch 747, Loss : 0.2703993022441864, Accuracy : 89.2223129272461%
Epoch 748, Loss : 0.27037206292152405, Accuracy 

Epoch 857, Loss : 0.26754817366600037, Accuracy : 89.33541870117188%
Epoch 858, Loss : 0.267525851726532, Accuracy : 89.3365707397461%
Epoch 859, Loss : 0.26750364899635315, Accuracy : 89.33706665039062%
Epoch 860, Loss : 0.2674805819988251, Accuracy : 89.33746337890625%
Epoch 861, Loss : 0.26745662093162537, Accuracy : 89.33807373046875%
Epoch 862, Loss : 0.26743653416633606, Accuracy : 89.3388900756836%
Epoch 863, Loss : 0.2674120366573334, Accuracy : 89.33981323242188%
Epoch 864, Loss : 0.2673880159854889, Accuracy : 89.34052276611328%
Epoch 865, Loss : 0.2673627734184265, Accuracy : 89.341552734375%
Epoch 866, Loss : 0.2673371732234955, Accuracy : 89.34300231933594%
Epoch 867, Loss : 0.26731494069099426, Accuracy : 89.34403228759766%
Epoch 868, Loss : 0.2672897279262543, Accuracy : 89.34504699707031%
Epoch 869, Loss : 0.26726651191711426, Accuracy : 89.34628295898438%
Epoch 870, Loss : 0.267246812582016, Accuracy : 89.34677124023438%
Epoch 871, Loss : 0.2672223448753357, Accuracy :

Epoch 979, Loss : 0.26488494873046875, Accuracy : 89.44645690917969%
Epoch 980, Loss : 0.2648635506629944, Accuracy : 89.44698333740234%
Epoch 981, Loss : 0.2648426592350006, Accuracy : 89.44808197021484%
Epoch 982, Loss : 0.264822781085968, Accuracy : 89.44880676269531%
Epoch 983, Loss : 0.2648009955883026, Accuracy : 89.4496078491211%
Epoch 984, Loss : 0.2647811472415924, Accuracy : 89.45022583007812%
Epoch 985, Loss : 0.26476284861564636, Accuracy : 89.45132446289062%
Epoch 986, Loss : 0.2647441625595093, Accuracy : 89.45194244384766%
Epoch 987, Loss : 0.2647254765033722, Accuracy : 89.45293426513672%
Epoch 988, Loss : 0.2647068500518799, Accuracy : 89.4537353515625%
Epoch 989, Loss : 0.26468831300735474, Accuracy : 89.45415496826172%
Epoch 990, Loss : 0.2646670937538147, Accuracy : 89.45523834228516%
Epoch 991, Loss : 0.26464587450027466, Accuracy : 89.45584869384766%
Epoch 992, Loss : 0.26462867856025696, Accuracy : 89.45674133300781%
Epoch 993, Loss : 0.2646068334579468, Accuracy

In [28]:
np.savez_compressed('ch2_dataset.npz', inputs=pts, labels=labels)

W_h, b_h = model.d1.get_weights()
W_o, b_o = model.d2.get_weights()

W_h = np.transpose(W_h)
W_o = np.transpose(W_o)
np.savez_compressed('ch2_parameters.npz', W_h=W_h, b_h=b_h, W_o=W_o, b_o=b_o)