In [None]:
%matplotlib inline
# %load noise_reduction.py

# Noise reduction
# using NN

# 欲しい信号signal(x)にノイズnoise(x)が乗っている観測データ
# obs(x) = signal(x) + noise(x)
# から、ノイズを除去するモデルの学習を行う。

import matplotlib.pyplot as plt
import tensorflow as tf
import math
import numpy as np
import os
from datetime import datetime


# 問題設定---------------------------------------------------------------------
# ノイズはGaussian形状としている。
# その平均muとsigmaの設定
noise_mu = 15
noise_sigma = 10

# １つの観測データは、number_of_points個の実数
# obs(x[0]), obs(x[1]), ..., obs(x[number_of_points-1])
# の集合である。
size1 = 20
number_of_points = 2 * size1 + 1
number_of_data = 100
number_of_test = 3
number_of_train = number_of_data - number_of_test


# アルゴリズム設定-------------------------------------------------------------
# Number of perceptrons at 2nd layer
# default: 10000
mini_batch_size = 20
number_of_2nd_perceptrons = 10000
#
stddev_of_perceptrons = 0.03  # 0.1 #0.03
learning_rate = 0.9  # 0.1 #0.9


# 補助関数--------------------------------------------------------------------
def Gaussian(x, mu, sigma):
    return math.exp(-0.5 * (x - mu) ** 2 / sigma ** 2)


# 疑似観測データの作成---------------------------------------------------------
print("# of sample points: ", number_of_points)
sample_points = range(-size1, size1 + 1)
signal_set = np.zeros((number_of_data, number_of_points))
obs_set = np.zeros((number_of_data, number_of_points))

d_sigs = np.random.rand(number_of_data).astype("float32")

for j in range(number_of_data):
    for i in range(number_of_points):
        signal_set[j][i] = Gaussian(
            sample_points[i], 0, 0.5 + 0.05 * size1 * d_sigs[j])
        obs_set[j][i] = signal_set[j][i] + 0.2 * Gaussian(
            sample_points[i], noise_mu, noise_sigma)


# Visualize for check---------------------------------------------------------
print('\nSome data for training are shown bellow.')
print('  blue: signal')
print('  yellow: signal+noise')
for j in range(min(2, number_of_data)):
    print('\nTraining data #' + str(j))
    plt.scatter(sample_points, signal_set[j])
    plt.scatter(sample_points, obs_set[j])
    plt.xlabel("x")
    plt.ylabel("y")
    plt.show()

In [None]:
# 計算グラフの定義と実行--------------------------------------------------------
# 入力データを定義
obs = tf.placeholder(tf.float32, [None, number_of_points], name="observed")

# 入力データをログに出力
img = tf.reshape(obs, [-1, number_of_points, 1, 1])
tf.summary.image("log_input_data", img, 2)

# 入力層から中間層
with tf.name_scope("second_layer"):
    w_1 = tf.Variable(tf.truncated_normal(
        [number_of_points, number_of_2nd_perceptrons],
        stddev=stddev_of_perceptrons), name="w1")
    b_1 = tf.Variable(tf.zeros([number_of_2nd_perceptrons]), name="b1")
    h_1 = tf.nn.relu(tf.matmul(obs, w_1) + b_1)

    # 中間層の重みの分布をログ出力
    tf.summary.histogram('log_w_1', w_1)

# 中間層から出力層
with tf.name_scope("output_layer"):
    w_2 = tf.Variable(tf.truncated_normal(
        [number_of_2nd_perceptrons, number_of_points],
        stddev=stddev_of_perceptrons), name="w2")
    b_2 = tf.Variable(tf.zeros([number_of_points]), name="b2")
    out = tf.matmul(h_1, w_2) + b_2

# 誤差関数
real_signal = tf.placeholder(
    tf.float32, [None, number_of_points], name="real_signal")
with tf.name_scope("loss"):
    loss = tf.reduce_mean(tf.square(real_signal - out))

    # 誤差をログ出力
    tf.summary.scalar("log_loss", loss)

# 訓練
with tf.name_scope("train"):
    train_step = tf.train.GradientDescentOptimizer(
        learning_rate).minimize(loss)

# 初期化
init = tf.global_variables_initializer()

# ログ出力用
summary_op = tf.summary.merge_all()

with tf.Session() as sess:
    log_path = os.path.join("logs", datetime.now().strftime("%Y%m%d-%H%M%S"))
    summary_writer = tf.summary.FileWriter(log_path, sess.graph)

    sess.run(init)

    for j in range(0, number_of_train, mini_batch_size):
        idx_end = min(j + mini_batch_size, number_of_train)
        print(
            "\n\nTraining data from #" + str(j),
            "to #" + str(idx_end - 1), "are trained."
        )

        sess.run(
            train_step, 
            feed_dict={
                obs:obs_set[j:idx_end], real_signal:signal_set[j:idx_end]}
        )

        test_obs = obs_set[number_of_train:number_of_data]
        
        # ログ出力用
        test_real_signal = signal_set[number_of_train:number_of_data]       

        # test_obsからノイズ除去
        outVal = sess.run(out, feed_dict={obs:test_obs})

        if j % 40 == 0:
            print('\nPredicted signal and real signal are shown bellow.')
            print('  blue: real signal')
            print('  magenda: predicted signal')
            for i in range(len(outVal)):
                print('\nTest data #' + str(i))
                plt.scatter(sample_points, signal_set[number_of_train + i])
                plt.scatter(sample_points, outVal[i], number_of_points, 'm')
                plt.xlabel("x")
                plt.ylabel("y")
                plt.show()

            # ログを取る処理を実行する（出力はログ情報が書かれたプロトコルバッファ）
#            print("test_real_signal's shape: ", str(test_real_signal.shape))
            summary_str = sess.run(
                summary_op, 
                feed_dict={obs:test_obs, real_signal:test_real_signal})
            # ログ情報のプロトコルバッファを書き込む
            summary_writer.add_summary(summary_str, idx_end)

In [None]:
# Visualize for description----------------------------------------------------
print('\nSome data are shown bellow.')
print('  blue: signal')
print('  yellow: signal+noise')
for j in range(min(1, number_of_data)):
    print('\nSignal data #' + str(j))
    #plt.scatter(sample_points, signal_set[j])
    plt.scatter(sample_points, signal_set[j])
    plt.ylim(-0.1, 1.1)
    plt.xlabel("x")
    plt.ylabel("y")
    plt.show()
    
    print('\nObserved data #' + str(j))
    plt.scatter(sample_points, obs_set[j], c='#ff7f0e')
    plt.xlabel("x")
    plt.ylabel("y")
    plt.show()
    
#     print('\nData #' + str(j))
#     plt.fill_between(
#         sample_points, obs_set[j], obs_set[j] - signal_set[j], color='b')
#     plt.fill_between(sample_points, obs_set[j] - signal_set[j], 0, color='m')
#     plt.xlabel("x")
#     plt.ylabel("y")
#     plt.show()
    
    print('\nNoise data #' + str(j))
    plt.scatter(
        sample_points, obs_set[j] - signal_set[j], color='g')
    plt.ylim(-0.1, 1.1)
    plt.xlabel("x")
    plt.ylabel("y")
    plt.show()