# SGD+mini-batch

In [6]:
import tensorflow as tf

In [12]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.boston_housing.load_data()

In [15]:
%matplotlib inline
import numpy as np
import tensorflow as tf
import keras
import matplotlib.pyplot as plt

def get_batches(x, y, batch_size):
    n_data = len(x)
    indices = np.arange(n_data)
    np.random.shuffle(indices)
    x_shuffled = x[indices]
    y_shuffled = y[indices]
    
    # 元データからランダムに batch_size 個ずつ抽出する
    for i in range(0, n_data, batch_size):
        x_batch = x_shuffled[i: i + batch_size]
        y_batch = y_shuffled[i: i + batch_size]
        yield x_batch, y_batch

def main():
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.boston_housing.load_data()
    
    # 3.6.5 データの前処理
    
    ## 平均と標準偏差を計算する
    
    x_train_mean = x_train.mean(axis=0)
    x_train_std = x_train.std(axis=0)
    
    y_train_mean = y_train.mean(axis=0)
    y_train_std = y_train.std(axis=0)
    
    ## 標準化

    x_train = (x_train - x_train_mean) / x_train_std
    y_train = (y_train - y_train_mean) / y_train_std

    ## x_test に対しても x_train_mean と x_train_std を使う
    
    x_test = (x_test - x_train_mean) / x_train_std
    
    ## y_test に対しても y_train_mean と y_train_std を使う

    y_test = (y_test - y_train_mean) / y_train_std
    
    
    # 3.6.6 モデルの定義
    
    ##　説明変数用のプレースホルダー
    
    x = tf.placeholder(tf.float32, (None, 13), name='x')
    y = tf.placeholder(tf.float32, (None, 1), name='y')
    
    ## 説明変数を重み w で足し合わせただけの簡単なモデル
    
    w = tf.Variable(tf.random_normal((13,1)))
    pred = tf.matmul(x, w)
# 3.6.7 損失関数の定義と学習
    
    ## 実測値と推定値の差の二乗の平均を誤差とする
    
    loss = tf.reduce_mean((y - pred) ** 2)
    optimizer = tf.train.GradientDescentOptimizer(
        learning_rate=0.1
    )
    train_step = optimizer.minimize(loss)
    
    # ミニバッチのサイズ
    BATCH_SIZE = 32

    step = 0
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        # 100エポック回す
        for epoch in range(100):
            for x_batch, y_batch in get_batches(x_train, y_train, 32):
                train_loss, _ = sess.run(
                    [loss, train_step],
                    feed_dict={
                        x: x_batch,
                        y: y_batch.reshape((-1, 1))
                    }
                )
                print('step: {}, train_loss: {}'.format(
                    step, train_loss
                ))
                step += 1

        pred_ = sess.run(
            pred,
            feed_dict={
                x: x_test
            }
        )
        
if __name__=='__main__':
    main()

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Using TensorFlow backend.


step: 0, train_loss: 14.008430480957031
step: 1, train_loss: 3.232260227203369
step: 2, train_loss: 4.01450252532959
step: 3, train_loss: 2.9002275466918945
step: 4, train_loss: 2.1070008277893066
step: 5, train_loss: 1.9160442352294922
step: 6, train_loss: 1.711913824081421
step: 7, train_loss: 1.160210371017456
step: 8, train_loss: 1.379097819328308
step: 9, train_loss: 0.6659513711929321
step: 10, train_loss: 1.1420364379882812
step: 11, train_loss: 0.6633274555206299
step: 12, train_loss: 0.8447467684745789
step: 13, train_loss: 0.8466216325759888
step: 14, train_loss: 0.5403729677200317
step: 15, train_loss: 0.6734892129898071
step: 16, train_loss: 0.7682428359985352
step: 17, train_loss: 0.6870162487030029
step: 18, train_loss: 0.4518558084964752
step: 19, train_loss: 0.4145524203777313
step: 20, train_loss: 0.2626575231552124
step: 21, train_loss: 0.22827431559562683
step: 22, train_loss: 0.3983951508998871
step: 23, train_loss: 0.34788060188293457
step: 24, train_loss: 0.344885

step: 386, train_loss: 0.17334899306297302
step: 387, train_loss: 0.6522247195243835
step: 388, train_loss: 0.13865801692008972
step: 389, train_loss: 0.4219469130039215
step: 390, train_loss: 0.18905581533908844
step: 391, train_loss: 0.3926783502101898
step: 392, train_loss: 0.24527069926261902
step: 393, train_loss: 0.2804713249206543
step: 394, train_loss: 0.34605276584625244
step: 395, train_loss: 0.15872813761234283
step: 396, train_loss: 0.1271091252565384
step: 397, train_loss: 0.14514890313148499
step: 398, train_loss: 0.46624553203582764
step: 399, train_loss: 0.31294482946395874
step: 400, train_loss: 0.3285251557826996
step: 401, train_loss: 0.5208274126052856
step: 402, train_loss: 0.21300990879535675
step: 403, train_loss: 0.1602523922920227
step: 404, train_loss: 0.1344030350446701
step: 405, train_loss: 0.34937334060668945
step: 406, train_loss: 0.10045763850212097
step: 407, train_loss: 0.2658616304397583
step: 408, train_loss: 0.2750292718410492
step: 409, train_loss:

