In [5]:
import numpy as np
import pandas as pd
from sklearn import preprocessing
try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split, cross_val_score, KFold
from PIL import Image
from tensorflow import keras
from tensorflow.keras import datasets,layers,optimizers,Sequential,metrics

print(tf.__version__)

2.3.0


In [6]:
def save_img(imgs,names):
  img_new = Image.new('L',(280,280))
  index = 0
  for i in range(0,280,80):
    for j in range(0,280,80):
      img = imgs[index]
      img = Image.fromarray(img,mode='L')
      img_new.paste(img,(i,j))
      index+=1
  img_new.save(names)

In [7]:
def feature_scale(x):
  x = tf.cast(x,dtype=tf.float32)/255.
#  y = tf.cast(y,dtype=tf.int32)
  return x

In [8]:
#Dim reduct nums
dim_reduce = 10
batch_num = 128
lr = 1e-3

In [9]:
(x,y),(x_test,y_test) = datasets.fashion_mnist.load_data()
data = tf.data.Dataset.from_tensor_slices(x)
data = data.map(feature_scale).shuffle(10000).batch(128)

data_test = tf.data.Dataset.from_tensor_slices(x_test)
data_test = data_test.map(feature_scale).batch(128)

data_iter = iter(data)
samples = next(data_iter)
print(samples[0].shape,samples[1].shape)

(28, 28) (28, 28)


In [10]:
class VAE(keras.Model):
  def __init__(self):
    super(VAE,self).__init__()
    #encoder
    self.fc_layer_1 = layers.Dense(128)
    self.fc_layer_2 = layers.Dense(dim_reduce)
    self.fc_layer_3 = layers.Dense(dim_reduce)
    
    
    self.fc_layer_4 = layers.Dense(128)
    self.fc_layer_5 = layers.Dense(784)
    

  def model_encoder(self, x):
    h = tf.nn.relu(self.fc_layer_1(x))
    mean_fc = self.fc_layer_2(h)
    var_fc = self.fc_layer_3(h)
    return mean_fc,var_fc

  def model_decoder(self, z):
    out = tf.nn.relu(self.fc_layer_4(z))
    out = self.fc_layer_5(out)
    return out

  def reparameter(self,mean_x,var_x):
    eps = tf.random.normal(var_x.shape)
    std = tf.exp(var_x)**0.5
    z = mean_x + std*eps
    return z
  
  def call(self, inputs, training=None):
    mean_x,var_x = self.model_encoder(inputs)
    z = self.reparameter(mean_x,var_x)
    x = self.model_decoder(z)
    return x,mean_x,var_x

In [11]:
model = VAE()
model.build(input_shape=(4,784))
optimizer = optimizers.Adam(lr=lr)
model.summary()

Model: "vae"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                multiple                  100480    
_________________________________________________________________
dense_1 (Dense)              multiple                  1290      
_________________________________________________________________
dense_2 (Dense)              multiple                  1290      
_________________________________________________________________
dense_3 (Dense)              multiple                  1408      
_________________________________________________________________
dense_4 (Dense)              multiple                  101136    
Total params: 205,604
Trainable params: 205,604
Non-trainable params: 0
_________________________________________________________________


In [12]:
!rm -rf img_result
!mkdir img_result

In [13]:
optimizer = optimizers.Adam(lr=lr)

for i in range(10):
  for step,x in enumerate(data):
    x = tf.reshape(x,[-1,784])
    with tf.GradientTape() as tape:
      logits,mean_x,var_x = model(x)
      loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=x,logits=logits)
      loss = tf.reduce_sum(loss)/x.shape[0]
      kl_div = -0.5*(var_x+1-mean_x**2-tf.exp(var_x))
      kl_div = tf.reduce_sum(kl_div)/x.shape[0]
      
      loss = loss + 1.*kl_div
    grads = tape.gradient(loss,model.trainable_variables)
    optimizer.apply_gradients(zip(grads,model.trainable_variables))
    
    if step %100==0:
      print(i,step,'loss:',float(loss),'kl_div:',float(kl_div))
      
  x = next(iter(data_test))
  val_x = tf.reshape(x,[-1,784])
  logits,_,_ = model(val_x)
  x_hat = tf.sigmoid(logits)
  x_hat = tf.reshape(x_hat,[-1,28,28])
  x_hat = x_hat.numpy()*255
  x_hat = x_hat.astype(np.uint8)
  save_img(x_hat,'img_result/AE_img_%d.png'%i)

0 0 loss: 548.5337524414062 kl_div: 2.107118844985962
0 100 loss: 307.5446472167969 kl_div: 16.437299728393555
0 200 loss: 282.6195983886719 kl_div: 15.404300689697266
0 300 loss: 274.8070983886719 kl_div: 15.179455757141113
0 400 loss: 285.6806640625 kl_div: 14.350698471069336
1 0 loss: 278.7014465332031 kl_div: 14.927692413330078
1 100 loss: 262.3731994628906 kl_div: 14.367724418640137
1 200 loss: 258.19793701171875 kl_div: 13.738924026489258
1 300 loss: 257.8957214355469 kl_div: 14.30457878112793
1 400 loss: 272.4373474121094 kl_div: 14.147831916809082
2 0 loss: 249.55059814453125 kl_div: 15.126795768737793
2 100 loss: 254.0316925048828 kl_div: 14.885117530822754
2 200 loss: 264.9345703125 kl_div: 14.743483543395996
2 300 loss: 255.69561767578125 kl_div: 14.14968490600586
2 400 loss: 255.53799438476562 kl_div: 14.801573753356934
3 0 loss: 246.38577270507812 kl_div: 14.415285110473633
3 100 loss: 234.06495666503906 kl_div: 14.18582820892334
3 200 loss: 253.3296356201172 kl_div: 14.22