In [9]:
import numpy as np
import os as os
import tensorflow as tf
tfk = tf.keras

In [11]:
dirname = 'QuinticData'
data = np.load(os.path.join(dirname, 'dataset.npz'))
BASIS = np.load(os.path.join(dirname, 'basis.pickle'), allow_pickle=True)

In [12]:
from cymetric.models.tfhelper import prepare_tf_basis, train_model

In [13]:
BASIS = prepare_tf_basis(BASIS)

In [14]:
kappa = np.real(BASIS['KAPPA'].numpy());
kappa


0.017635807

In [15]:
from cymetric.models.callbacks import RicciCallback, SigmaCallback, VolkCallback, KaehlerCallback, TransitionCallback
from cymetric.models.tfmodels import MultFSModel
from cymetric.models.metrics import SigmaLoss, KaehlerLoss, TransitionLoss, VolkLoss, RicciLoss, TotalLoss

In [16]:
rcb = RicciCallback((data['X_val'], data['y_val']), data['val_pullbacks'])
scb = SigmaCallback((data['X_val'], data['y_val']))
volkcb = VolkCallback((data['X_val'], data['y_val']))
kcb = KaehlerCallback((data['X_val'], data['y_val']))
tcb = TransitionCallback((data['X_val'], data['y_val']))
cb_list = [rcb, scb, kcb, tcb, volkcb]

In [17]:
nlayer = 3
nHidden = 64
act = 'gelu'
nEpochs = 50
bSizes = [64, 50000]
alpha = [1., 1., 1., 1., 1.]
nfold = 3
n_in = 2*5
n_out = nfold**2

In [19]:
nn = tf.keras.Sequential()
nn.add(tfk.Input(shape=(n_in)))
for i in range(nlayer):
    nn.add(tfk.layers.Dense(nHidden, activation=act))
nn.add(tfk.layers.Dense(n_out, use_bias=False))

In [20]:
fmodel = MultFSModel(nn, BASIS, alpha=alpha)

In [21]:
cmetrics = [TotalLoss(), SigmaLoss(), KaehlerLoss(), TransitionLoss(), VolkLoss(), RicciLoss()]
opt = tfk.optimizers.Adam()

In [22]:
fmodel, training_history = train_model(fmodel, data, optimizer=opt, epochs=nEpochs, batch_sizes=[64, 50000], 
                                       verbose=1, custom_metrics=cmetrics, callbacks=cb_list)


Epoch  1/50




 - Ricci measure val:      1.3031
 - Sigma measure val:      0.1150
 - Kaehler measure val:    0.0069
 - Transition measure val: 0.0107
 - Volk val:               4.9516

Epoch  2/50
 - Ricci measure val:      1.1804
 - Sigma measure val:      0.0992
 - Kaehler measure val:    0.0064
 - Transition measure val: 0.0062
 - Volk val:               5.1252

Epoch  3/50
 - Ricci measure val:      1.1867
 - Sigma measure val:      0.0953
 - Kaehler measure val:    0.0052
 - Transition measure val: 0.0036
 - Volk val:               5.1173

Epoch  4/50
 - Ricci measure val:      1.1401
 - Sigma measure val:      0.0999
 - Kaehler measure val:    0.0046
 - Transition measure val: 0.0029
 - Volk val:               5.2988

Epoch  5/50
 - Ricci measure val:      1.0901
 - Sigma measure val:      0.0923
 - Kaehler measure val:    0.0047
 - Transition measure val: 0.0027
 - Volk val:               5.0347

Epoch  6/50
 - Ricci measure val:      1.0815
 - Sigma measure val:      0.0929
 - Kaehler measur

 - Ricci measure val:      0.9364
 - Sigma measure val:      0.0833
 - Kaehler measure val:    0.0048
 - Transition measure val: 0.0024
 - Volk val:               4.9764

Epoch 16/50
 - Ricci measure val:      0.9117
 - Sigma measure val:      0.0837
 - Kaehler measure val:    0.0046
 - Transition measure val: 0.0024
 - Volk val:               5.0757

Epoch 17/50
 - Ricci measure val:      0.9173
 - Sigma measure val:      0.0864
 - Kaehler measure val:    0.0048
 - Transition measure val: 0.0024
 - Volk val:               5.2349

