In [1]:
%load_ext autoreload
%autoreload 2

from model_awg import AdversarialWeightGenenerator, create_model
import numpy as np
import tensorflow as tf
import tensorflow.keras.backend as K

In [2]:
#(batch, num_layers, seq_len, seq_len)
model, optimizer = create_model("test_model", load_checkpoint=False)

Input shape is: (4, 100, 100)
Output features: 100


In [3]:
x = tf.random.normal((55, 4, 100, 100))
x.shape

TensorShape([55, 4, 100, 100])

In [4]:
y = model(x)
print(y.shape)

(55, 4, 100, 100)


In [5]:
model.trainable_weights[0].shape

TensorShape([100, 16])

In [6]:
trainable_count = np.sum([K.count_params(w) for w in model.trainable_weights])
non_trainable_count = int(np.sum([K.count_params(w) for w in model.non_trainable_weights]))

print(f'Total params: {trainable_count + non_trainable_count}')
print(f'Trainable params: {trainable_count}')
print(f'Non-trainable params: {non_trainable_count}')

Total params: 3316
Trainable params: 3316
Non-trainable params: 0


In [40]:
def create_dataset(N=100):
    x =  tf.nn.softmax(tf.random.uniform((N, 4, 100, 100)))
    y = tf.nn.softmax(-2.0*x**3 + 5.0*x**2 -2.0*x + 1.0)
    return (x, y)

In [41]:
data = create_dataset()

In [42]:
x, y = data

In [43]:
print(x.shape)
print(y.shape)
tf.reduce_sum(y[0][0][0])

(100, 4, 100, 100)
(100, 4, 100, 100)


<tf.Tensor: id=7961844, shape=(), dtype=float32, numpy=1.0>

In [44]:
def evaluate(model, xt, yt):    
    
    # Eval
    ypred = model(xt, training=False)
    loss = loss_obj(yt, ypred)
    return loss

In [45]:
model, optimizer = create_model("test_model", load_checkpoint=False)
loss_obj = tf.keras.losses.KLDivergence()

(xt, yt) = create_dataset(N=100)
epoch = 1

(x, y) = create_dataset(N=100)
print(f"{epoch:2}. Out: {evaluate(model, xt, yt):.7e}  In: {evaluate(model, x, y):.7e}  ")

for _ in range(20):        
    epoch += 1

    with tf.GradientTape() as tape:
            pred = model(x, training=True)        

            loss = loss_obj(y, pred)
#                 print(loss)
    #         loss = tf.keras.losses.binary_crossentropy(y_true=y, y_pred=logits, from_logits=True)

    grads = tape.gradient(loss, model.trainable_weights)                
    optimizer.apply_gradients(zip(grads, model.trainable_weights))

    (xt, yt) = create_dataset(N=100)
    print(f"{epoch:2}. Out: {evaluate(model, xt, yt):.7e}  In: {evaluate(model, x, y):.7e}  ")
    

Input shape is: (4, 100, 100)
Output features: 100
 1. Out: 7.3020419e-05  In: 7.2906783e-05  
 2. Out: 5.3025986e-05  In: 5.2954994e-05  
 3. Out: 3.8366714e-05  In: 3.8296585e-05  
 4. Out: 2.9036850e-05  In: 2.8941542e-05  
 5. Out: 2.4016526e-05  In: 2.3980332e-05  
 6. Out: 2.2108687e-05  In: 2.2092632e-05  
 7. Out: 2.1981714e-05  In: 2.1979897e-05  
 8. Out: 2.2623657e-05  In: 2.2630724e-05  
 9. Out: 2.3392247e-05  In: 2.3387109e-05  
10. Out: 2.3902703e-05  In: 2.3908628e-05  
11. Out: 2.4087725e-05  In: 2.4080786e-05  
12. Out: 2.3933972e-05  In: 2.3924104e-05  
13. Out: 2.3508001e-05  In: 2.3523116e-05  
14. Out: 2.2976332e-05  In: 2.2980428e-05  
15. Out: 2.2380866e-05  In: 2.2391505e-05  
16. Out: 2.1825324e-05  In: 2.1823582e-05  
17. Out: 2.1298260e-05  In: 2.1312508e-05  
18. Out: 2.0881540e-05  In: 2.0870457e-05  
19. Out: 2.0485140e-05  In: 2.0494521e-05  
20. Out: 2.0172822e-05  In: 2.0176629e-05  
21. Out: 1.9923020e-05  In: 1.9908130e-05  


In [46]:
ypred = model(xt, training=False)
print(yt[0][0][0])
print(ypred[0][0][0])

tf.Tensor(
[0.01008101 0.01006052 0.01008044 0.01006802 0.0099774  0.01005746
 0.00992219 0.01003918 0.01004559 0.00996095 0.01005557 0.01000738
 0.00990914 0.00997245 0.00990048 0.00998525 0.01005339 0.01004901
 0.00991599 0.00991833 0.01003806 0.00997801 0.01006414 0.00992602
 0.01006282 0.00995271 0.01001343 0.00991134 0.01001556 0.0099549
 0.01003063 0.00996015 0.01005401 0.00990058 0.01003408 0.01007544
 0.009912   0.00995179 0.01001517 0.00992082 0.01000205 0.01000024
 0.01000135 0.00999754 0.00997605 0.01002908 0.01004526 0.01000728
 0.00997173 0.0100193  0.00997164 0.00991191 0.01002974 0.01003049
 0.00991107 0.00994424 0.01000276 0.00994613 0.0100759  0.01008024
 0.01005971 0.0099936  0.01007728 0.01005337 0.00994357 0.01003835
 0.00997705 0.0100767  0.01005906 0.00990703 0.01002434 0.01006802
 0.00994824 0.00993563 0.01001273 0.00998096 0.00992093 0.009959
 0.00992031 0.01002712 0.00991291 0.0100529  0.01003497 0.00997655
 0.00996934 0.0100457  0.01005608 0.0100636  0.0100120