## tensorflow 实现线性回归

In [1]:
import tensorflow as tf
import tensorflow.keras as keras
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import sklearn
import time
import random
import os

for module in tf, keras, pd, mpl, np, sklearn:
    print(module.__name__, module.__version__)

tensorflow 2.3.0
tensorflow.keras 2.4.0
pandas 1.3.3
matplotlib 3.4.2
numpy 1.18.5
sklearn 1.0


### 1. 生成数据集

In [2]:
# 定义数据的维度
feature_num = 3
# 定义数据行数
input_num = 10000

# 使用正态分布
feature = tf.random.normal(shape=(input_num, feature_num), stddev=2)

# 定义权重和偏置值
true_w = [1.2, -2.2, 3.]
true_b = 4.2

# 生成 labels
labels = true_w[0] * feature[:, 0] + true_w[1] * feature[:, 1] + true_w[2] * feature[:, 2] + true_b
# 添加噪声
labels += tf.random.normal(labels.shape, stddev=0.02)
print(feature.shape, labels.shape)
print(feature[:5].numpy(), "\n", labels[:5].numpy())

(10000, 3) (10000,)
[[-3.549167   -1.2205924  -2.4376402 ]
 [-2.0892808   1.4278817   0.4083136 ]
 [-3.5043025   1.7323956  -3.1746438 ]
 [ 2.7191412  -0.4684315  -0.63798606]
 [ 1.140864    0.6738385   0.50637037]] 
 [ -4.692578    -0.22133102 -13.344547     6.610779     5.5883646 ]


### 2. 读取数据

- dataset.shuffle :
    1. shuffle 可以将数据打乱
    2. shuffle 的 buffer_size 参数应大于等于样本数，batch 可以指定 batch_size 的分割大小。

In [13]:

dataset = tf.data.Dataset.from_tensor_slices((feature, labels))

# 打印小批量的数据
# shuffle 的 buffer_size 参数应大于等于样本数，batch 可以指定 batch_size 的分割大小。
dataset = dataset.shuffle(buffer_size=input_num)
# 读取小批次的数据
dataset_batch = dataset.batch(3000)
for (feature, label) in dataset_batch.take(1):
    print(feature, "\n", label)

tf.Tensor(
[[ 0.727077   -1.912732   -0.75884783]
 [-3.2444994  -0.53297764 -3.7582388 ]
 [ 0.25169364  1.02628     1.434147  ]
 ...
 [ 3.1750088  -1.8997864   1.6457089 ]
 [ 0.22117324  1.7619385  -4.0208354 ]
 [ 1.1663364  -0.89056647 -3.559205  ]], shape=(3000, 3), dtype=float32) 
 tf.Tensor(
[  6.9954944  -9.786368    6.5648537 ...  17.133238  -11.490167
  -3.1036744], shape=(3000,), dtype=float32)
