-
Notifications
You must be signed in to change notification settings - Fork 0
/
decoder.py
31 lines (23 loc) · 878 Bytes
/
decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
"""
Decoder for VQ-VAE.
"""
import tensorflow as tf
from tensorflow.keras.models import Model
class Decoder(Model):
def __init__(self, embedding_dim: int):
super(Decoder, self).__init__()
self.embedding_dim = embedding_dim
self.deconv_1 = tf.keras.layers.Conv2DTranspose(64, 3, strides=2, padding="same", activation="relu")
self.deconv_2 = tf.keras.layers.Conv2DTranspose(32, 3, strides=2, padding="same", activation="relu")
self.deconv_3 = tf.keras.layers.Conv2DTranspose(1, 3, padding="same")
def call(self, inputs):
x = self.deconv_1(inputs)
x = self.deconv_2(x)
x = self.deconv_3(x)
return x
if __name__ == "__main__":
decoder = Decoder(64)
decoder.build(input_shape=(128, 7, 7, 64))
decoder.summary()
a = decoder(tf.random.normal((5, 7, 7, 64)))
print(a.shape)