Skip to content

Latest commit

 

History

History
206 lines (121 loc) · 7.33 KB

README.md

File metadata and controls

206 lines (121 loc) · 7.33 KB

C 语言实现 LSTM 算法

LSTM 算法简介

LSTM 全称是 "Long Short-Term Memory",一种用来学习大量时序序列中隐含的相关性并用于预测其可能的趋势的机器学习算法。它的应用范围包括但不局限于价格走势预测、估计剩余寿命、分析语言的情感趋向、自动写作和语音合成。

算法描述了一个用于计算的工作单元,它按照时间顺序接受自然数作为输入,通过计算得到对应的输出。当一系列输入计算完成时,也就得到了对应的一个输出序列。单个 LSTM 单元的学习能力是有限的,可以将输出的序列作为输入序列给另外一个 LSTM 作为输入,通过这种方式利用多个 LSTM 单元的组合提高整体学习能力。

LSTM 的计算过程

作为机器学习算法的一种,LSTM 的应用包括利用大量数据进行训练和根据训练得到的参数预测两个步骤。其预测过程使用正向传播算法,训练过程采用误差反向传播算法。

LSTM 的正向传播算法

一个处理 1 维序列的 LSTM 单元有 12 个参数。这里将这 12 个参数表示为:

假设输入序列 x 和输出序列 h 分别有 n 个元素,表示为:

每次计算中还会产生以下的临时变量:

最后 LSTM 的计算可以表示为:

里面的一个函数符号表示 Sigmoid 函数:

LSTM 的误差反向传播算法

假设期望输出为:

采用均方误差(MSE,Mean Squard Error)来评估实际输出与期望输出的误差:

那么在一次正向传播后,LSTM 输出序列 h 的每个元素对 E 的影响可以用下面的一阶偏导数表示:

进行误差反向传播需要使用 E 对 12 个参数的每一个的偏导数。与普通的神经网络算法不同的是,LSTM 利用 C(t-1) 和 h(t-1) 参与第 t 次的计算,使得第 t 次之前的计算结果会对第 t 次的输出 h(t) 产生影响。

由于:

所以 t = 1 时:

当 t > 1 时:

为方便计算,激活函数导数可取:

最后:

一般情况下学习率取值:

采用简单的梯度下降,可以在正向传播后修正参数:

C 实现方法

用结构体来保存计算过程所需的变量,并提供一个函数用来初始化并返回这个结构体。后续提供一系列的函数用于操作这个结构体。

数据结构

结构体中变量名称和算法的参数之间的对应关系是:

结构体变量定义对应算法中的变量
整数 lengthint length;表示 LSTM 计算序列长度
浮点数指针 xdouble *x;输入序列 x
浮点数指针 hdouble *h;输出序列 h
浮点数指针 fdouble *f;中间变量序列 f
浮点数指针 idouble *i;中间变量序列 i
浮点数指针 tilde_Cdouble *tilde_C;中间变量序列
浮点数指针 Cdouble *C;中间变量序列 C
浮点数指针 odouble *o;中间变量序列 o
浮点数指针 hat_hdouble *hat_h;期望输出序列
浮点数 W_fhdouble W_fh;参数
浮点数 W_fxdouble W_fx;参数
浮点数 b_fdouble b_f;参数
浮点数 W_ihdouble W_ih;参数
浮点数 W_ixdouble W_ix;参数
浮点数 b_idouble b_i;参数
浮点数 W_Chdouble W_Ch;参数
浮点数 W_Cxdouble W_Cx;参数
浮点数 b_Cdouble b_C;参数
浮点数 W_ohdouble W_oh;参数
浮点数 W_oxdouble W_ox;参数
浮点数 b_odouble b_o;参数

下面的参数不是算法必须的,但是实现时会使用:

结构体变量定义说明
整数 error_noint error_no;错误号,无错误默认0。用于记录最后一次程序发生的错误。
字符指针 error_msgchar *error_msg;发生的错误的文字说明,默认无错误,内容为指向字符串"\0"的指针。

错误编号和错误信息

错误号信息说明
0"\0"无错误。
1"not enough memory"内存不足。

接口文档

1. struct lstmlib* lstmlib_create(int length);

参数
  1. length:LSTM 接受输入序列的长度。
返回值

返回一个 struct lstmlib* 结构体指针,或者失败时返回 NULL

功能

创建一个 LSTM 单元,并返回一个结构体指针。可以对这个结构体指针使用 lstmlib 其他函数进行操作。lstmlib_create 方法会自动调用 lstmlib_random_params 对参数进行初始化赋值,赋值的范围是[-1,1]。

2. char lstmlib_random_params(struct lstmlib *unit, double min, double max);

参数
  1. unit:一个 LSTM 单元结构体指针。
  2. min:最小值。
  3. max:最大值。
返回值

成功返回 1,失败返回0

功能

对指定的 LSTM 单元的参数进行随机初始化,初始化的范围是[min, max]。

3. char lstmlib_run(struct lstmlib *unit, double *input, double *output);

参数
  1. unit:一个 LSTM 单元结构体指针。
  2. input:输入浮点数数组指针。
  3. output:输出浮点数数组指针。
返回值

执行成功返回 1,失败返回0

功能

以指定inputoutput作为输入输出区域,运行 LSTM。

4. double lstmlib_get_mse(struct lstmlib *unit);

参数
  1. unit:LSTM 单元结构体。
返回值

一个浮点数,MSE。