Epoch 18/50
 - Ricci measure val:      0.9011
 - Sigma measure val:      0.0808
 - Kaehler measure val:    0.0047
 - Transition measure val: 0.0024
 - Volk val:               5.0138

Epoch 19/50
 - Ricci measure val:      0.8863
 - Sigma measure val:      0.0864
 - Kaehler measure val:    0.0046
 - Transition measure val: 0.0024
 - Volk val:               5.1697

Epoch 20/50
 - Ricci measure val:      0.8747
 - Sigma measure val:      0.0849
 - Kaehler measur

 - Kaehler measure val:    0.0054
 - Transition measure val: 0.0025
 - Volk val:               5.2155

Epoch 30/50
 - Ricci measure val:      0.6766
 - Sigma measure val:      0.0638
 - Kaehler measure val:    0.0058
 - Transition measure val: 0.0025
 - Volk val:               5.0414

Epoch 31/50
 - Ricci measure val:      0.6755
 - Sigma measure val:      0.0656
 - Kaehler measure val:    0.0057
 - Transition measure val: 0.0024
 - Volk val:               5.1008

Epoch 32/50
 - Ricci measure val:      0.6369
 - Sigma measure val:      0.0599
 - Kaehler measure val:    0.0055
 - Transition measure val: 0.0024
 - Volk val:               5.1654

Epoch 33/50
 - Ricci measure val:      0.5182
 - Sigma measure val:      0.0568
 - Kaehler measure val:    0.0051
 - Transition measure val: 0.0022
 - Volk val:               5.1481

Epoch 34/50
 - Ricci measure val:      0.4516
 - Sigma measure val:      0.0463
 - Kaehler measure val:    0.0051
 - Transition measure val: 0.0020
 - Volk val:     

 - Ricci measure val:      0.1829
 - Sigma measure val:      0.0212
 - Kaehler measure val:    0.0035
 - Transition measure val: 0.0013
 - Volk val:               5.0965

Epoch 45/50
 - Ricci measure val:      0.1708
 - Sigma measure val:      0.0231
 - Kaehler measure val:    0.0034
 - Transition measure val: 0.0013
 - Volk val:               5.0842

Epoch 46/50
 - Ricci measure val:      0.1550
 - Sigma measure val:      0.0211
 - Kaehler measure val:    0.0034
 - Transition measure val: 0.0013
 - Volk val:               5.0578

Epoch 47/50
 - Ricci measure val:      0.1749
 - Sigma measure val:      0.0208
 - Kaehler measure val:    0.0032
 - Transition measure val: 0.0013
 - Volk val:               5.1465

Epoch 48/50
 - Ricci measure val:      0.1465
 - Sigma measure val:      0.0201
 - Kaehler measure val:    0.0033
 - Transition measure val: 0.0013
 - Volk val:               5.0578

Epoch 49/50
 - Ricci measure val:      0.1502
 - Sigma measure val:      0.0186
 - Kaehler measur

In [42]:
data['X_val'][0:2]

array([[-1.54617016e-01,  1.00000000e+00,  8.03067615e-01,
        -2.29330237e-02,  3.64707644e-01,  1.46141291e-01,
         5.55111512e-17, -5.87226954e-01,  2.89385198e-01,
         3.36354160e-01],
       [ 1.00000000e+00, -6.46316341e-01, -9.98160936e-01,
        -1.10571894e-01, -2.50638013e-01,  0.00000000e+00,
        -1.89901480e-01,  2.63191987e-02,  2.58948551e-02,
        -2.83181408e-01]])

In [44]:
arg = tf.convert_to_tensor(data['X_val'][0:2], dtype=tf.float32)
arg

<tf.Tensor: shape=(2, 10), dtype=float32, numpy=
array([[-1.5461701e-01,  1.0000000e+00,  8.0306762e-01, -2.2933023e-02,
         3.6470765e-01,  1.4614129e-01,  5.5511151e-17, -5.8722693e-01,
         2.8938520e-01,  3.3635417e-01],
       [ 1.0000000e+00, -6.4631635e-01, -9.9816096e-01, -1.1057189e-01,
        -2.5063801e-01,  0.0000000e+00, -1.8990149e-01,  2.6319198e-02,
         2.5894854e-02, -2.8318140e-01]], dtype=float32)>

In [47]:
ricci_measure(fmodel, args, y_true, pullbacks=None, verbose=0):

SyntaxError: invalid syntax (4072780148.py, line 1)