Skip to content

Commit

Permalink
无穷大改np.inf;优化显存占用
Browse files Browse the repository at this point in the history
  • Loading branch information
bojone committed Mar 6, 2023
1 parent e2f899a commit 20a4694
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 107 deletions.
143 changes: 86 additions & 57 deletions bert4keras/backend.py
Expand Up @@ -58,18 +58,6 @@ def set_gelu(version):
keras.utils.get_custom_objects()['gelu'] = gelu_tanh


def infinity():
"""返回默认的代表无穷大的数值
"""
return keras.utils.get_custom_objects().get('infinity', 1e12)


def set_infinity(value):
"""设置新的代表无穷大的数值
"""
keras.utils.get_custom_objects()['infinity'] = value


def piecewise_linear(t, schedule, from_zero=True):
"""分段线性函数
其中schedule是形如{1000: 1, 2000: 0.1}的字典,
Expand All @@ -93,7 +81,7 @@ def piecewise_linear(t, schedule, from_zero=True):
x = schedule[i][1] + slope * (t - t_begin)
else:
x = (t * 0 + 1) * schedule[i][1]
x = K.switch(t >= t_begin, x, x_begin)
x = K.where(t >= t_begin, x, x_begin)

return x

Expand Down Expand Up @@ -177,32 +165,86 @@ def flatten(tensor, start=None, end=None):
return K.reshape(tensor, shape)


def sequence_masking(x, mask, value=0, axis=None):
def dtype(x):
"""增强K.dtype的容错性
"""
try:
return K.dtype(x)
except:
pass


def where(cond, x, y):
"""给tf.where加上自动广播
"""
shape = tf.broadcast_dynamic_shape(K.shape(x), K.shape(y))
shape = tf.broadcast_dynamic_shape(K.shape(cond), shape)

if dtype(x) is None and dtype(y) is None:
x = tf.broadcast_to(K.constant(x, dtype=K.floatx()), shape)
y = tf.broadcast_to(K.constant(y, dtype=K.floatx()), shape)
elif dtype(x) is None:
x = tf.broadcast_to(K.constant(x, dtype=dtype(y)), shape)
elif dtype(y) is None:
y = tf.broadcast_to(K.constant(y, dtype=dtype(x)), shape)
else:
x = tf.broadcast_to(x, shape)
y = tf.broadcast_to(x, shape)

if dtype(cond) != 'bool':
cond = K.cast(cond, 'bool')

cond = tf.broadcast_to(cond, shape)
return tf.where(cond, x, y)


def sequence_masking(
x, mask=None, value=0, axis=None, bias=None, return_mask=False
):
"""为序列条件mask的函数
mask: 形如(batch_size, seq_len)的bool矩阵;
value: mask部分要被替换成的值,可以是'-inf'或'inf';
axis: 序列所在轴,默认为1;
bias: 额外的偏置项,或者附加的mask;
return_mask: 是否同时返回对齐后的mask。
"""
if mask is None:
return x
else:
if value == '-inf':
value = -K.infinity()
elif value == 'inf':
value = K.infinity()
value = K.zeros_like(x) + value
if not (mask is None and bias is None):
if mask is None:
if K.dtype(bias) == 'bool':
mask = bias
x = K.where(mask, x, value)
else:
x = x + bias
else:
if axis is None:
axes = [1]
elif isinstance(axis, list):
axes = axis
else:
axes = [axis]

axes = [axis if axis >= 0 else K.ndim(x) + axis for axis in axes]

if K.dtype(mask) != 'bool':
mask = K.cast(mask, 'bool')

if axis is None:
axis = 1
elif axis < 0:
axis = K.ndim(x) + axis
assert axis > 0, 'axis must be greater than 0'
full_mask = align(mask, [0, axes[0]], K.ndim(x))
for axis in axes[1:]:
full_mask = full_mask & align(mask, [0, axis], K.ndim(x))

if K.dtype(mask) != 'bool':
mask = K.cast(mask, 'bool')
mask = align(mask, [0, axis], K.ndim(x))
mask = full_mask
if bias is None:
x = K.where(mask, x, value)
elif K.dtype(bias) == 'bool':
mask = mask & bias
x = K.where(mask, x, value)
else:
x = K.where(mask, x + bias, value)

return K.switch(mask, x, value)
if return_mask:
return x, mask
else:
return x


def batch_gather(params, indices):
Expand Down Expand Up @@ -246,7 +288,7 @@ def divisible_temporal_padding(x, n):
"""将一维向量序列右padding到长度能被n整除
"""
r_len = K.shape(x)[1] % n
p_len = K.switch(r_len > 0, n - r_len, 0)
p_len = K.where(r_len > 0, n - r_len, 0)
return K.temporal_padding(x, (0, p_len))


Expand All @@ -268,17 +310,21 @@ def leaky_relu(x, alpha=0.2):
return tf.nn.leaky_relu(x, alpha=alpha)


def attention_normalize(a, axis=-1, method='softmax'):
def attention_normalize(a, mask=None, axis=-1, method='softmax', bias=None):
"""不同的注意力归一化方案
softmax:常规/标准的指数归一化;
squared_relu:来自 https://arxiv.org/abs/2202.10447 ;
softmax_plus:来自 https://kexue.fm/archives/8823 。
"""
a, mask = sequence_masking(a, mask, -np.inf, axis, bias, True)
if method == 'softmax':
return K.softmax(a, axis=axis)
else:
mask = K.cast(a >= -K.infinity() / 10, K.floatx())
l = K.sum(mask, axis=axis, keepdims=True)
if mask is None:
l = K.cast(K.shape(a)[-1], K.floatx())
else:
mask = K.cast(mask, K.floatx())
l = K.sum(mask, axis=axis, keepdims=True)
if method == 'squared_relu':
return K.relu(a)**2 / l
elif method == 'softmax_plus':
Expand Down Expand Up @@ -331,16 +377,6 @@ def apply_rotary_position_embeddings(sinusoidal, *tensors):
return outputs[0] if len(outputs) == 1 else outputs


def log(x, epsilon=None):
"""给log添加epsilon,防止NaN
"""
if epsilon is None:
return tf.math.log(x)
elif epsilon is True:
epsilon = K.epsilon()
return tf.math.log(K.maximum(x, epsilon))


def multilabel_categorical_crossentropy(y_true, y_pred):
"""多标签分类的交叉熵
说明:
Expand All @@ -353,12 +389,9 @@ def multilabel_categorical_crossentropy(y_true, y_pred):
4. 详情请看:https://kexue.fm/archives/7359 和
https://kexue.fm/archives/9064 。
"""
y_mask = y_pred > -K.infinity() / 10
n_mask = (y_true < 1 - K.epsilon()) & y_mask
p_mask = (y_true > K.epsilon()) & y_mask
infs = K.zeros_like(y_pred) + K.infinity()
y_neg = K.switch(n_mask, y_pred, -infs) + K.log(1 - y_true, True)
y_pos = K.switch(p_mask, -y_pred, -infs) + K.log(y_true, True)
y_mask = K.not_equal(y_pred, -np.inf)
y_neg = K.where(y_mask, y_pred, -np.inf) + K.log(1 - y_true)
y_pos = K.where(y_mask, -y_pred, -np.inf) + K.log(y_true)
zeros = K.zeros_like(y_pred[..., :1])
y_neg = K.concatenate([y_neg, zeros], axis=-1)
y_pos = K.concatenate([y_pos, zeros], axis=-1)
Expand All @@ -381,7 +414,7 @@ def sparse_multilabel_categorical_crossentropy(y_true, y_pred, mask_zero=False):
zeros = K.zeros_like(y_pred[..., :1])
y_pred = K.concatenate([y_pred, zeros], axis=-1)
if mask_zero:
infs = zeros + K.infinity()
infs = zeros + np.inf
y_pred = K.concatenate([infs, y_pred[..., 1:]], axis=-1)
y_pos_2 = batch_gather(y_pred, y_true)
y_pos_1 = K.concatenate([y_pos_2, zeros], axis=-1)
Expand Down Expand Up @@ -485,14 +518,10 @@ def actual_grad_fn(*doutputs):
# 给tf.keras补充上logsumexp
K.logsumexp = getattr(K, 'logsumexp', None) or tf.math.reduce_logsumexp

# 修改版对数函数
K.log = log

# 添加到 keras.backend 上,使其可以像 K.epsilon() 那样操作
K.reshape = reshape
K.flatten = flatten
K.infinity = infinity
K.set_infinity = set_infinity
K.where = where
sys.modules['tensorflow.keras.backend'] = K

custom_objects = {
Expand Down
68 changes: 32 additions & 36 deletions bert4keras/layers.py
Expand Up @@ -144,7 +144,7 @@ def __init__(self, data_format='channels_last', **kwargs):

def call(self, inputs, mask=None):
axis = 1 if self.data_format == 'channels_last' else 2
inputs = sequence_masking(inputs, mask, '-inf', axis)
inputs = sequence_masking(inputs, mask, -np.inf, axis)
return K.max(inputs, axis=axis)

def compute_mask(self, inputs, mask=None):
Expand Down Expand Up @@ -530,12 +530,9 @@ def pay_attention_to(self, inputs, mask=None, **kwargs):
# Attention(续)
if self.attention_scale:
a = a / self.key_size**0.5
if a_bias is not None:
if K.ndim(a_bias) == 3:
a_bias = align(a_bias, [0, -2, -1], K.ndim(a))
a = a + a_bias
a = sequence_masking(a, v_mask, '-inf', -1)
A = attention_normalize(a, -1, self.normalization)
if a_bias is not None and K.ndim(a_bias) == 3:
a_bias = align(a_bias, [0, -2, -1], K.ndim(a))
A = attention_normalize(a, v_mask, -1, self.normalization, a_bias)
if self.attention_dropout:
A = Dropout(self.attention_dropout)(A)
# 完成输出
Expand Down Expand Up @@ -679,10 +676,7 @@ def call(self, inputs, mask=None, a_bias=None, p_bias=None):
a = tf.einsum('bmd,bnd->bmn', q, k)
if self.attention_scale:
a = a / self.key_size**0.5
if a_bias is not None:
a = a + a_bias
a = sequence_masking(a, mask, '-inf', -1)
A = attention_normalize(a, -1, self.normalization)
A = attention_normalize(a, mask, -1, self.normalization, a_bias)
if self.attention_dropout:
A = Dropout(self.attention_dropout)(A)
# 计算输出
Expand Down Expand Up @@ -741,7 +735,9 @@ def call(self, inputs):
inputs = inputs - mean
if self.unit_variance:
variance = K.mean(K.square(inputs), axis=-1, keepdims=True)
inputs = inputs / K.sqrt(variance + self.epsilon)
inputs = tf.math.divide_no_nan(
inputs, K.sqrt(variance + self.epsilon)
)

if self.conditional:
inputs = [inputs, conds]
Expand Down Expand Up @@ -1007,7 +1003,7 @@ def compute_position_ids(self, inputs):
'int32',
)
val_if_large = K.minimum(val_if_large, num_buckets - 1)
ret += K.switch(is_small, n, val_if_large)
ret += K.where(is_small, n, val_if_large)
return ret

def get_config(self):
Expand Down Expand Up @@ -1114,7 +1110,7 @@ def compute_mask(self, inputs, mask=None):
return None

def call(self, inputs, mask=None):
return sequence_masking(inputs, mask, '-inf', 1)
return sequence_masking(inputs, mask, -np.inf, 1)

def target_score(self, y_true, y_pred):
"""计算目标路径的相对概率(还没有归一化)
Expand Down Expand Up @@ -1142,13 +1138,14 @@ def dense_loss(self, y_true, y_pred):
"""y_true需要是one hot形式
"""
# 导出mask并转换数据类型
mask = K.all(K.greater(y_pred, -1e6), axis=2, keepdims=True)
mask = K.cast(mask, K.floatx())
mask = K.all(K.not_equal(y_pred, -np.inf), axis=2, keepdims=True)
# 计算目标分数
y_true, y_pred = y_true * mask, y_pred * mask
y_true = K.where(mask, y_true, 0)
y_pred = K.where(mask, y_pred, 0)
target_score = self.target_score(y_true, y_pred)
# 递归计算log Z
init_states = [y_pred[:, 0]]
mask = K.cast(mask, K.floatx())
y_pred = K.concatenate([y_pred, mask], axis=2)
input_length = K.int_shape(y_pred[:, 1:])[1]
log_norm, _, _ = K.rnn(
Expand Down Expand Up @@ -1183,7 +1180,7 @@ def sparse_accuracy(self, y_true, y_pred):
此处y_true需要是整数形式(非one hot)
"""
# 导出mask并转换数据类型
mask = K.all(K.greater(y_pred, -1e6), axis=2)
mask = K.all(K.not_equal(y_pred, -np.inf), axis=2)
mask = K.cast(mask, K.floatx())
# y_true需要重新明确一下shape和dtype
y_true = K.reshape(y_true, K.shape(y_pred)[:-1])
Expand Down Expand Up @@ -1273,7 +1270,7 @@ def compute_mask(self, inputs, mask=None):
return None

def call(self, inputs, mask=None):
return sequence_masking(inputs, mask, '-inf', 1)
return sequence_masking(inputs, mask, -np.inf, 1)

def reverse_sequence(self, inputs, mask=None):
if mask is None:
Expand All @@ -1286,7 +1283,7 @@ def basic_loss(self, y_true, y_pred, go_backwards=False):
"""y_true需要是整数形式(非one hot)
"""
# 导出mask并转换数据类型
mask = K.all(K.greater(y_pred, -1e6), axis=2)
mask = K.all(K.not_equal(y_pred, -np.inf), axis=2)
mask = K.cast(mask, K.floatx())
# y_true需要重新明确一下shape和dtype
y_true = K.reshape(y_true, K.shape(y_pred)[:-1])
Expand All @@ -1309,7 +1306,7 @@ def basic_loss(self, y_true, y_pred, go_backwards=False):
history = tf.einsum('bnd,kd->bnk', history, r_trans)
# 计算loss
history = K.concatenate([y_pred[:, :1], history[:, :-1]], 1)
y_pred = (y_pred + history) / 2
y_pred = K.where(mask[..., None], (y_pred + history) / 2, 0)
loss = K.sparse_categorical_crossentropy(
y_true, y_pred, from_logits=True
)
Expand All @@ -1333,7 +1330,7 @@ def basic_accuracy(self, y_true, y_pred, go_backwards=False):
此处y_true需要是整数形式(非one hot)
"""
# 导出mask并转换数据类型
mask = K.all(K.greater(y_pred, -1e6), axis=2)
mask = K.all(K.not_equal(y_pred, -np.inf), axis=2)
mask = K.cast(mask, K.floatx())
# y_true需要重新明确一下shape和dtype
y_true = K.reshape(y_true, K.shape(y_pred)[:-1])
Expand Down Expand Up @@ -1431,16 +1428,16 @@ def call(self, inputs, mask=None):
pos = SinusoidalPositionEmbedding(self.head_size, 'zero')(inputs)
qw, kw = apply_rotary_position_embeddings(pos, qw, kw)
# 计算内积
logits = tf.einsum('bmhd,bnhd->bhmn', qw, kw)
# 排除padding
logits = sequence_masking(logits, mask, '-inf', 2)
logits = sequence_masking(logits, mask, '-inf', 3)
logits = tf.einsum('bmhd,bnhd->bhmn', qw, kw) / self.head_size**0.5
# 排除下三角
if self.tril_mask:
mask = tf.linalg.band_part(K.ones_like(logits), 0, -1)
logits = logits - (1 - mask) * K.infinity()
# scale返回
return logits / self.head_size**0.5
tril_mask = tf.linalg.band_part(K.ones_like(logits[0, 0]), 0, -1)
tril_mask = K.cast(tril_mask, 'bool')
else:
tril_mask = None
# 返回最终结果
print([tril_mask])
return sequence_masking(logits, mask, -np.inf, [2, 3], tril_mask)

def compute_output_shape(self, input_shape):
return (input_shape[0], self.heads, input_shape[1], input_shape[1])
Expand Down Expand Up @@ -1489,15 +1486,14 @@ def call(self, inputs, mask=None):
logits = tf.einsum('bmd,bnd->bmn', qw, kw) / self.head_size**0.5
bias = tf.einsum('bnh->bhn', self.q_dense(inputs)) / 2
logits = logits[:, None] + bias[:, ::2, None] + bias[:, 1::2, :, None]
# 排除padding
logits = sequence_masking(logits, mask, '-inf', 2)
logits = sequence_masking(logits, mask, '-inf', 3)
# 排除下三角
if self.tril_mask:
mask = tf.linalg.band_part(K.ones_like(logits), 0, -1)
logits = logits - (1 - mask) * K.infinity()
tril_mask = tf.linalg.band_part(K.ones_like(logits[0, 0]), 0, -1)
tril_mask = K.cast(tril_mask, 'bool')
else:
tril_mask = None
# 返回最终结果
return logits
return sequence_masking(logits, mask, -np.inf, [2, 3], tril_mask)


class Loss(Layer):
Expand Down

0 comments on commit 20a4694

Please sign in to comment.