step: 596, train_loss: 0.2911413908004761
step: 597, train_loss: 0.1331476867198944
step: 598, train_loss: 0.34271323680877686
step: 599, train_loss: 0.31566277146339417
step: 600, train_loss: 0.2940559983253479
step: 601, train_loss: 0.20008929073810577
step: 602, train_loss: 0.27096623182296753
step: 603, train_loss: 0.2775583267211914
step: 604, train_loss: 0.5009604096412659
step: 605, train_loss: 0.2723824083805084
step: 606, train_loss: 0.23215752840042114
step: 607, train_loss: 0.25479376316070557
step: 608, train_loss: 0.4163118898868561
step: 609, train_loss: 0.20994526147842407
step: 610, train_loss: 0.27825847268104553
step: 611, train_loss: 0.24447865784168243
step: 612, train_loss: 0.21700716018676758
step: 613, train_loss: 0.23154279589653015
step: 614, train_loss: 0.38820087909698486
step: 615, train_loss: 0.2664022445678711
step: 616, train_loss: 0.389880508184433
step: 617, train_loss: 0.4533211290836334
step: 618, train_loss: 0.43576347827911377
step: 619, train_loss:

step: 822, train_loss: 0.20410215854644775
step: 823, train_loss: 0.1889324188232422
step: 824, train_loss: 0.2196887880563736
step: 825, train_loss: 0.4603906571865082
step: 826, train_loss: 0.20757940411567688
step: 827, train_loss: 0.3443073332309723
step: 828, train_loss: 0.36068934202194214
step: 829, train_loss: 0.28228893876075745
step: 830, train_loss: 0.3012748658657074
step: 831, train_loss: 0.5194291472434998
step: 832, train_loss: 0.6296333074569702
step: 833, train_loss: 0.1747715175151825
step: 834, train_loss: 0.22036001086235046
step: 835, train_loss: 0.15081924200057983
step: 836, train_loss: 0.5396764874458313
step: 837, train_loss: 0.21151328086853027
step: 838, train_loss: 0.24379168450832367
step: 839, train_loss: 0.27113819122314453
step: 840, train_loss: 0.22049781680107117
step: 841, train_loss: 0.5321215987205505
step: 842, train_loss: 0.3399600386619568
step: 843, train_loss: 0.25596368312835693
step: 844, train_loss: 0.15425148606300354
step: 845, train_loss:

step: 1040, train_loss: 0.33052343130111694
step: 1041, train_loss: 0.35218024253845215
step: 1042, train_loss: 0.10492303222417831
step: 1043, train_loss: 0.19614559412002563
step: 1044, train_loss: 0.22317419946193695
step: 1045, train_loss: 0.21263869106769562
step: 1046, train_loss: 0.5240539908409119
step: 1047, train_loss: 0.25581812858581543
step: 1048, train_loss: 0.4670151174068451
step: 1049, train_loss: 0.3727512061595917
step: 1050, train_loss: 0.2629610300064087
step: 1051, train_loss: 0.22346968948841095
step: 1052, train_loss: 0.1627078801393509
step: 1053, train_loss: 0.1544293463230133
step: 1054, train_loss: 0.2090354859828949
step: 1055, train_loss: 0.24574433267116547
step: 1056, train_loss: 0.5191115736961365
step: 1057, train_loss: 0.18802635371685028
step: 1058, train_loss: 0.4988161027431488
step: 1059, train_loss: 0.24234847724437714
step: 1060, train_loss: 0.1852511763572693
step: 1061, train_loss: 0.21938839554786682
step: 1062, train_loss: 0.2282642126083374

step: 1258, train_loss: 0.2492714375257492
step: 1259, train_loss: 0.3447043001651764
step: 1260, train_loss: 0.11174464225769043
step: 1261, train_loss: 0.1824822723865509
step: 1262, train_loss: 0.26380839943885803
step: 1263, train_loss: 0.3814055919647217
step: 1264, train_loss: 0.198882594704628
step: 1265, train_loss: 0.5457421541213989
step: 1266, train_loss: 0.24775193631649017
step: 1267, train_loss: 0.2575637698173523
step: 1268, train_loss: 0.5060690641403198
step: 1269, train_loss: 0.12608584761619568
step: 1270, train_loss: 0.262966513633728
step: 1271, train_loss: 0.21340428292751312
step: 1272, train_loss: 0.2053951919078827
step: 1273, train_loss: 0.3181235194206238
step: 1274, train_loss: 0.15516111254692078
step: 1275, train_loss: 0.1997602880001068
step: 1276, train_loss: 0.258098840713501
step: 1277, train_loss: 0.22002345323562622
step: 1278, train_loss: 0.10333068668842316
step: 1279, train_loss: 0.22742587327957153
step: 1280, train_loss: 0.13219033181667328
step