Skip to content

Commit

Permalink
Bug fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
edwardyehuang committed Jul 23, 2022
1 parent cbbff92 commit 0b4a72e
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions layers/self_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(

def call(self, inputs, training=None):

inputs_shape = tf.shape(inputs)
batch_size, height, width, channels = get_tensor_shape(inputs)

query = self.query_conv(inputs, training=training)
key = self.key_conv(inputs, training=training)
Expand All @@ -66,17 +66,14 @@ def call(self, inputs, training=None):
attention = get_attention(query, key, apply_scale=self.apply_scale)

if self.vis_manager.recording:
height = inputs_shape[-3]
width = inputs_shape[-2]

self.vis_manager.easy_add(tf.reshape(attention, (-1, height, width, height, width)), name="attention_map")

value = flatten_hw(value)

attention = self.attention_dropout(attention, training=training)

value = tf.matmul(attention, value)
value = tf.reshape(value, inputs_shape)
value = tf.reshape(value, [batch_size, height, width, value.shape[-1]])

value = self.feature_dropout(value, training=training)

Expand Down

0 comments on commit 0b4a72e

Please sign in to comment.