In [13]:
import tensorflow as tf

#### 선형회귀 모델 정의

In [14]:
class LinearRegression(tf.keras.Model):
    def __init__(self):
        super(LinearRegression, self).__init__()
        self.linear_layer = tf.keras.layers.Dense(1, activation=None)
        
    def call(self, x):
        y_pred = self.linear_layer(x)
        
        return y_pred


#### 손실 함수 정의
MSE 손실함수
mean{(y' - y)^2}

In [15]:
@tf.function
def mse_loss(y_pred, y):
    return tf.reduce_mean(tf.square(y_pred - y))

#### 옵티마이저 정의

In [16]:
optimizer = tf.optimizers.SGD(0.01)

#### 텐서보드 기록 경로 설정 및 FileWriter 선언

In [17]:
summary_writer = tf.summary.create_file_writer('./tensorboard_log')

#### 최적화 위한 function 정의

In [18]:
@tf.function
def train_step(model, x, y):
    with tf.GradientTape() as tape:
        y_pred = model(x)
        loss = mse_loss(y_pred, y)
    # 매 step마다 텐서보드 로그 기록
    with summary_writer.as_default():
        tf.summary.scalar('loss', loss, step=optimizer.iterations)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

#### 트레이닝 데이터셋

In [19]:
x_train = [1.0, 2.0, 3.0, 4.0]
y_train = [2.0, 4.0, 6.0, 8.0]

# batch 형태로 가져오기
train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_data = train_data.repeat().batch(1)
train_data_iter = iter(train_data)

#### 모델 선언

In [20]:
LinearRegression_model = LinearRegression()

#### 경사하강법 수행

In [21]:
for i in range(1000):
    batch_xs, batch_ys = next(train_data_iter)
    # tf.kears.layers.Dense API의 최소 input dimension인 2-dimension 맞추기 위한 확장
    batch_xs = tf.expand_dims(batch_xs, 0)
    train_step(LinearRegression_model, batch_xs, batch_ys)

#### 테스트 데이터셋

In [22]:
x_test = [3.5, 5.0, 5.5, 6.0]
test_data = tf.data.Dataset.from_tensor_slices((x_test))
test_data = test_data.batch(1)

#### 측정

In [23]:
for batch_x_test in test_data:
    batch_x_test = tf.expand_dims(batch_x_test, 0)
    
    print(tf.squeeze(LinearRegression_model(batch_x_test), 0).numpy())

[6.9976463]
[9.985136]
[10.980967]
[11.976796]
