In [3]:
%matplotlib qt

import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d
from poisson_fem import PoissonFEM
import ROM
import GenerativeSurrogate as gs
import Data as dta
import numpy as np
import scipy.sparse as sps
import scipy.sparse.linalg as lg
import time
import petsc4py
import sys
petsc4py.init(sys.argv)
from petsc4py import PETSc
import torch
from torch import optim

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
# Some parameters
lin_dim_rom = 4                      # Linear number of rom elements
a = np.array([1, 1, 0])              # Boundary condition function coefficients
dtype = torch.float                  # Tensor data type
supervised_samples = {n for n in range(16)}
unsupervised_samples = {n for n in range(16, 272)}
dim_z = 20
dim_z_supervised=8

In [3]:
# Define mesh and boundary conditions
mesh = PoissonFEM.RectangularMesh(np.ones(lin_dim_rom)/lin_dim_rom)
# mesh.plot()

def origin(x):
    return np.abs(x[0]) < np.finfo(float).eps and np.abs(x[1]) < np.finfo(float).eps

def ess_boundary_fun(x):
    return 0.0
mesh.set_essential_boundary(origin, ess_boundary_fun)

def domain_boundary(x):
    # unit square
    return np.abs(x[0]) < np.finfo(float).eps or np.abs(x[1]) < np.finfo(float).eps or \
            np.abs(x[0]) > 1.0 - np.finfo(float).eps or np.abs(x[1]) > 1.0 - np.finfo(float).eps
mesh.set_natural_boundary(domain_boundary)

def flux(x):
    q = np.array([a[0] + a[2]*x[1], a[1] + a[2]*x[0]])
    return q

In [4]:
#Specify right hand side and stiffness matrix
rhs = PoissonFEM.RightHandSide(mesh)
rhs.set_natural_rhs(mesh, flux)
K = PoissonFEM.StiffnessMatrix(mesh)
rhs.set_rhs_stencil(mesh, K)

In [5]:
trainingData = dta.StokesData(supervised_samples, unsupervised_samples)
trainingData.read_data()
# trainingData.plotMicrostruct(1)
trainingData.reshape_microstructure_image()

In [6]:
# define rom
rom = ROM.ROM(mesh, K, rhs, trainingData.output_resolution**2)

In [7]:
model = gs.GenerativeSurrogate(rom, trainingData, dim_z=dim_z, dim_z_supervised=dim_z_supervised)

In [8]:
# model.save()
# loaded_model = gs.GenerativeSurrogate()
# loaded_model.load()

In [9]:
for n in range(model.data.n_supervised_samples):
    print('sample == ', n)
    model.log_lambdac_mean[n].max_iter = 3e4
    model.log_lambdac_mean[n].converge(model, model.data.n_supervised_samples, mode=n)

sample ==  0
loss_lambda_c =  45945471369216.0
loss_lambda_c =  357083152384.0
loss_lambda_c =  341057994752.0
loss_lambda_c =  335339487232.0
loss_lambda_c =  332537856000.0
loss_lambda_c =  331105796096.0
Epoch  2503: reducing learning rate of group 0 to 3.0000e-03.
Epoch  2520: reducing learning rate of group 0 to 3.0000e-04.
Epoch  2536: reducing learning rate of group 0 to 3.0000e-05.
Epoch  2552: reducing learning rate of group 0 to 3.0000e-06.
sample ==  1
loss_lambda_c =  20042003513344.0
loss_lambda_c =  196687839232.0
loss_lambda_c =  190270406656.0
loss_lambda_c =  188073230336.0
Epoch  1718: reducing learning rate of group 0 to 3.0000e-03.
Epoch  1735: reducing learning rate of group 0 to 3.0000e-04.
Epoch  1751: reducing learning rate of group 0 to 3.0000e-05.
Epoch  1767: reducing learning rate of group 0 to 3.0000e-06.
sample ==  2
loss_lambda_c =  655100825567232.0
Epoch    83: reducing learning rate of group 0 to 3.0000e-03.
loss_lambda_c =  2471118766080.0
loss_lambda

In [10]:
model.fit(n_steps=5, with_precisions=False, z_iterations=10, thetac_iterations=10000, thetaf_iterations=10, lambdac_iterations=10)

step =  0
loss z =  12748624.0
loss z =  12723624.0
loss z =  12700334.0
loss z =  12678778.0
loss z =  12658641.0
loss z =  12639934.0
loss z =  12622554.0
loss z =  12606394.0
loss z =  12591324.0
z step =  5.24725079536438 s
loss_f =  12559499.0
loss_f =  12141416.0
thetaf step =  6.934699296951294 s
sample ==  0
loss_lambda_c =  331093868544.0
Epoch  2568: reducing learning rate of group 0 to 1.0000e-03.
Epoch  2584: reducing learning rate of group 0 to 1.0000e-04.
Epoch  2600: reducing learning rate of group 0 to 1.0000e-05.
Epoch  2616: reducing learning rate of group 0 to 1.0000e-06.
sample ==  1
loss_lambda_c =  187719385088.0
Epoch  1783: reducing learning rate of group 0 to 1.0000e-03.
Epoch  1799: reducing learning rate of group 0 to 1.0000e-04.
Epoch  1815: reducing learning rate of group 0 to 1.0000e-05.
Epoch  1831: reducing learning rate of group 0 to 1.0000e-06.
sample ==  2
loss_lambda_c =  1249009074176.0
loss_lambda_c =  1241229295616.0
Epoch  8610: reducing learning

loss_c =  19858.251953125
loss_c =  19556.1015625
loss_c =  19250.921875
loss_c =  18942.8359375
loss_c =  18631.95703125
loss_c =  18318.4140625
loss_c =  18002.341796875
loss_c =  17683.876953125
loss_c =  17363.169921875
loss_c =  17040.376953125
loss_c =  16715.65625
loss_c =  16389.18359375
loss_c =  16061.1259765625
loss_c =  15731.669921875
loss_c =  15401.0029296875
loss_c =  15069.318359375
loss_c =  14736.818359375
loss_c =  14403.705078125
loss_c =  14070.1884765625
loss_c =  13736.482421875
loss_c =  13402.8046875
loss_c =  13069.3779296875
loss_c =  12736.421875
loss_c =  12404.1611328125
loss_c =  12072.8125
loss_c =  11742.59375
loss_c =  11413.697265625
loss_c =  11086.2861328125
loss_c =  10760.421875
loss_c =  10435.8564453125
loss_c =  10110.6826171875
loss_c =  9751.181640625
loss_c =  9406.8818359375
loss_c =  9070.755859375
loss_c =  8738.939453125
loss_c =  8411.693359375
loss_c =  8089.2568359375
loss_c =  7771.7646484375
loss_c =  7459.13134765625
loss_c =  715

loss_c =  219.93344116210938
loss_c =  219.9246368408203
loss_c =  219.91827392578125
loss_c =  219.91372680664062
loss_c =  219.9104461669922
loss_c =  219.90818786621094
loss_c =  219.9066162109375
loss_c =  219.905517578125
loss_c =  219.90475463867188
loss_c =  219.90426635742188
loss_c =  219.90390014648438
loss_c =  219.9036865234375
loss_c =  219.903564453125
loss_c =  219.90345764160156
loss_c =  219.90338134765625
loss_c =  219.90335083007812
loss_c =  219.90335083007812
loss_c =  219.90330505371094
loss_c =  219.90330505371094
loss_c =  219.9033203125
loss_c =  219.90328979492188
loss_c =  219.9033203125
Epoch 12501: reducing learning rate of group 0 to 1.4000e-04.
loss_c =  219.90330505371094
loss_c =  219.90330505371094
loss_c =  219.90328979492188
loss_c =  219.90330505371094
loss_c =  219.9033203125
loss_c =  219.90328979492188
loss_c =  219.90330505371094
loss_c =  219.90328979492188
loss_c =  219.90330505371094
loss_c =  219.90330505371094
loss_c =  219.90328979492188
l

loss_c =  219.94989013671875
loss_c =  219.9498748779297
loss_c =  219.94989013671875
loss_c =  219.9498748779297
loss_c =  219.94985961914062
loss_c =  219.94985961914062
loss_c =  219.94985961914062
loss_c =  219.9498748779297
loss_c =  219.9498748779297
loss_c =  219.94989013671875
loss_c =  219.9498748779297
loss_c =  219.9498748779297
loss_c =  219.94989013671875
loss_c =  219.94989013671875
loss_c =  219.94989013671875
loss_c =  219.9498748779297
loss_c =  219.94989013671875
loss_c =  219.94989013671875
loss_c =  219.9498748779297
loss_c =  219.9498748779297
Epoch 18507: reducing learning rate of group 0 to 1.4000e-04.
loss_c =  219.94989013671875
loss_c =  219.9498748779297
loss_c =  219.9498748779297
loss_c =  219.9498748779297
loss_c =  219.9498748779297
loss_c =  219.94989013671875
loss_c =  219.9498748779297
loss_c =  219.94989013671875
loss_c =  219.9498748779297
loss_c =  219.9498748779297
loss_c =  219.9498748779297
loss_c =  219.94989013671875
loss_c =  219.9498901367187

Epoch  2006: reducing learning rate of group 0 to 3.0000e-08.
sample ==  15
loss_lambda_c =  54343307264.0
Epoch  1561: reducing learning rate of group 0 to 3.0000e-05.
Epoch  1577: reducing learning rate of group 0 to 3.0000e-06.
Epoch  1593: reducing learning rate of group 0 to 3.0000e-07.
Epoch  1609: reducing learning rate of group 0 to 3.0000e-08.
lambdac step =  1.1580958366394043 s
loss_c =  219.96408081054688
loss_c =  219.96408081054688
loss_c =  219.96408081054688
loss_c =  219.96408081054688
loss_c =  219.96408081054688
loss_c =  219.96408081054688
loss_c =  219.9640655517578
loss_c =  219.964111328125
loss_c =  219.96408081054688
loss_c =  219.96408081054688
loss_c =  219.96408081054688
loss_c =  219.9640655517578
loss_c =  219.96408081054688
loss_c =  219.96408081054688
loss_c =  219.96408081054688
loss_c =  219.96408081054688
loss_c =  219.96408081054688
loss_c =  219.9640655517578
loss_c =  219.96405029296875
loss_c =  219.96408081054688
loss_c =  219.9640655517578
Epoch

loss_lambda_c =  257742323712.0
Epoch  4898: reducing learning rate of group 0 to 1.0000e-05.
Epoch  4914: reducing learning rate of group 0 to 1.0000e-06.
Epoch  4930: reducing learning rate of group 0 to 1.0000e-07.
Epoch  4946: reducing learning rate of group 0 to 1.0000e-08.
sample ==  14
loss_lambda_c =  507602567168.0
Epoch  2022: reducing learning rate of group 0 to 1.0000e-05.
Epoch  2038: reducing learning rate of group 0 to 1.0000e-06.
Epoch  2054: reducing learning rate of group 0 to 1.0000e-07.
Epoch  2070: reducing learning rate of group 0 to 1.0000e-08.
sample ==  15
loss_lambda_c =  54343245824.0
Epoch  1625: reducing learning rate of group 0 to 1.0000e-05.
Epoch  1641: reducing learning rate of group 0 to 1.0000e-06.
Epoch  1657: reducing learning rate of group 0 to 1.0000e-07.
Epoch  1673: reducing learning rate of group 0 to 1.0000e-08.
lambdac step =  1.297306776046753 s
loss_c =  219.9687042236328
loss_c =  219.9686279296875
loss_c =  219.96810913085938
loss_c =  21

In [11]:
model.fit(n_steps=10000, z_iterations=50, with_precisions=True, thetaf_iterations=200, thetac_iterations=500, lambdac_iterations=100)

step =  5
loss z =  12102745.0
loss z =  12101985.0
loss z =  12101238.0
loss z =  12100503.0
Epoch    54: reducing learning rate of group 0 to 2.0000e-05.
loss z =  12099777.0
loss z =  12099635.0
loss z =  12099492.0
loss z =  12099349.0
loss z =  12099207.0
loss z =  12099068.0
loss z =  12098927.0
loss z =  12098785.0
loss z =  12098645.0
loss z =  12098506.0
loss z =  12098366.0
Epoch    65: reducing learning rate of group 0 to 4.0000e-06.
loss z =  12098228.0
loss z =  12098199.0
loss z =  12098171.0
loss z =  12098142.0
loss z =  12098087.0
loss z =  12098060.0
loss z =  12098031.0
loss z =  12098003.0
loss z =  12097976.0
loss z =  12097948.0
Epoch    76: reducing learning rate of group 0 to 8.0000e-07.
loss z =  12097920.0
loss z =  12097916.0
loss z =  12097909.0
loss z =  12097904.0
loss z =  12097898.0
loss z =  12097893.0
loss z =  12097888.0
loss z =  12097881.0
loss z =  12097876.0
loss z =  12097870.0
loss z =  12097864.0
Epoch    87: reducing learning rate of group 0 t

loss_f =  9161749.0
loss_f =  9149955.0
loss_f =  9137338.0
loss_f =  9127921.0
loss_f =  9119121.0
loss_f =  9110574.0
loss_f =  9102600.0
loss_f =  9094747.0
loss_f =  9087185.0
loss_f =  9079846.0
loss_f =  9072704.0
loss_f =  9065800.0
loss_f =  9059102.0
loss_f =  9052601.0
loss_f =  9046259.0
loss_f =  9040058.0
loss_f =  9033962.0
loss_f =  9027976.0
loss_f =  9041393.0
loss_f =  9016392.0
loss_f =  9010658.0
loss_f =  9004854.0
loss_f =  8999164.0
loss_f =  8993521.0
loss_f =  8987916.0
loss_f =  8982344.0
loss_f =  8976816.0
loss_f =  8971324.0
loss_f =  8965867.0
loss_f =  8960444.0
loss_f =  8955057.0
loss_f =  8949705.0
loss_f =  8944394.0
loss_f =  8939115.0
thetaf step =  139.2079062461853 s
sample ==  0
loss_lambda_c =  11067.1513671875
Epoch  2953: reducing learning rate of group 0 to 1.0000e-05.
Epoch  2969: reducing learning rate of group 0 to 1.0000e-06.
Epoch  2985: reducing learning rate of group 0 to 1.0000e-07.
Epoch  3001: reducing learning rate of group 0 to 1.

Epoch  9108: reducing learning rate of group 0 to 1.0000e-08.
sample ==  3
loss_lambda_c =  43826.3984375
Epoch  4100: reducing learning rate of group 0 to 1.0000e-05.
Epoch  4116: reducing learning rate of group 0 to 1.0000e-06.
Epoch  4132: reducing learning rate of group 0 to 1.0000e-07.
Epoch  4148: reducing learning rate of group 0 to 1.0000e-08.
sample ==  4
loss_lambda_c =  24291.0078125
Epoch  5162: reducing learning rate of group 0 to 1.0000e-05.
Epoch  5178: reducing learning rate of group 0 to 1.0000e-06.
Epoch  5194: reducing learning rate of group 0 to 1.0000e-07.
Epoch  5210: reducing learning rate of group 0 to 1.0000e-08.
sample ==  5
loss_lambda_c =  37792.8046875
Epoch  2300: reducing learning rate of group 0 to 1.0000e-05.
Epoch  2316: reducing learning rate of group 0 to 1.0000e-06.
Epoch  2332: reducing learning rate of group 0 to 1.0000e-07.
Epoch  2348: reducing learning rate of group 0 to 1.0000e-08.
sample ==  6
loss_lambda_c =  23237.1640625
Epoch  3187: reduc

Epoch  3703: reducing learning rate of group 0 to 1.0000e-06.
Epoch  3719: reducing learning rate of group 0 to 1.0000e-07.
Epoch  3735: reducing learning rate of group 0 to 1.0000e-08.
sample ==  9
loss_lambda_c =  8165.09423828125
Epoch  2198: reducing learning rate of group 0 to 1.0000e-05.
Epoch  2214: reducing learning rate of group 0 to 1.0000e-06.
Epoch  2230: reducing learning rate of group 0 to 1.0000e-07.
Epoch  2246: reducing learning rate of group 0 to 1.0000e-08.
sample ==  10
loss_lambda_c =  19093.2734375
Epoch  3734: reducing learning rate of group 0 to 1.0000e-05.
Epoch  3750: reducing learning rate of group 0 to 1.0000e-06.
Epoch  3766: reducing learning rate of group 0 to 1.0000e-07.
Epoch  3782: reducing learning rate of group 0 to 1.0000e-08.
sample ==  11
loss_lambda_c =  10354.796875
Epoch  8719: reducing learning rate of group 0 to 1.0000e-05.
Epoch  8735: reducing learning rate of group 0 to 1.0000e-06.
Epoch  8751: reducing learning rate of group 0 to 1.0000e-

Epoch  5270: reducing learning rate of group 0 to 1.0000e-08.
sample ==  14
loss_lambda_c =  15625.32421875
Epoch  2343: reducing learning rate of group 0 to 1.0000e-05.
Epoch  2359: reducing learning rate of group 0 to 1.0000e-06.
Epoch  2375: reducing learning rate of group 0 to 1.0000e-07.
Epoch  2391: reducing learning rate of group 0 to 1.0000e-08.
sample ==  15
loss_lambda_c =  1820.1455078125
Epoch  1947: reducing learning rate of group 0 to 1.0000e-05.
Epoch  1963: reducing learning rate of group 0 to 1.0000e-06.
Epoch  1979: reducing learning rate of group 0 to 1.0000e-07.
Epoch  1995: reducing learning rate of group 0 to 1.0000e-08.
lambdac step =  1.1771996021270752 s
loss_c =  263.59075927734375
loss_c =  258.0021057128906
loss_c =  256.96026611328125
loss_c =  256.08941650390625
loss_c =  255.296875
loss_c =  254.5257568359375
loss_c =  253.7234344482422
loss_c =  252.83477783203125
loss_c =  251.8072509765625
loss_c =  250.62158203125
thetac step =  2.325016736984253 s
st

loss_c =  255.03109741210938
loss_c =  254.50762939453125
loss_c =  253.97039794921875
loss_c =  253.4197235107422
loss_c =  252.8570556640625
loss_c =  252.2843780517578
loss_c =  251.70428466796875
Epoch 42516: reducing learning rate of group 0 to 2.0000e-05.
thetac step =  2.802664041519165 s
step =  11
loss z =  13748923.0
loss z =  13563496.0
loss z =  13396714.0
loss z =  13243769.0
loss z =  13102802.0
loss z =  12972042.0
loss z =  12848922.0
loss z =  12732394.0
loss z =  12621846.0
loss z =  12517354.0
Epoch   322: reducing learning rate of group 0 to 2.0000e-05.
loss z =  12419062.0
loss z =  12400548.0
loss z =  12382292.0
loss z =  12364484.0
loss z =  12346878.0
loss z =  12329594.0
loss z =  12312718.0
loss z =  12296100.0
loss z =  12279680.0
loss z =  12247698.0
Epoch   333: reducing learning rate of group 0 to 4.0000e-06.
loss z =  12232208.0
loss z =  12229174.0
loss z =  12226106.0
loss z =  12223090.0
loss z =  12220076.0
loss z =  12217035.0
loss z =  12214001.0
l

loss z =  13339704.0
loss z =  13338000.0
loss z =  13336281.0
loss z =  13334559.0
loss z =  13332868.0
loss z =  13331131.0
loss z =  13329442.0
loss z =  13327734.0
loss z =  13326010.0
loss z =  13324303.0
loss z =  13322601.0
Epoch   388: reducing learning rate of group 0 to 8.0000e-07.
loss z =  13320927.0
loss z =  13320570.0
loss z =  13320236.0
loss z =  13319900.0
loss z =  13319545.0
loss z =  13319212.0
loss z =  13318877.0
loss z =  13318166.0
loss z =  13317830.0
loss z =  13317497.0
Epoch   399: reducing learning rate of group 0 to 1.6000e-07.
z step =  24.19505214691162 s
loss_f =  10564726.0
Epoch  1452: reducing learning rate of group 0 to 2.5000e-04.
loss_f =  10210422.0
loss_f =  10094984.0
loss_f =  10020502.0
loss_f =  9966384.0
loss_f =  9923632.0
loss_f =  9887497.0
Epoch  1483: reducing learning rate of group 0 to 1.2500e-04.
loss_f =  9858521.0
loss_f =  9843812.0
loss_f =  9829938.0
loss_f =  9816634.0
loss_f =  9803703.0
loss_f =  9791027.0
Epoch  1514: redu

loss_f =  9168032.0
loss_f =  9135566.0
Epoch  1669: reducing learning rate of group 0 to 2.5000e-04.
loss_f =  9118657.0
loss_f =  9105987.0
loss_f =  9094174.0
loss_f =  9083057.0
loss_f =  9072471.0
loss_f =  9062275.0
Epoch  1700: reducing learning rate of group 0 to 1.2500e-04.
loss_f =  9055354.0
loss_f =  9050520.0
loss_f =  9045754.0
loss_f =  9041049.0
loss_f =  9036396.0
loss_f =  9031788.0
Epoch  1731: reducing learning rate of group 0 to 6.2500e-05.
loss_f =  9028134.0
loss_f =  9025866.0
loss_f =  9023612.0
loss_f =  9021368.0
loss_f =  9019128.0
loss_f =  9016865.0
Epoch  1762: reducing learning rate of group 0 to 3.1250e-05.
loss_f =  9014870.0
loss_f =  9013758.0
loss_f =  9012642.0
loss_f =  9011532.0
loss_f =  9010424.0
loss_f =  9009320.0
Epoch  1793: reducing learning rate of group 0 to 1.5625e-05.
loss_f =  9008198.0
loss_f =  9007648.0
loss_f =  9007078.0
loss_f =  9006516.0
loss_f =  9005964.0
loss_f =  9005402.0
loss_f =  9004840.0
Epoch  1824: reducing learning

loss_f =  8783126.0
loss_f =  8782764.0
loss_f =  8782403.0
loss_f =  8782050.0
loss_f =  8781688.0
loss_f =  8781326.0
Epoch  2010: reducing learning rate of group 0 to 7.8125e-06.
loss_f =  8781074.0
loss_f =  8780896.0
loss_f =  8780720.0
loss_f =  8780537.0
loss_f =  8780358.0
thetaf step =  138.17585372924805 s
sample ==  0
loss_lambda_c =  11070.74609375
Epoch  3465: reducing learning rate of group 0 to 1.0000e-05.
Epoch  3481: reducing learning rate of group 0 to 1.0000e-06.
Epoch  3497: reducing learning rate of group 0 to 1.0000e-07.
Epoch  3513: reducing learning rate of group 0 to 1.0000e-08.
sample ==  1
loss_lambda_c =  6408.7548828125
Epoch  2686: reducing learning rate of group 0 to 1.0000e-05.
Epoch  2702: reducing learning rate of group 0 to 1.0000e-06.
Epoch  2718: reducing learning rate of group 0 to 1.0000e-07.
Epoch  2734: reducing learning rate of group 0 to 1.0000e-08.
sample ==  2
loss_lambda_c =  30089.958984375
Epoch  9508: reducing learning rate of group 0 to

Epoch  9620: reducing learning rate of group 0 to 1.0000e-08.
sample ==  3
loss_lambda_c =  43826.24609375
Epoch  4612: reducing learning rate of group 0 to 1.0000e-05.
Epoch  4628: reducing learning rate of group 0 to 1.0000e-06.
Epoch  4644: reducing learning rate of group 0 to 1.0000e-07.
Epoch  4660: reducing learning rate of group 0 to 1.0000e-08.
sample ==  4
loss_lambda_c =  24292.5
Epoch  5674: reducing learning rate of group 0 to 1.0000e-05.
Epoch  5690: reducing learning rate of group 0 to 1.0000e-06.
Epoch  5706: reducing learning rate of group 0 to 1.0000e-07.
Epoch  5722: reducing learning rate of group 0 to 1.0000e-08.
sample ==  5
loss_lambda_c =  37790.4921875
Epoch  2812: reducing learning rate of group 0 to 1.0000e-05.
Epoch  2828: reducing learning rate of group 0 to 1.0000e-06.
Epoch  2844: reducing learning rate of group 0 to 1.0000e-07.
Epoch  2860: reducing learning rate of group 0 to 1.0000e-08.
sample ==  6
loss_lambda_c =  23240.693359375
Epoch  3699: reducing

Epoch  2924: reducing learning rate of group 0 to 1.0000e-08.
sample ==  6
loss_lambda_c =  23240.41015625
Epoch  3763: reducing learning rate of group 0 to 1.0000e-05.
Epoch  3779: reducing learning rate of group 0 to 1.0000e-06.
Epoch  3795: reducing learning rate of group 0 to 1.0000e-07.
Epoch  3811: reducing learning rate of group 0 to 1.0000e-08.
sample ==  7
loss_lambda_c =  14117.64453125
Epoch 13082: reducing learning rate of group 0 to 1.0000e-05.
Epoch 13098: reducing learning rate of group 0 to 1.0000e-06.
Epoch 13114: reducing learning rate of group 0 to 1.0000e-07.
Epoch 13130: reducing learning rate of group 0 to 1.0000e-08.
sample ==  8
loss_lambda_c =  2421.74072265625
Epoch  4203: reducing learning rate of group 0 to 1.0000e-05.
Epoch  4219: reducing learning rate of group 0 to 1.0000e-06.
Epoch  4235: reducing learning rate of group 0 to 1.0000e-07.
Epoch  4251: reducing learning rate of group 0 to 1.0000e-08.
sample ==  9
loss_lambda_c =  8160.7431640625
Epoch  2710

Epoch  4315: reducing learning rate of group 0 to 1.0000e-08.
sample ==  9
loss_lambda_c =  8155.7568359375
Epoch  2774: reducing learning rate of group 0 to 1.0000e-05.
Epoch  2790: reducing learning rate of group 0 to 1.0000e-06.
Epoch  2806: reducing learning rate of group 0 to 1.0000e-07.
Epoch  2822: reducing learning rate of group 0 to 1.0000e-08.
sample ==  10
loss_lambda_c =  19097.03125
Epoch  4310: reducing learning rate of group 0 to 1.0000e-05.
Epoch  4326: reducing learning rate of group 0 to 1.0000e-06.
Epoch  4342: reducing learning rate of group 0 to 1.0000e-07.
Epoch  4358: reducing learning rate of group 0 to 1.0000e-08.
sample ==  11
loss_lambda_c =  10349.2197265625
Epoch  9297: reducing learning rate of group 0 to 1.0000e-05.
Epoch  9313: reducing learning rate of group 0 to 1.0000e-06.
Epoch  9329: reducing learning rate of group 0 to 1.0000e-07.
Epoch  9345: reducing learning rate of group 0 to 1.0000e-08.
sample ==  12
loss_lambda_c =  9313.521484375
Epoch  4531

Epoch  9377: reducing learning rate of group 0 to 1.0000e-06.
Epoch  9393: reducing learning rate of group 0 to 1.0000e-07.
Epoch  9409: reducing learning rate of group 0 to 1.0000e-08.
sample ==  12
loss_lambda_c =  9315.4501953125
Epoch  4595: reducing learning rate of group 0 to 1.0000e-05.
Epoch  4611: reducing learning rate of group 0 to 1.0000e-06.
Epoch  4627: reducing learning rate of group 0 to 1.0000e-07.
Epoch  4643: reducing learning rate of group 0 to 1.0000e-08.
sample ==  13
loss_lambda_c =  8872.8017578125
Epoch  5801: reducing learning rate of group 0 to 1.0000e-05.
Epoch  5817: reducing learning rate of group 0 to 1.0000e-06.
Epoch  5833: reducing learning rate of group 0 to 1.0000e-07.
Epoch  5849: reducing learning rate of group 0 to 1.0000e-08.
sample ==  14
loss_lambda_c =  15654.1591796875
Epoch  2919: reducing learning rate of group 0 to 1.0000e-05.
Epoch  2935: reducing learning rate of group 0 to 1.0000e-06.
Epoch  2951: reducing learning rate of group 0 to 1.

loss_c =  252.50633239746094
loss_c =  252.3580780029297
loss_c =  252.20738220214844
loss_c =  252.0531768798828
loss_c =  251.8949432373047
loss_c =  251.7328643798828
loss_c =  251.56649780273438
loss_c =  251.39593505859375
loss_c =  251.22085571289062
thetac step =  4.041189670562744 s
Saving model...
...saving done.
step =  20
loss z =  10990074.0
loss z =  10983318.0
loss z =  10977654.0
loss z =  10972662.0
loss z =  10968103.0
loss z =  10963862.0
loss z =  10959904.0
loss z =  10956206.0
loss z =  10952742.0
loss z =  10949490.0
loss z =  10946424.0
Epoch   719: reducing learning rate of group 0 to 2.0000e-05.
loss z =  10943503.0
loss z =  10942942.0
loss z =  10942385.0
loss z =  10941833.0
loss z =  10941286.0
loss z =  10940742.0
loss z =  10940204.0
loss z =  10939669.0
loss z =  10938602.0
loss z =  10938075.0
Epoch   730: reducing learning rate of group 0 to 4.0000e-06.
loss z =  10937550.0
loss z =  10937446.0
loss z =  10937340.0
loss z =  10937237.0
loss z =  109371

loss z =  10858709.0
loss z =  10858690.0
loss z =  10858670.0
loss z =  10858650.0
loss z =  10858630.0
loss z =  10858610.0
loss z =  10858571.0
loss z =  10858551.0
loss z =  10858532.0
loss z =  10858512.0
Epoch   797: reducing learning rate of group 0 to 1.6000e-07.
z step =  24.493196964263916 s
loss_f =  8093798.5
loss_f =  8078272.0
loss_f =  8069469.0
loss_f =  8063256.0
loss_f =  8058440.0
loss_f =  8054440.5
loss_f =  8050999.0
Epoch  3247: reducing learning rate of group 0 to 2.5000e-04.
loss_f =  8049143.0
loss_f =  8047709.0
loss_f =  8046375.0
loss_f =  8045086.0
loss_f =  8043841.5
loss_f =  8042615.0
Epoch  3278: reducing learning rate of group 0 to 1.2500e-04.
loss_f =  8041773.5
loss_f =  8041180.0
loss_f =  8040589.0
loss_f =  8040009.0
loss_f =  8039430.5
loss_f =  8038853.0
loss_f =  8038271.5
loss_f =  8037695.5
loss_f =  8037131.0
loss_f =  8036560.0
loss_f =  8035994.5
loss_f =  8035427.5
loss_f =  8034870.0
loss_f =  8034297.5
loss_f =  8033742.0
loss_f =  803

loss_f =  7960390.5
loss_f =  7958718.0
loss_f =  7957062.5
loss_f =  7955408.5
loss_f =  7953788.5
loss_f =  7952176.5
loss_f =  7950578.5
loss_f =  7949000.0
thetaf step =  139.29452753067017 s
sample ==  0
loss_lambda_c =  11069.1953125
Epoch  3977: reducing learning rate of group 0 to 1.0000e-05.
Epoch  3993: reducing learning rate of group 0 to 1.0000e-06.
Epoch  4009: reducing learning rate of group 0 to 1.0000e-07.
Epoch  4025: reducing learning rate of group 0 to 1.0000e-08.
sample ==  1
loss_lambda_c =  6406.73095703125
Epoch  3200: reducing learning rate of group 0 to 1.0000e-05.
Epoch  3216: reducing learning rate of group 0 to 1.0000e-06.
Epoch  3232: reducing learning rate of group 0 to 1.0000e-07.
Epoch  3248: reducing learning rate of group 0 to 1.0000e-08.
sample ==  2
loss_lambda_c =  30090.072265625
Epoch 10020: reducing learning rate of group 0 to 1.0000e-05.
Epoch 10036: reducing learning rate of group 0 to 1.0000e-06.
Epoch 10052: reducing learning rate of group 0 

Epoch  3372: reducing learning rate of group 0 to 1.0000e-08.
sample ==  6
loss_lambda_c =  23241.76953125
Epoch  4211: reducing learning rate of group 0 to 1.0000e-05.
Epoch  4227: reducing learning rate of group 0 to 1.0000e-06.
Epoch  4243: reducing learning rate of group 0 to 1.0000e-07.
Epoch  4259: reducing learning rate of group 0 to 1.0000e-08.
sample ==  7
loss_lambda_c =  14120.333984375
Epoch 13530: reducing learning rate of group 0 to 1.0000e-05.
Epoch 13546: reducing learning rate of group 0 to 1.0000e-06.
Epoch 13562: reducing learning rate of group 0 to 1.0000e-07.
Epoch 13578: reducing learning rate of group 0 to 1.0000e-08.
sample ==  8
loss_lambda_c =  2420.864013671875
Epoch  4652: reducing learning rate of group 0 to 1.0000e-05.
Epoch  4668: reducing learning rate of group 0 to 1.0000e-06.
Epoch  4684: reducing learning rate of group 0 to 1.0000e-07.
Epoch  4700: reducing learning rate of group 0 to 1.0000e-08.
sample ==  9
loss_lambda_c =  8155.19580078125
Epoch  3

Epoch  9794: reducing learning rate of group 0 to 1.0000e-08.
sample ==  12
loss_lambda_c =  9305.62109375
Epoch  4980: reducing learning rate of group 0 to 1.0000e-05.
Epoch  4996: reducing learning rate of group 0 to 1.0000e-06.
Epoch  5012: reducing learning rate of group 0 to 1.0000e-07.
Epoch  5028: reducing learning rate of group 0 to 1.0000e-08.
sample ==  13
loss_lambda_c =  8870.927734375
Epoch  6186: reducing learning rate of group 0 to 1.0000e-05.
Epoch  6202: reducing learning rate of group 0 to 1.0000e-06.
Epoch  6218: reducing learning rate of group 0 to 1.0000e-07.
Epoch  6234: reducing learning rate of group 0 to 1.0000e-08.
sample ==  14
loss_lambda_c =  15622.1513671875
Epoch  3303: reducing learning rate of group 0 to 1.0000e-05.
Epoch  3319: reducing learning rate of group 0 to 1.0000e-06.
Epoch  3335: reducing learning rate of group 0 to 1.0000e-07.
Epoch  3351: reducing learning rate of group 0 to 1.0000e-08.
sample ==  15
loss_lambda_c =  1813.604736328125
Epoch 

Epoch 49523: reducing learning rate of group 0 to 2.0000e-05.
loss_c =  258.9169006347656
loss_c =  258.73883056640625
loss_c =  258.5659484863281
loss_c =  258.3948059082031
loss_c =  258.2250061035156
loss_c =  258.055908203125
loss_c =  257.8868103027344
loss_c =  257.7181396484375
loss_c =  257.54901123046875
thetac step =  4.991853475570679 s
step =  26
loss z =  10607161.0
loss z =  10602578.0
loss z =  10599052.0
loss z =  10596082.0
loss z =  10593371.0
loss z =  10590794.0
loss z =  10588300.0
loss z =  10585867.0
loss z =  10583486.0
loss z =  10581156.0
loss z =  10578869.0
Epoch   989: reducing learning rate of group 0 to 2.0000e-05.
loss z =  10576627.0
loss z =  10576184.0
loss z =  10575745.0
loss z =  10575307.0
loss z =  10574869.0
loss z =  10574432.0
loss z =  10574001.0
loss z =  10573571.0
loss z =  10572717.0
loss z =  10572291.0
Epoch  1000: reducing learning rate of group 0 to 4.0000e-06.
loss z =  10571866.0
loss z =  10571782.0
loss z =  10571697.0
loss z =  1

loss z =  10526920.0
loss z =  10526836.0
loss z =  10526753.0
loss z =  10526669.0
Epoch  1056: reducing learning rate of group 0 to 8.0000e-07.
loss z =  10526586.0
loss z =  10526568.0
loss z =  10526552.0
loss z =  10526536.0
loss z =  10526518.0
loss z =  10526503.0
loss z =  10526468.0
loss z =  10526453.0
loss z =  10526436.0
loss z =  10526421.0
Epoch  1067: reducing learning rate of group 0 to 1.6000e-07.
z step =  24.822913885116577 s
loss_f =  7788205.0
loss_f =  7776305.0
loss_f =  7769732.0
loss_f =  7764244.0
loss_f =  7760338.0
loss_f =  7757125.0
loss_f =  7754357.0
loss_f =  7752077.0
loss_f =  7750088.0
loss_f =  7748246.5
loss_f =  7746536.0
loss_f =  7744890.0
loss_f =  7743312.0
loss_f =  7741788.5
loss_f =  7740299.5
loss_f =  7738843.5
loss_f =  7737422.0
loss_f =  7736022.0
loss_f =  7734661.5
loss_f =  7733319.0
loss_f =  7731997.0
loss_f =  7730697.0
loss_f =  7729419.5
loss_f =  7728157.5
loss_f =  7726918.0
loss_f =  7725694.0
loss_f =  7724489.5
loss_f =  7

loss_f =  7691216.0
loss_f =  7690532.0
loss_f =  7689856.0
loss_f =  7689185.0
loss_f =  7688522.0
loss_f =  7687863.0
loss_f =  7687205.0
loss_f =  7686550.0
loss_f =  7685910.0
loss_f =  7685268.5
loss_f =  7684633.5
thetaf step =  140.00559329986572 s
sample ==  0
loss_lambda_c =  11073.9228515625
Epoch  4361: reducing learning rate of group 0 to 1.0000e-05.
Epoch  4377: reducing learning rate of group 0 to 1.0000e-06.
Epoch  4393: reducing learning rate of group 0 to 1.0000e-07.
Epoch  4409: reducing learning rate of group 0 to 1.0000e-08.
sample ==  1
loss_lambda_c =  6410.36474609375
Epoch  3584: reducing learning rate of group 0 to 1.0000e-05.
Epoch  3600: reducing learning rate of group 0 to 1.0000e-06.
Epoch  3616: reducing learning rate of group 0 to 1.0000e-07.
Epoch  3632: reducing learning rate of group 0 to 1.0000e-08.
sample ==  2
loss_lambda_c =  30094.87890625
Epoch 10404: reducing learning rate of group 0 to 1.0000e-05.
Epoch 10420: reducing learning rate of group 0 

Epoch 10516: reducing learning rate of group 0 to 1.0000e-08.
sample ==  3
loss_lambda_c =  43826.3515625
Epoch  5508: reducing learning rate of group 0 to 1.0000e-05.
Epoch  5524: reducing learning rate of group 0 to 1.0000e-06.
Epoch  5540: reducing learning rate of group 0 to 1.0000e-07.
Epoch  5556: reducing learning rate of group 0 to 1.0000e-08.
sample ==  4
loss_lambda_c =  24293.828125
Epoch  6570: reducing learning rate of group 0 to 1.0000e-05.
Epoch  6586: reducing learning rate of group 0 to 1.0000e-06.
Epoch  6602: reducing learning rate of group 0 to 1.0000e-07.
Epoch  6618: reducing learning rate of group 0 to 1.0000e-08.
sample ==  5
loss_lambda_c =  37790.7109375
Epoch  3708: reducing learning rate of group 0 to 1.0000e-05.
Epoch  3724: reducing learning rate of group 0 to 1.0000e-06.
Epoch  3740: reducing learning rate of group 0 to 1.0000e-07.
Epoch  3756: reducing learning rate of group 0 to 1.0000e-08.
sample ==  6
loss_lambda_c =  23244.955078125
Epoch  4595: redu

Epoch  5148: reducing learning rate of group 0 to 1.0000e-08.
sample ==  9
loss_lambda_c =  8150.59521484375
Epoch  3606: reducing learning rate of group 0 to 1.0000e-05.
Epoch  3622: reducing learning rate of group 0 to 1.0000e-06.
Epoch  3638: reducing learning rate of group 0 to 1.0000e-07.
Epoch  3654: reducing learning rate of group 0 to 1.0000e-08.
sample ==  10
loss_lambda_c =  19102.18359375
Epoch  5142: reducing learning rate of group 0 to 1.0000e-05.
Epoch  5158: reducing learning rate of group 0 to 1.0000e-06.
Epoch  5174: reducing learning rate of group 0 to 1.0000e-07.
Epoch  5190: reducing learning rate of group 0 to 1.0000e-08.
sample ==  11
loss_lambda_c =  10354.291015625
Epoch 10130: reducing learning rate of group 0 to 1.0000e-05.
Epoch 10146: reducing learning rate of group 0 to 1.0000e-06.
Epoch 10162: reducing learning rate of group 0 to 1.0000e-07.
Epoch 10178: reducing learning rate of group 0 to 1.0000e-08.
sample ==  12
loss_lambda_c =  9305.13671875
Epoch  53

Epoch  6682: reducing learning rate of group 0 to 1.0000e-08.
sample ==  14
loss_lambda_c =  15624.1748046875
Epoch  3751: reducing learning rate of group 0 to 1.0000e-05.
Epoch  3767: reducing learning rate of group 0 to 1.0000e-06.
Epoch  3783: reducing learning rate of group 0 to 1.0000e-07.
Epoch  3799: reducing learning rate of group 0 to 1.0000e-08.
sample ==  15
loss_lambda_c =  1812.937255859375
Epoch  3357: reducing learning rate of group 0 to 1.0000e-05.
Epoch  3373: reducing learning rate of group 0 to 1.0000e-06.
Epoch  3389: reducing learning rate of group 0 to 1.0000e-07.
Epoch  3405: reducing learning rate of group 0 to 1.0000e-08.
lambdac step =  1.1797993183135986 s
loss_c =  256.115966796875
Epoch 52526: reducing learning rate of group 0 to 2.0000e-05.
loss_c =  255.08331298828125
loss_c =  254.93875122070312
loss_c =  254.81008911132812
loss_c =  254.6889190673828
loss_c =  254.57159423828125
loss_c =  254.4554901123047
loss_c =  254.33981323242188
loss_c =  254.2234

loss_c =  251.26968383789062
thetac step =  5.850454807281494 s
step =  33
loss z =  10440106.0
loss z =  10425001.0
loss z =  10411490.0
loss z =  10399506.0
loss z =  10389057.0
loss z =  10380322.0
loss z =  10373503.0
loss z =  10368129.0
loss z =  10363580.0
loss z =  10359480.0
loss z =  10355696.0
Epoch  1304: reducing learning rate of group 0 to 2.0000e-05.
loss z =  10352136.0
loss z =  10351450.0
loss z =  10350778.0
loss z =  10350108.0
loss z =  10349446.0
loss z =  10348791.0
loss z =  10348142.0
loss z =  10347498.0
loss z =  10346236.0
loss z =  10345612.0
Epoch  1315: reducing learning rate of group 0 to 4.0000e-06.
loss z =  10344993.0
loss z =  10344870.0
loss z =  10344749.0
loss z =  10344625.0
loss z =  10344501.0
loss z =  10344379.0
loss z =  10344257.0
loss z =  10344136.0
loss z =  10344016.0
loss z =  10343893.0
loss z =  10343772.0
Epoch  1326: reducing learning rate of group 0 to 8.0000e-07.
loss z =  10343651.0
loss z =  10343627.0
loss z =  10343603.0
loss

loss z =  10321802.0
loss z =  10321780.0
loss z =  10321732.0
loss z =  10321710.0
loss z =  10321686.0
loss z =  10321662.0
Epoch  1382: reducing learning rate of group 0 to 1.6000e-07.
z step =  24.87155508995056 s
loss_f =  7611435.0
loss_f =  7596435.0
loss_f =  7588597.5
loss_f =  7580639.0
loss_f =  7574542.5
loss_f =  7569843.0
Epoch  5846: reducing learning rate of group 0 to 2.5000e-04.
loss_f =  7566273.0
loss_f =  7564764.0
loss_f =  7563439.0
loss_f =  7562232.0
loss_f =  7561117.0
loss_f =  7560073.5
loss_f =  7559082.0
Epoch  5877: reducing learning rate of group 0 to 1.2500e-04.
loss_f =  7558509.0
loss_f =  7558044.0
loss_f =  7557589.0
loss_f =  7557134.5
loss_f =  7556681.5
loss_f =  7556238.5
loss_f =  7555806.0
loss_f =  7555369.0
loss_f =  7554941.0
loss_f =  7554519.5
loss_f =  7554097.0
loss_f =  7553670.5
loss_f =  7553252.0
loss_f =  7552840.5
loss_f =  7552418.0
loss_f =  7552016.0
loss_f =  7551620.5
loss_f =  7551205.0
loss_f =  7550805.0
loss_f =  7550409.

loss_f =  7524322.5
loss_f =  7523649.0
loss_f =  7523009.5
thetaf step =  139.82229089736938 s
sample ==  0
loss_lambda_c =  11075.2255859375
Epoch  4809: reducing learning rate of group 0 to 1.0000e-05.
Epoch  4825: reducing learning rate of group 0 to 1.0000e-06.
Epoch  4841: reducing learning rate of group 0 to 1.0000e-07.
Epoch  4857: reducing learning rate of group 0 to 1.0000e-08.
sample ==  1
loss_lambda_c =  6410.3857421875
Epoch  4032: reducing learning rate of group 0 to 1.0000e-05.
Epoch  4048: reducing learning rate of group 0 to 1.0000e-06.
Epoch  4064: reducing learning rate of group 0 to 1.0000e-07.
Epoch  4080: reducing learning rate of group 0 to 1.0000e-08.
sample ==  2
loss_lambda_c =  30096.36328125
Epoch 10852: reducing learning rate of group 0 to 1.0000e-05.
Epoch 10868: reducing learning rate of group 0 to 1.0000e-06.
Epoch 10884: reducing learning rate of group 0 to 1.0000e-07.
Epoch 10900: reducing learning rate of group 0 to 1.0000e-08.
sample ==  3
loss_lamb

Epoch  4172: reducing learning rate of group 0 to 1.0000e-06.
Epoch  4188: reducing learning rate of group 0 to 1.0000e-07.
Epoch  4204: reducing learning rate of group 0 to 1.0000e-08.
sample ==  6
loss_lambda_c =  23246.685546875
Epoch  5043: reducing learning rate of group 0 to 1.0000e-05.
Epoch  5059: reducing learning rate of group 0 to 1.0000e-06.
Epoch  5075: reducing learning rate of group 0 to 1.0000e-07.
Epoch  5091: reducing learning rate of group 0 to 1.0000e-08.
sample ==  7
loss_lambda_c =  14114.615234375
Epoch 14362: reducing learning rate of group 0 to 1.0000e-05.
Epoch 14378: reducing learning rate of group 0 to 1.0000e-06.
Epoch 14394: reducing learning rate of group 0 to 1.0000e-07.
Epoch 14410: reducing learning rate of group 0 to 1.0000e-08.
sample ==  8
loss_lambda_c =  2422.3779296875
Epoch  5484: reducing learning rate of group 0 to 1.0000e-05.
Epoch  5500: reducing learning rate of group 0 to 1.0000e-06.
Epoch  5516: reducing learning rate of group 0 to 1.0000

Epoch 10610: reducing learning rate of group 0 to 1.0000e-07.
Epoch 10626: reducing learning rate of group 0 to 1.0000e-08.
sample ==  12
loss_lambda_c =  9303.220703125
Epoch  5814: reducing learning rate of group 0 to 1.0000e-05.
Epoch  5830: reducing learning rate of group 0 to 1.0000e-06.
Epoch  5846: reducing learning rate of group 0 to 1.0000e-07.
Epoch  5862: reducing learning rate of group 0 to 1.0000e-08.
sample ==  13
loss_lambda_c =  8874.162109375
Epoch  7018: reducing learning rate of group 0 to 1.0000e-05.
Epoch  7034: reducing learning rate of group 0 to 1.0000e-06.
Epoch  7050: reducing learning rate of group 0 to 1.0000e-07.
Epoch  7066: reducing learning rate of group 0 to 1.0000e-08.
sample ==  14
loss_lambda_c =  15624.0
Epoch  4135: reducing learning rate of group 0 to 1.0000e-05.
Epoch  4151: reducing learning rate of group 0 to 1.0000e-06.
Epoch  4167: reducing learning rate of group 0 to 1.0000e-07.
Epoch  4183: reducing learning rate of group 0 to 1.0000e-08.
s

loss_c =  255.15489196777344
loss_c =  254.80857849121094
loss_c =  254.46376037597656
loss_c =  254.11526489257812
loss_c =  253.7621612548828
loss_c =  253.40426635742188
loss_c =  253.04107666015625
loss_c =  252.67242431640625
loss_c =  252.29827880859375
thetac step =  5.927362680435181 s
step =  39
loss z =  10311402.0
loss z =  10296637.0
loss z =  10283661.0
loss z =  10272344.0
loss z =  10262645.0
loss z =  10254684.0
loss z =  10248373.0
loss z =  10243394.0
loss z =  10239230.0
loss z =  10235398.0
loss z =  10231808.0
Epoch  1574: reducing learning rate of group 0 to 2.0000e-05.
loss z =  10228402.0
loss z =  10227746.0
loss z =  10227095.0
loss z =  10226446.0
loss z =  10225805.0
loss z =  10225171.0
loss z =  10224538.0
loss z =  10223912.0
loss z =  10222641.0
loss z =  10222002.0
Epoch  1585: reducing learning rate of group 0 to 4.0000e-06.
loss z =  10221364.0
loss z =  10221239.0
loss z =  10221114.0
loss z =  10220988.0
loss z =  10220863.0
loss z =  10220738.0
los

loss z =  10213362.0
loss z =  10213225.0
loss z =  10213088.0
loss z =  10212948.0
loss z =  10212811.0
loss z =  10212674.0
loss z =  10212537.0
loss z =  10212400.0
loss z =  10212263.0
Epoch  1641: reducing learning rate of group 0 to 8.0000e-07.
loss z =  10212127.0
loss z =  10212099.0
loss z =  10212072.0
loss z =  10212045.0
loss z =  10212018.0
loss z =  10211990.0
loss z =  10211937.0
loss z =  10211910.0
loss z =  10211883.0
loss z =  10211855.0
Epoch  1652: reducing learning rate of group 0 to 1.6000e-07.
z step =  25.630122661590576 s
loss_f =  7521719.5
loss_f =  7509704.5
loss_f =  7499644.5
loss_f =  7491442.5
loss_f =  7483431.0
loss_f =  7477002.5
Epoch  7045: reducing learning rate of group 0 to 2.5000e-04.
loss_f =  7473225.5
loss_f =  7471284.0
loss_f =  7469565.0
loss_f =  7468148.5
loss_f =  7466851.0
loss_f =  7465636.0
loss_f =  7464506.5
loss_f =  7463420.0
loss_f =  7462382.5
loss_f =  7461379.0
loss_f =  7460406.0
loss_f =  7459461.0
loss_f =  7458522.0
loss

loss_f =  7444620.0
loss_f =  7444372.0
loss_f =  7444137.0
loss_f =  7443898.0
loss_f =  7443659.0
loss_f =  7443422.5
loss_f =  7443185.5
Epoch  7337: reducing learning rate of group 0 to 3.1250e-05.
loss_f =  7443047.5
loss_f =  7442928.0
loss_f =  7442814.0
loss_f =  7442692.0
loss_f =  7442570.0
loss_f =  7442458.0
Epoch  7368: reducing learning rate of group 0 to 1.5625e-05.
loss_f =  7442378.5
loss_f =  7442312.0
loss_f =  7442249.5
loss_f =  7442204.0
loss_f =  7442133.0
loss_f =  7442086.0
Epoch  7399: reducing learning rate of group 0 to 7.8125e-06.
loss_f =  7442038.0
loss_f =  7442006.0
loss_f =  7441976.0
thetaf step =  140.5460696220398 s
sample ==  0
loss_lambda_c =  11075.2939453125
Epoch  5193: reducing learning rate of group 0 to 1.0000e-05.
Epoch  5209: reducing learning rate of group 0 to 1.0000e-06.
Epoch  5225: reducing learning rate of group 0 to 1.0000e-07.
Epoch  5241: reducing learning rate of group 0 to 1.0000e-08.
sample ==  1
loss_lambda_c =  6410.390625
Ep

Epoch 11348: reducing learning rate of group 0 to 1.0000e-08.
sample ==  3
loss_lambda_c =  43826.5
Epoch  6340: reducing learning rate of group 0 to 1.0000e-05.
Epoch  6356: reducing learning rate of group 0 to 1.0000e-06.
Epoch  6372: reducing learning rate of group 0 to 1.0000e-07.
Epoch  6388: reducing learning rate of group 0 to 1.0000e-08.
sample ==  4
loss_lambda_c =  24292.798828125
Epoch  7402: reducing learning rate of group 0 to 1.0000e-05.
Epoch  7418: reducing learning rate of group 0 to 1.0000e-06.
Epoch  7434: reducing learning rate of group 0 to 1.0000e-07.
Epoch  7450: reducing learning rate of group 0 to 1.0000e-08.
sample ==  5
loss_lambda_c =  37791.11328125
Epoch  4540: reducing learning rate of group 0 to 1.0000e-05.
Epoch  4556: reducing learning rate of group 0 to 1.0000e-06.
Epoch  4572: reducing learning rate of group 0 to 1.0000e-07.
Epoch  4588: reducing learning rate of group 0 to 1.0000e-08.
sample ==  6
loss_lambda_c =  23247.119140625
Epoch  5427: reduci

Epoch  5980: reducing learning rate of group 0 to 1.0000e-08.
sample ==  9
loss_lambda_c =  8145.00244140625
Epoch  4438: reducing learning rate of group 0 to 1.0000e-05.
Epoch  4454: reducing learning rate of group 0 to 1.0000e-06.
Epoch  4470: reducing learning rate of group 0 to 1.0000e-07.
Epoch  4486: reducing learning rate of group 0 to 1.0000e-08.
sample ==  10
loss_lambda_c =  19104.552734375
Epoch  5974: reducing learning rate of group 0 to 1.0000e-05.
Epoch  5990: reducing learning rate of group 0 to 1.0000e-06.
Epoch  6006: reducing learning rate of group 0 to 1.0000e-07.
Epoch  6022: reducing learning rate of group 0 to 1.0000e-08.
sample ==  11
loss_lambda_c =  10355.634765625
Epoch 10962: reducing learning rate of group 0 to 1.0000e-05.
Epoch 10978: reducing learning rate of group 0 to 1.0000e-06.
Epoch 10994: reducing learning rate of group 0 to 1.0000e-07.
Epoch 11010: reducing learning rate of group 0 to 1.0000e-08.
sample ==  12
loss_lambda_c =  9301.4599609375
Epoch 

Epoch 11074: reducing learning rate of group 0 to 1.0000e-08.
sample ==  12
loss_lambda_c =  9301.322265625
Epoch  6263: reducing learning rate of group 0 to 1.0000e-05.
Epoch  6279: reducing learning rate of group 0 to 1.0000e-06.
Epoch  6295: reducing learning rate of group 0 to 1.0000e-07.
Epoch  6311: reducing learning rate of group 0 to 1.0000e-08.
sample ==  13
loss_lambda_c =  8874.322265625
Epoch  7466: reducing learning rate of group 0 to 1.0000e-05.
Epoch  7482: reducing learning rate of group 0 to 1.0000e-06.
Epoch  7498: reducing learning rate of group 0 to 1.0000e-07.
Epoch  7514: reducing learning rate of group 0 to 1.0000e-08.
sample ==  14
loss_lambda_c =  15624.767578125
Epoch  4583: reducing learning rate of group 0 to 1.0000e-05.
Epoch  4599: reducing learning rate of group 0 to 1.0000e-06.
Epoch  4615: reducing learning rate of group 0 to 1.0000e-07.
Epoch  4631: reducing learning rate of group 0 to 1.0000e-08.
sample ==  15
loss_lambda_c =  1813.322998046875
Epoch 

Epoch 59533: reducing learning rate of group 0 to 2.0000e-05.
loss_c =  256.20855712890625
loss_c =  255.97409057617188
loss_c =  255.8062286376953
loss_c =  255.67306518554688
loss_c =  255.55804443359375
loss_c =  255.45281982421875
loss_c =  255.35305786132812
loss_c =  255.25689697265625
loss_c =  255.16253662109375
thetac step =  7.96532940864563 s
step =  46
loss z =  10247017.0
loss z =  10227393.0
loss z =  10209922.0
loss z =  10194306.0
loss z =  10180287.0
loss z =  10167791.0
loss z =  10156794.0
loss z =  10147367.0
loss z =  10139554.0
loss z =  10133137.0
loss z =  10127848.0
Epoch  1888: reducing learning rate of group 0 to 2.0000e-05.
loss z =  10123391.0
loss z =  10122581.0
loss z =  10121781.0
loss z =  10120994.0
loss z =  10120219.0
loss z =  10119455.0
loss z =  10118698.0
loss z =  10117949.0
loss z =  10116476.0
loss z =  10115752.0
Epoch  1899: reducing learning rate of group 0 to 4.0000e-06.
loss z =  10115036.0
loss z =  10114893.0
loss z =  10114750.0
loss 

loss z =  10102026.0
loss z =  10101393.0
loss z =  10100770.0
loss z =  10100156.0
loss z =  10098937.0
Epoch  1943: reducing learning rate of group 0 to 4.0000e-06.
loss z =  10098336.0
loss z =  10098216.0
loss z =  10098098.0
loss z =  10097978.0
loss z =  10097860.0
loss z =  10097742.0
loss z =  10097622.0
loss z =  10097502.0
loss z =  10097382.0
loss z =  10097263.0
loss z =  10097143.0
Epoch  1954: reducing learning rate of group 0 to 8.0000e-07.
loss z =  10097024.0
loss z =  10096999.0
loss z =  10096976.0
loss z =  10096952.0
loss z =  10096929.0
loss z =  10096905.0
loss z =  10096882.0
loss z =  10096834.0
loss z =  10096811.0
loss z =  10096786.0
Epoch  1965: reducing learning rate of group 0 to 1.6000e-07.
z step =  24.85142207145691 s
loss_f =  7427132.0
loss_f =  7417933.0
loss_f =  7407749.5
loss_f =  7400704.5
loss_f =  7394081.5
loss_f =  7388622.5
loss_f =  7384705.0
loss_f =  7381476.0
loss_f =  7378898.0
loss_f =  7376518.0
loss_f =  7374440.0
loss_f =  7372448.

loss_f =  7340644.0
Epoch  8665: reducing learning rate of group 0 to 1.2500e-04.
loss_f =  7339955.0
loss_f =  7339482.5
loss_f =  7339010.5
loss_f =  7338543.0
loss_f =  7338092.5
loss_f =  7337639.0
Epoch  8696: reducing learning rate of group 0 to 6.2500e-05.
loss_f =  7337281.0
loss_f =  7337055.5
loss_f =  7336834.0
loss_f =  7336611.5
loss_f =  7336400.5
loss_f =  7336180.0
Epoch  8727: reducing learning rate of group 0 to 3.1250e-05.
loss_f =  7335986.0
loss_f =  7335874.5
loss_f =  7335768.0
loss_f =  7335658.0
loss_f =  7335542.0
loss_f =  7335438.0
Epoch  8758: reducing learning rate of group 0 to 1.5625e-05.
loss_f =  7335326.5
loss_f =  7335272.0
loss_f =  7335216.0
loss_f =  7335166.5
loss_f =  7335112.0
loss_f =  7335059.0
loss_f =  7335005.5
Epoch  8789: reducing learning rate of group 0 to 7.8125e-06.
loss_f =  7334962.0
loss_f =  7334939.0
thetaf step =  150.2099404335022 s
sample ==  0
loss_lambda_c =  11075.5546875
Epoch  5641: reducing learning rate of group 0 to 1

loss_f =  7333868.0
loss_f =  7333849.0
loss_f =  7333812.0
loss_f =  7333791.0
loss_f =  7333751.0
thetaf step =  144.97485828399658 s
sample ==  0
loss_lambda_c =  11075.6552734375
Epoch  5705: reducing learning rate of group 0 to 1.0000e-05.
Epoch  5721: reducing learning rate of group 0 to 1.0000e-06.
Epoch  5737: reducing learning rate of group 0 to 1.0000e-07.
Epoch  5753: reducing learning rate of group 0 to 1.0000e-08.
sample ==  1
loss_lambda_c =  6410.0380859375
Epoch  4928: reducing learning rate of group 0 to 1.0000e-05.
Epoch  4944: reducing learning rate of group 0 to 1.0000e-06.
Epoch  4960: reducing learning rate of group 0 to 1.0000e-07.
Epoch  4976: reducing learning rate of group 0 to 1.0000e-08.
sample ==  2
loss_lambda_c =  30098.44140625
Epoch 11748: reducing learning rate of group 0 to 1.0000e-05.
Epoch 11764: reducing learning rate of group 0 to 1.0000e-06.
Epoch 11780: reducing learning rate of group 0 to 1.0000e-07.
Epoch 11796: reducing learning rate of group

Epoch 11860: reducing learning rate of group 0 to 1.0000e-08.
sample ==  3
loss_lambda_c =  43826.5703125
Epoch  6852: reducing learning rate of group 0 to 1.0000e-05.
Epoch  6868: reducing learning rate of group 0 to 1.0000e-06.
Epoch  6884: reducing learning rate of group 0 to 1.0000e-07.
Epoch  6900: reducing learning rate of group 0 to 1.0000e-08.
sample ==  4
loss_lambda_c =  24292.828125
Epoch  7914: reducing learning rate of group 0 to 1.0000e-05.
Epoch  7930: reducing learning rate of group 0 to 1.0000e-06.
Epoch  7946: reducing learning rate of group 0 to 1.0000e-07.
Epoch  7962: reducing learning rate of group 0 to 1.0000e-08.
sample ==  5
loss_lambda_c =  37791.4921875
Epoch  5052: reducing learning rate of group 0 to 1.0000e-05.
Epoch  5068: reducing learning rate of group 0 to 1.0000e-06.
Epoch  5084: reducing learning rate of group 0 to 1.0000e-07.
Epoch  5100: reducing learning rate of group 0 to 1.0000e-08.
sample ==  6
loss_lambda_c =  23248.17578125
Epoch  5939: reduc

Epoch  6492: reducing learning rate of group 0 to 1.0000e-08.
sample ==  9
loss_lambda_c =  8141.97802734375
Epoch  4950: reducing learning rate of group 0 to 1.0000e-05.
Epoch  4966: reducing learning rate of group 0 to 1.0000e-06.
Epoch  4982: reducing learning rate of group 0 to 1.0000e-07.
Epoch  4998: reducing learning rate of group 0 to 1.0000e-08.
sample ==  10
loss_lambda_c =  19105.876953125
Epoch  6486: reducing learning rate of group 0 to 1.0000e-05.
Epoch  6502: reducing learning rate of group 0 to 1.0000e-06.
Epoch  6518: reducing learning rate of group 0 to 1.0000e-07.
Epoch  6534: reducing learning rate of group 0 to 1.0000e-08.
sample ==  11
loss_lambda_c =  10356.2060546875
Epoch 11474: reducing learning rate of group 0 to 1.0000e-05.
Epoch 11490: reducing learning rate of group 0 to 1.0000e-06.
Epoch 11506: reducing learning rate of group 0 to 1.0000e-07.
Epoch 11522: reducing learning rate of group 0 to 1.0000e-08.
sample ==  12
loss_lambda_c =  9299.4375
Epoch  6713

Epoch  5127: reducing learning rate of group 0 to 1.0000e-07.
Epoch  5143: reducing learning rate of group 0 to 1.0000e-08.
sample ==  15
loss_lambda_c =  1813.3597412109375
Epoch  4701: reducing learning rate of group 0 to 1.0000e-05.
Epoch  4717: reducing learning rate of group 0 to 1.0000e-06.
Epoch  4733: reducing learning rate of group 0 to 1.0000e-07.
Epoch  4749: reducing learning rate of group 0 to 1.0000e-08.
lambdac step =  1.1582176685333252 s
loss_c =  256.795166015625
loss_c =  255.23707580566406
loss_c =  254.7921142578125
loss_c =  254.44189453125
loss_c =  254.08840942382812
loss_c =  253.72581481933594
loss_c =  253.35174560546875
loss_c =  252.96450805664062
loss_c =  252.5640411376953
loss_c =  252.15347290039062
thetac step =  8.93932819366455 s
step =  53
loss z =  10069432.0
loss z =  10057140.0
loss z =  10046980.0
loss z =  10039001.0
loss z =  10032782.0
loss z =  10027565.0
loss z =  10022846.0
loss z =  10018452.0
loss z =  10014300.0
loss z =  10010356.0
los

thetac step =  9.174630641937256 s
step =  54
loss z =  10060784.0
loss z =  10046015.0
loss z =  10033598.0
loss z =  10023663.0
loss z =  10016060.0
loss z =  10010269.0
loss z =  10005474.0
loss z =  10001180.0
loss z =  9997224.0
loss z =  9993517.0
Epoch  2246: reducing learning rate of group 0 to 2.0000e-05.
loss z =  9990008.0
loss z =  9989336.0
loss z =  9988668.0
loss z =  9988006.0
loss z =  9987348.0
loss z =  9986696.0
loss z =  9986052.0
loss z =  9985412.0
loss z =  9984776.0
loss z =  9983526.0
Epoch  2257: reducing learning rate of group 0 to 4.0000e-06.
loss z =  9982905.0
loss z =  9982781.0
loss z =  9982658.0
loss z =  9982536.0
loss z =  9982414.0
loss z =  9982293.0
loss z =  9982170.0
loss z =  9982048.0
loss z =  9981926.0
loss z =  9981804.0
loss z =  9981680.0
Epoch  2268: reducing learning rate of group 0 to 8.0000e-07.
loss z =  9981558.0
loss z =  9981535.0
loss z =  9981509.0
loss z =  9981486.0
loss z =  9981461.0
loss z =  9981437.0
loss z =  9981413.0


loss z =  9971797.0
loss z =  9971772.0
loss z =  9971746.0
loss z =  9971721.0
loss z =  9971695.0
loss z =  9971669.0
loss z =  9971642.0
loss z =  9971590.0
loss z =  9971564.0
loss z =  9971538.0
Epoch  2323: reducing learning rate of group 0 to 1.6000e-07.
z step =  24.964227437973022 s
loss_f =  7323039.5
Epoch 10004: reducing learning rate of group 0 to 2.5000e-04.
loss_f =  7313339.0
loss_f =  7309375.0
loss_f =  7305088.5
loss_f =  7301935.0
loss_f =  7298907.0
loss_f =  7296294.0
Epoch 10035: reducing learning rate of group 0 to 1.2500e-04.
loss_f =  7294786.0
loss_f =  7293802.0
loss_f =  7292901.5
loss_f =  7292096.5
loss_f =  7291377.0
loss_f =  7290684.0
Epoch 10066: reducing learning rate of group 0 to 6.2500e-05.
loss_f =  7290165.0
loss_f =  7289842.0
loss_f =  7289526.5
loss_f =  7289217.0
loss_f =  7288905.0
loss_f =  7288593.0
Epoch 10097: reducing learning rate of group 0 to 3.1250e-05.
loss_f =  7288321.0
loss_f =  7288183.5
loss_f =  7288027.0
loss_f =  7287878.0

loss_f =  7260117.0
loss_f =  7259343.0
loss_f =  7258584.0
loss_f =  7257829.5
loss_f =  7257103.5
loss_f =  7256371.0
loss_f =  7255646.5
loss_f =  7254930.5
loss_f =  7254221.0
loss_f =  7253524.5
loss_f =  7252820.5
loss_f =  7252141.5
loss_f =  7251451.0
loss_f =  7250774.0
loss_f =  7250100.0
loss_f =  7249442.0
loss_f =  7248778.5
loss_f =  7248122.0
loss_f =  7247464.0
loss_f =  7246825.0
thetaf step =  139.18917107582092 s
sample ==  0
loss_lambda_c =  11076.353515625
Epoch  6153: reducing learning rate of group 0 to 1.0000e-05.
Epoch  6169: reducing learning rate of group 0 to 1.0000e-06.
Epoch  6185: reducing learning rate of group 0 to 1.0000e-07.
Epoch  6201: reducing learning rate of group 0 to 1.0000e-08.
sample ==  1
loss_lambda_c =  6409.39453125
Epoch  5376: reducing learning rate of group 0 to 1.0000e-05.
Epoch  5392: reducing learning rate of group 0 to 1.0000e-06.
Epoch  5408: reducing learning rate of group 0 to 1.0000e-07.
Epoch  5424: reducing learning rate of g

Epoch 12292: reducing learning rate of group 0 to 1.0000e-07.
Epoch 12308: reducing learning rate of group 0 to 1.0000e-08.
sample ==  3
loss_lambda_c =  43826.7890625
Epoch  7300: reducing learning rate of group 0 to 1.0000e-05.
Epoch  7316: reducing learning rate of group 0 to 1.0000e-06.
Epoch  7332: reducing learning rate of group 0 to 1.0000e-07.
Epoch  7348: reducing learning rate of group 0 to 1.0000e-08.
sample ==  4
loss_lambda_c =  24291.681640625
Epoch  8362: reducing learning rate of group 0 to 1.0000e-05.
Epoch  8378: reducing learning rate of group 0 to 1.0000e-06.
Epoch  8394: reducing learning rate of group 0 to 1.0000e-07.
Epoch  8410: reducing learning rate of group 0 to 1.0000e-08.
sample ==  5
loss_lambda_c =  37792.1171875
Epoch  5500: reducing learning rate of group 0 to 1.0000e-05.
Epoch  5516: reducing learning rate of group 0 to 1.0000e-06.
Epoch  5532: reducing learning rate of group 0 to 1.0000e-07.
Epoch  5548: reducing learning rate of group 0 to 1.0000e-08

Epoch  5596: reducing learning rate of group 0 to 1.0000e-07.
Epoch  5612: reducing learning rate of group 0 to 1.0000e-08.
sample ==  6
loss_lambda_c =  23249.28125
Epoch  6451: reducing learning rate of group 0 to 1.0000e-05.
Epoch  6467: reducing learning rate of group 0 to 1.0000e-06.
Epoch  6483: reducing learning rate of group 0 to 1.0000e-07.
Epoch  6499: reducing learning rate of group 0 to 1.0000e-08.
sample ==  7
loss_lambda_c =  14114.3603515625
Epoch 15770: reducing learning rate of group 0 to 1.0000e-05.
Epoch 15786: reducing learning rate of group 0 to 1.0000e-06.
Epoch 15802: reducing learning rate of group 0 to 1.0000e-07.
Epoch 15818: reducing learning rate of group 0 to 1.0000e-08.
sample ==  8
loss_lambda_c =  2421.623779296875
Epoch  6892: reducing learning rate of group 0 to 1.0000e-05.
Epoch  6908: reducing learning rate of group 0 to 1.0000e-06.
Epoch  6924: reducing learning rate of group 0 to 1.0000e-07.
Epoch  6940: reducing learning rate of group 0 to 1.0000e

Epoch  7004: reducing learning rate of group 0 to 1.0000e-08.
sample ==  9
loss_lambda_c =  8140.4765625
Epoch  5464: reducing learning rate of group 0 to 1.0000e-05.
Epoch  5480: reducing learning rate of group 0 to 1.0000e-06.
Epoch  5496: reducing learning rate of group 0 to 1.0000e-07.
Epoch  5512: reducing learning rate of group 0 to 1.0000e-08.
sample ==  10
loss_lambda_c =  19107.015625
Epoch  6998: reducing learning rate of group 0 to 1.0000e-05.
Epoch  7014: reducing learning rate of group 0 to 1.0000e-06.
Epoch  7030: reducing learning rate of group 0 to 1.0000e-07.
Epoch  7046: reducing learning rate of group 0 to 1.0000e-08.
sample ==  11
loss_lambda_c =  10357.0859375
Epoch 11986: reducing learning rate of group 0 to 1.0000e-05.
Epoch 12002: reducing learning rate of group 0 to 1.0000e-06.
Epoch 12018: reducing learning rate of group 0 to 1.0000e-07.
Epoch 12034: reducing learning rate of group 0 to 1.0000e-08.
sample ==  12
loss_lambda_c =  9297.720703125
Epoch  7227: red

Epoch  5623: reducing learning rate of group 0 to 1.0000e-06.
Epoch  5639: reducing learning rate of group 0 to 1.0000e-07.
Epoch  5655: reducing learning rate of group 0 to 1.0000e-08.
sample ==  15
loss_lambda_c =  1813.7261962890625
Epoch  5213: reducing learning rate of group 0 to 1.0000e-05.
Epoch  5229: reducing learning rate of group 0 to 1.0000e-06.
Epoch  5245: reducing learning rate of group 0 to 1.0000e-07.
Epoch  5261: reducing learning rate of group 0 to 1.0000e-08.
lambdac step =  1.1594839096069336 s
loss_c =  256.64031982421875
loss_c =  254.77774047851562
loss_c =  254.2616729736328
loss_c =  253.8778076171875
loss_c =  253.51467895507812
loss_c =  253.16305541992188
loss_c =  252.81813049316406
loss_c =  252.4768829345703
loss_c =  252.13714599609375
loss_c =  251.79762268066406
thetac step =  10.725608348846436 s
step =  61
loss z =  9978391.0
loss z =  9963605.0
loss z =  9950674.0
loss z =  9939373.0
loss z =  9929522.0
loss z =  9921006.0
loss z =  9913692.0
loss 

loss_c =  255.09860229492188
loss_c =  255.02377319335938
thetac step =  10.224644899368286 s
step =  62
loss z =  9960380.0
loss z =  9942396.0
loss z =  9928398.0
loss z =  9917316.0
loss z =  9908194.0
loss z =  9900421.0
loss z =  9893740.0
loss z =  9887991.0
loss z =  9882988.0
loss z =  9878576.0
loss z =  9874644.0
Epoch  2602: reducing learning rate of group 0 to 2.0000e-05.
loss z =  9871079.0
loss z =  9870413.0
loss z =  9869755.0
loss z =  9869107.0
loss z =  9868466.0
loss z =  9867836.0
loss z =  9867214.0
loss z =  9866598.0
loss z =  9865392.0
loss z =  9864798.0
Epoch  2613: reducing learning rate of group 0 to 4.0000e-06.
loss z =  9864210.0
loss z =  9864094.0
loss z =  9863978.0
loss z =  9863862.0
loss z =  9863748.0
loss z =  9863630.0
loss z =  9863515.0
loss z =  9863400.0
loss z =  9863284.0
loss z =  9863168.0
loss z =  9863054.0
Epoch  2624: reducing learning rate of group 0 to 8.0000e-07.
loss z =  9862941.0
loss z =  9862918.0
loss z =  9862895.0
loss z = 

loss z =  9854742.0
Epoch  2680: reducing learning rate of group 0 to 1.6000e-07.
z step =  24.397275686264038 s
loss_f =  7229435.0
loss_f =  7220266.5
loss_f =  7215521.0
loss_f =  7209175.0
loss_f =  7202728.0
loss_f =  7197924.0
Epoch 11616: reducing learning rate of group 0 to 2.5000e-04.
loss_f =  7194996.5
loss_f =  7193397.5
loss_f =  7191980.5
loss_f =  7190711.0
loss_f =  7189572.0
loss_f =  7188471.5
Epoch 11647: reducing learning rate of group 0 to 1.2500e-04.
loss_f =  7187542.0
loss_f =  7187027.0
loss_f =  7186466.0
loss_f =  7185914.0
loss_f =  7185391.0
loss_f =  7184893.0
Epoch 11678: reducing learning rate of group 0 to 6.2500e-05.
loss_f =  7184436.0
loss_f =  7184197.5
loss_f =  7183996.0
loss_f =  7183769.0
loss_f =  7183561.0
loss_f =  7183350.0
loss_f =  7183134.0
Epoch 11709: reducing learning rate of group 0 to 3.1250e-05.
loss_f =  7183009.5
loss_f =  7182898.5
loss_f =  7182800.0
loss_f =  7182691.0
loss_f =  7182583.0
loss_f =  7182481.0
Epoch 11740: reduci

loss_f =  7180227.0
loss_f =  7180113.0
loss_f =  7180001.0
loss_f =  7179892.0
loss_f =  7179782.0
loss_f =  7179674.0
Epoch 11926: reducing learning rate of group 0 to 1.5625e-05.
loss_f =  7179572.0
loss_f =  7179526.5
loss_f =  7179467.0
loss_f =  7179403.0
loss_f =  7179355.5
loss_f =  7179297.0
Epoch 11957: reducing learning rate of group 0 to 7.8125e-06.
loss_f =  7179255.5
loss_f =  7179225.0
loss_f =  7179194.5
loss_f =  7179166.0
loss_f =  7179136.0
loss_f =  7179107.5
Epoch 11988: reducing learning rate of group 0 to 3.9063e-06.
thetaf step =  136.79358530044556 s
sample ==  0
loss_lambda_c =  11078.1513671875
Epoch  6665: reducing learning rate of group 0 to 1.0000e-05.
Epoch  6681: reducing learning rate of group 0 to 1.0000e-06.
Epoch  6697: reducing learning rate of group 0 to 1.0000e-07.
Epoch  6713: reducing learning rate of group 0 to 1.0000e-08.
sample ==  1
loss_lambda_c =  6408.478515625
Epoch  5888: reducing learning rate of group 0 to 1.0000e-05.
Epoch  5904: red

Epoch 12820: reducing learning rate of group 0 to 1.0000e-08.
sample ==  3
loss_lambda_c =  43826.9921875
Epoch  7812: reducing learning rate of group 0 to 1.0000e-05.
Epoch  7828: reducing learning rate of group 0 to 1.0000e-06.
Epoch  7844: reducing learning rate of group 0 to 1.0000e-07.
Epoch  7860: reducing learning rate of group 0 to 1.0000e-08.
sample ==  4
loss_lambda_c =  24290.783203125
Epoch  8874: reducing learning rate of group 0 to 1.0000e-05.
Epoch  8890: reducing learning rate of group 0 to 1.0000e-06.
Epoch  8906: reducing learning rate of group 0 to 1.0000e-07.
Epoch  8922: reducing learning rate of group 0 to 1.0000e-08.
sample ==  5
loss_lambda_c =  37792.99609375
Epoch  6012: reducing learning rate of group 0 to 1.0000e-05.
Epoch  6028: reducing learning rate of group 0 to 1.0000e-06.
Epoch  6044: reducing learning rate of group 0 to 1.0000e-07.
Epoch  6060: reducing learning rate of group 0 to 1.0000e-08.
sample ==  6
loss_lambda_c =  23250.46875
Epoch  6899: redu

Epoch  6092: reducing learning rate of group 0 to 1.0000e-06.
Epoch  6108: reducing learning rate of group 0 to 1.0000e-07.
Epoch  6124: reducing learning rate of group 0 to 1.0000e-08.
sample ==  6
loss_lambda_c =  23250.5
Epoch  6963: reducing learning rate of group 0 to 1.0000e-05.
Epoch  6979: reducing learning rate of group 0 to 1.0000e-06.
Epoch  6995: reducing learning rate of group 0 to 1.0000e-07.
Epoch  7011: reducing learning rate of group 0 to 1.0000e-08.
sample ==  7
loss_lambda_c =  14115.703125
Epoch 16282: reducing learning rate of group 0 to 1.0000e-05.
Epoch 16298: reducing learning rate of group 0 to 1.0000e-06.
Epoch 16314: reducing learning rate of group 0 to 1.0000e-07.
Epoch 16330: reducing learning rate of group 0 to 1.0000e-08.
sample ==  8
loss_lambda_c =  2420.84716796875
Epoch  7404: reducing learning rate of group 0 to 1.0000e-05.
Epoch  7420: reducing learning rate of group 0 to 1.0000e-06.
Epoch  7436: reducing learning rate of group 0 to 1.0000e-07.
Epoc

Epoch  6009: reducing learning rate of group 0 to 1.0000e-07.
Epoch  6025: reducing learning rate of group 0 to 1.0000e-08.
sample ==  10
loss_lambda_c =  19108.119140625
Epoch  7510: reducing learning rate of group 0 to 1.0000e-05.
Epoch  7526: reducing learning rate of group 0 to 1.0000e-06.
Epoch  7542: reducing learning rate of group 0 to 1.0000e-07.
Epoch  7558: reducing learning rate of group 0 to 1.0000e-08.
sample ==  11
loss_lambda_c =  10357.2373046875
Epoch 12498: reducing learning rate of group 0 to 1.0000e-05.
Epoch 12514: reducing learning rate of group 0 to 1.0000e-06.
Epoch 12530: reducing learning rate of group 0 to 1.0000e-07.
Epoch 12546: reducing learning rate of group 0 to 1.0000e-08.
sample ==  12
loss_lambda_c =  9296.189453125
Epoch  7740: reducing learning rate of group 0 to 1.0000e-05.
Epoch  7756: reducing learning rate of group 0 to 1.0000e-06.
Epoch  7772: reducing learning rate of group 0 to 1.0000e-07.
Epoch  7788: reducing learning rate of group 0 to 1.0

Epoch  5741: reducing learning rate of group 0 to 1.0000e-06.
Epoch  5757: reducing learning rate of group 0 to 1.0000e-07.
Epoch  5773: reducing learning rate of group 0 to 1.0000e-08.
lambdac step =  1.3362598419189453 s
loss_c =  257.0139465332031
loss_c =  255.58389282226562
loss_c =  255.20909118652344
loss_c =  254.93875122070312
loss_c =  254.67416381835938
loss_c =  254.40892028808594
loss_c =  254.14193725585938
loss_c =  253.8724365234375
loss_c =  253.59986877441406
loss_c =  253.323974609375
thetac step =  12.331274271011353 s
step =  69
loss z =  9920506.0
loss z =  9887994.0
loss z =  9861317.0
loss z =  9840285.0
loss z =  9823695.0
loss z =  9810092.0
loss z =  9798592.0
loss z =  9789047.0
loss z =  9781419.0
loss z =  9775372.0
loss z =  9770895.0
Epoch  2914: reducing learning rate of group 0 to 2.0000e-05.
loss z =  9767513.0
loss z =  9766939.0
loss z =  9766385.0
loss z =  9765844.0
loss z =  9765311.0
loss z =  9764790.0
loss z =  9764280.0
loss z =  9763778.0
lo

loss z =  9935807.0
loss z =  9924570.0
Epoch  2959: reducing learning rate of group 0 to 2.0000e-05.
loss z =  9913326.0
loss z =  9911008.0
loss z =  9908834.0
loss z =  9906580.0
loss z =  9904552.0
loss z =  9902393.0
loss z =  9900288.0
loss z =  9898083.0
loss z =  9893638.0
loss z =  9891550.0
Epoch  2970: reducing learning rate of group 0 to 4.0000e-06.
loss z =  9889513.0
loss z =  9889080.0
loss z =  9888690.0
loss z =  9888217.0
loss z =  9887829.0
loss z =  9887374.0
loss z =  9886946.0
loss z =  9886495.0
loss z =  9886043.0
loss z =  9885611.0
loss z =  9885203.0
Epoch  2981: reducing learning rate of group 0 to 8.0000e-07.
loss z =  9884794.0
loss z =  9884716.0
loss z =  9884617.0
loss z =  9884540.0
loss z =  9884422.0
loss z =  9884345.0
loss z =  9884170.0
loss z =  9884091.0
loss z =  9884011.0
loss z =  9883934.0
Epoch  2992: reducing learning rate of group 0 to 1.6000e-07.
z step =  25.053831577301025 s
loss_f =  7278785.0
loss_f =  7268721.0
loss_f =  7257947.5
E

loss z =  10035812.0
loss z =  10035745.0
loss z =  10035680.0
loss z =  10035589.0
loss z =  10035522.0
loss z =  10035436.0
Epoch  3036: reducing learning rate of group 0 to 1.6000e-07.
z step =  24.900869607925415 s
loss_f =  7425736.0
loss_f =  7374698.5
loss_f =  7349769.0
loss_f =  7323403.0
loss_f =  7302257.5
loss_f =  7285959.5
loss_f =  7273115.5
Epoch 13216: reducing learning rate of group 0 to 2.5000e-04.
loss_f =  7266305.0
loss_f =  7261076.5
loss_f =  7256502.5
loss_f =  7252312.5
loss_f =  7248393.0
loss_f =  7244630.0
Epoch 13247: reducing learning rate of group 0 to 1.2500e-04.
loss_f =  7242160.0
loss_f =  7240348.0
loss_f =  7238670.0
loss_f =  7237024.0
loss_f =  7235390.0
loss_f =  7233790.0
Epoch 13278: reducing learning rate of group 0 to 6.2500e-05.
loss_f =  7232512.5
loss_f =  7231728.0
loss_f =  7230967.0
loss_f =  7230155.0
loss_f =  7229412.0
loss_f =  7228637.5
Epoch 13309: reducing learning rate of group 0 to 3.1250e-05.
loss_f =  7227958.0
loss_f =  722

loss_f =  7145441.0
loss_f =  7145064.5
loss_f =  7144694.5
loss_f =  7144320.0
Epoch 13495: reducing learning rate of group 0 to 3.1250e-05.
loss_f =  7143945.0
loss_f =  7143780.0
loss_f =  7143589.5
loss_f =  7143412.0
loss_f =  7143229.0
loss_f =  7143051.0
loss_f =  7142872.0
Epoch 13526: reducing learning rate of group 0 to 1.5625e-05.
loss_f =  7142756.0
loss_f =  7142677.0
loss_f =  7142586.0
loss_f =  7142492.0
loss_f =  7142400.0
loss_f =  7142317.0
Epoch 13557: reducing learning rate of group 0 to 7.8125e-06.
loss_f =  7142243.0
loss_f =  7142201.0
loss_f =  7142165.0
loss_f =  7142117.5
loss_f =  7142078.0
thetaf step =  145.96098256111145 s
sample ==  0
loss_lambda_c =  11079.1142578125
Epoch  7177: reducing learning rate of group 0 to 1.0000e-05.
Epoch  7193: reducing learning rate of group 0 to 1.0000e-06.
Epoch  7209: reducing learning rate of group 0 to 1.0000e-07.
Epoch  7225: reducing learning rate of group 0 to 1.0000e-08.
sample ==  1
loss_lambda_c =  6407.85742187

Epoch 13316: reducing learning rate of group 0 to 1.0000e-07.
Epoch 13332: reducing learning rate of group 0 to 1.0000e-08.
sample ==  3
loss_lambda_c =  43826.7421875
Epoch  8324: reducing learning rate of group 0 to 1.0000e-05.
Epoch  8340: reducing learning rate of group 0 to 1.0000e-06.
Epoch  8356: reducing learning rate of group 0 to 1.0000e-07.
Epoch  8372: reducing learning rate of group 0 to 1.0000e-08.
sample ==  4
loss_lambda_c =  24291.095703125
Epoch  9386: reducing learning rate of group 0 to 1.0000e-05.
Epoch  9402: reducing learning rate of group 0 to 1.0000e-06.
Epoch  9418: reducing learning rate of group 0 to 1.0000e-07.
Epoch  9434: reducing learning rate of group 0 to 1.0000e-08.
sample ==  5
loss_lambda_c =  37794.046875
Epoch  6524: reducing learning rate of group 0 to 1.0000e-05.
Epoch  6540: reducing learning rate of group 0 to 1.0000e-06.
Epoch  6556: reducing learning rate of group 0 to 1.0000e-07.
Epoch  6572: reducing learning rate of group 0 to 1.0000e-08.

Epoch  7948: reducing learning rate of group 0 to 1.0000e-07.
Epoch  7964: reducing learning rate of group 0 to 1.0000e-08.
sample ==  9
loss_lambda_c =  8139.42578125
Epoch  6425: reducing learning rate of group 0 to 1.0000e-05.
Epoch  6441: reducing learning rate of group 0 to 1.0000e-06.
Epoch  6457: reducing learning rate of group 0 to 1.0000e-07.
Epoch  6473: reducing learning rate of group 0 to 1.0000e-08.
sample ==  10
loss_lambda_c =  19108.955078125
Epoch  7958: reducing learning rate of group 0 to 1.0000e-05.
Epoch  7974: reducing learning rate of group 0 to 1.0000e-06.
Epoch  7990: reducing learning rate of group 0 to 1.0000e-07.
Epoch  8006: reducing learning rate of group 0 to 1.0000e-08.
sample ==  11
loss_lambda_c =  10356.9013671875
Epoch 12946: reducing learning rate of group 0 to 1.0000e-05.
Epoch 12962: reducing learning rate of group 0 to 1.0000e-06.
Epoch 12978: reducing learning rate of group 0 to 1.0000e-07.
Epoch 12994: reducing learning rate of group 0 to 1.000

Epoch  9498: reducing learning rate of group 0 to 1.0000e-08.
sample ==  14
loss_lambda_c =  15623.056640625
Epoch  6567: reducing learning rate of group 0 to 1.0000e-05.
Epoch  6583: reducing learning rate of group 0 to 1.0000e-06.
Epoch  6599: reducing learning rate of group 0 to 1.0000e-07.
Epoch  6615: reducing learning rate of group 0 to 1.0000e-08.
sample ==  15
loss_lambda_c =  1814.5977783203125
Epoch  6173: reducing learning rate of group 0 to 1.0000e-05.
Epoch  6189: reducing learning rate of group 0 to 1.0000e-06.
Epoch  6205: reducing learning rate of group 0 to 1.0000e-07.
Epoch  6221: reducing learning rate of group 0 to 1.0000e-08.
lambdac step =  1.185487985610962 s
loss_c =  256.26202392578125
Epoch 74548: reducing learning rate of group 0 to 2.0000e-05.
loss_c =  255.60238647460938
loss_c =  255.54898071289062
loss_c =  255.4976806640625
loss_c =  255.44723510742188
loss_c =  255.3970184326172
loss_c =  255.34661865234375
loss_c =  255.2956085205078
loss_c =  255.2439

loss z =  9716994.0
loss z =  9708488.0
loss z =  9703434.0
loss z =  9699978.0
loss z =  9697159.0
loss z =  9694627.0
loss z =  9692255.0
loss z =  9689993.0
loss z =  9687815.0
loss z =  9685710.0
Epoch  3270: reducing learning rate of group 0 to 2.0000e-05.
loss z =  9683664.0
loss z =  9683263.0
loss z =  9682864.0
loss z =  9682467.0
loss z =  9682075.0
loss z =  9681683.0
loss z =  9681293.0
loss z =  9680905.0
loss z =  9680132.0
loss z =  9679750.0
Epoch  3281: reducing learning rate of group 0 to 4.0000e-06.
loss z =  9679371.0
loss z =  9679297.0
loss z =  9679219.0
loss z =  9679143.0
loss z =  9679068.0
loss z =  9678992.0
loss z =  9678917.0
loss z =  9678842.0
loss z =  9678766.0
loss z =  9678688.0
loss z =  9678614.0
Epoch  3292: reducing learning rate of group 0 to 8.0000e-07.
loss z =  9678538.0
loss z =  9678522.0
loss z =  9678508.0
loss z =  9678492.0
loss z =  9678477.0
loss z =  9678462.0
loss z =  9678431.0
loss z =  9678416.0
loss z =  9678402.0
loss z =  9678

loss z =  9663558.0
loss z =  9663542.0
loss z =  9663526.0
loss z =  9663510.0
Epoch  3348: reducing learning rate of group 0 to 1.6000e-07.
z step =  25.912616968154907 s
loss_f =  7073211.0
loss_f =  7066216.5
Epoch 14583: reducing learning rate of group 0 to 2.5000e-04.
loss_f =  7063882.0
loss_f =  7060578.5
loss_f =  7057621.5
loss_f =  7055419.5
loss_f =  7053732.0
loss_f =  7052319.5
Epoch 14614: reducing learning rate of group 0 to 1.2500e-04.
loss_f =  7051157.0
loss_f =  7050643.0
loss_f =  7050175.0
loss_f =  7049719.0
loss_f =  7049312.0
loss_f =  7048926.0
loss_f =  7048550.0
loss_f =  7048187.5
loss_f =  7047838.5
loss_f =  7047500.0
loss_f =  7047159.5
loss_f =  7046842.0
loss_f =  7046522.0
loss_f =  7046197.0
loss_f =  7045893.0
loss_f =  7045583.0
loss_f =  7045274.5
loss_f =  7044974.0
loss_f =  7044678.0
loss_f =  7044374.0
loss_f =  7044084.0
loss_f =  7043793.5
loss_f =  7043509.0
loss_f =  7043218.5
loss_f =  7042942.5
loss_f =  7042643.0
loss_f =  7042362.0
los

KeyboardInterrupt: 

In [None]:
model.tauc

In [None]:
model.z_mean

In [None]:
# model = gs.GenerativeSurrogate()
# model.load()

In [None]:
test_samples = {n for n in range(0, 4)}
testData = dta.StokesData(unsupervised_samples=test_samples)
testData.read_data()
# trainingData.plotMicrostruct(1)
testData.reshape_microstructure_image()

In [None]:
uf_pred, Z, grad_norm = model.predict(testData, max_iterations=3000, optimizer='SGD', lr=1e-3)

In [None]:
fig = plt.figure()
plt.plot(grad_norm)
plt.yscale('log')
plt.grid()
ax = plt.gca()
ax.set_ylim(1e1, 1e5)

In [None]:
plt.xlabel('iteration')
plt.ylabel('gradient norm')
plt.title('Gradient norm w.r.t. z in prediction inference')

In [None]:
grad_norm[-10:]

In [None]:
model.z_mean[:4, :].data

In [None]:
Z.data

In [None]:
uf_pred

In [None]:
model.data.P[:4, :]

In [None]:
model.fit(n_steps=10)

In [None]:
x = np.outer(np.linspace(0, 1, 129), np.ones(129))
y = x.copy().T # transpose
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot_surface(x, y, np.reshape(uf_pred[3], (129, 129)),cmap='viridis', edgecolor='none')

In [None]:
pp = np.reshape(model.data.P[3, :], (129, 129))

In [None]:
ax.plot_surface(x, y, pp.detach().numpy())

In [None]:
lg_lambdac = model.pcNet(Z)

In [None]:
lg_lambdac

In [None]:
lg_lambdac2 = model.pcNet(model.z_mean[:4, :])
lg_lambdac2

In [None]:
model.log_lambdac_mean_tensor[:4, :]

In [None]:
uf_pred2 = model.rom_autograd(torch.exp(lg_lambdac2[0]))

In [None]:
x = np.outer(np.linspace(0, 1, 129), np.ones(129))
y = x.copy().T # transpose
fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot_surface(x, y, np.reshape(uf_pred2.detach().numpy(), (129, 129)),cmap='viridis', edgecolor='none')

In [None]:
model.log_lambdac_mean_tensor[:4, :]

In [None]:
fig = plt.figure()
figManager = plt.get_current_fig_manager()
figManager.window.showMaximized()

x = np.outer(np.linspace(0, 1, 129), np.ones(129))
y = x.copy().T # transpose

ax = []
for i in range(4):
    ax.append(fig.add_subplot(2, 4, i + 1, projection='3d'))
    ax[i].plot_surface(x, y, np.reshape(uf_pred[i], (129, 129)), edgecolor='none')
    ax[i].plot_surface(x, y, np.reshape(model.data.P[i, :], (129, 129)).detach().numpy(),
                       edgecolor='none', cmap='inferno')
    ax[i].set_position([0.025 + .2375*i, 0.55, 0.22, 0.4])
    ax[i].elev = 10
    ax[i].set_title('pressure field prediction')
    
xx = np.array([i for i in range(dim_z)])
width = 0.42  # the width of the bars
for i in range(4):
    ax.append(fig.add_subplot(2, 4, i + 5))
    ax[i + 4].bar(xx-width/2, Z[i, :].detach().numpy(), width, label='prediction')
    ax[i + 4].bar(xx + width/2, model.z_mean[i, :].detach().numpy(), width, label='training')
    ax[i + 4].set_position([0.05 + .235*i, 0.08, 0.20, 0.4])
    ax[i + 4].set_title('location of z distribution')
    ax[i + 4].set_xlabel('component')
    ax[i + 4].set_ylabel('z')
    ax[i + 4].legend()

In [None]:
lg_lambdac2 = model.pcNet(model.z_mean[:4, :])
# lg_lambdac2 = model.log_lambdac_mean_tensor[:4, :]
uf_pred2 = []
fig = plt.figure()
figManager = plt.get_current_fig_manager()
figManager.window.showMaximized()

x = np.outer(np.linspace(0, 1, 129), np.ones(129))
y = x.copy().T # transpose
ax = []
for i in range(4):
    uf_pred2.append(model.rom_autograd(torch.exp(lg_lambdac2[i])))
    ax.append(fig.add_subplot(1, 4, i + 1, projection='3d'))
    ax[i].plot_surface(x, y, np.reshape(uf_pred2[i].detach().numpy(), (129, 129)), edgecolor='none')
    ax[i].plot_surface(x, y, np.reshape(model.data.P[i, :], (129, 129)).detach().numpy(),
                       edgecolor='none', cmap='inferno')
    ax[i].set_position([0.025 + .2375*i, 0.55, 0.22, 0.4])
    ax[i].elev = 10
    ax[i].set_title('pressure field prediction')

In [None]:
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1, projection='3d')
ax.plot_surface(x, y, np.reshape(uf_pred2[i].detach().numpy(), (129, 129)), edgecolor='none')

In [None]:
plt.close('all')

In [None]:
ax[0].get_position()

In [None]:
ax[0].set_position([0.02, 0.3, 0.2, 0.4])

In [None]:
ax[0].elev = 10

In [None]:
ax[0].get_zticks()

In [None]:
ax[5].set_xlabel

In [None]:
a

In [None]:
fc = torch.nn.Linear(5, 2)

In [None]:
fc.weight.shape

In [None]:
fc.bias

In [None]:
model.training_iterations = 0

In [None]:
from torch.distributions import constraints
from torch.distributions.exp_family import ExponentialFamily
from torch.distributions.utils import _standard_normal, broadcast_all

In [None]:
dist = torch.distributions.Normal(0, 1)

In [None]:
dist.stddev = 2

In [None]:
x = torch.randn(3, 4)

In [None]:
x

In [None]:
sum(x)

In [None]:
def log_emp_dist(x):
    sigma_e = 2*torch.ones(1)
    mu_e = -3*torch.ones(1)
    return .5*(1/sigma_e**2)*(x - mu_e)**2

def log_emp_dist_grad(x):
    sigma_e = 2*torch.ones(1)
    mu_e = -3*torch.ones(1)
    return (x - mu_e)/sigma_e**2
    

In [None]:
from VI import variationalinference as vi
svi = vi.DiagGaussianSVI(log_emp_dist, log_emp_dist_grad, 1)

In [None]:
svi.fit()

In [None]:
for i in range(10000):
    loss = elbo.autograd_elbo(elbo.vi_mean, elbo.vi_log_std)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

In [None]:
print('sigma == ', elbo.vi_std)
print('mu == ', elbo.vi_mean)

In [None]:
svi.vi_std

In [None]:
log_sigma.grad

In [None]:
torch.zeros(5) - 1

In [13]:
d = {}

In [17]:
d['mean'] = 5

In [18]:
d

{'mean': 5}

In [58]:
dist = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(3, 4, requires_grad=True), covariance_matrix=torch.eye(4).unsqueeze(0).repeat(3, 1, 1))

In [37]:
a = torch.tensor([1, 2, 3])

In [42]:
a = a.unsqueeze(0)

In [44]:
a = torch.ones(3, requires_grad=True)
b = torch.ones(4, requires_grad=True)

In [45]:
opt = optim.Adam([a, b])

In [6]:
dist

MultivariateNormal(loc: torch.Size([3]), covariance_matrix: torch.Size([3, 3]))

In [8]:
dist.expand([3])

MultivariateNormal(loc: torch.Size([3, 3]), covariance_matrix: torch.Size([3, 3, 3]))

In [12]:
dist_e = dist.expand([4])

In [15]:
dist_e.covariance_matrix

tensor([[[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]],

        [[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]],

        [[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]],

        [[1., 0., 0.],
         [0., 1., 0.],
         [0., 0., 1.]]])

In [40]:
L = torch.eye(2)
L[0, 1] = .5

In [41]:
L = L.unsqueeze(0).repeat(3, 1, 1)

In [44]:
L @ torch.transpose(L, 1, 2)

tensor([[[1.2500, 0.5000],
         [0.5000, 1.0000]],

        [[1.2500, 0.5000],
         [0.5000, 1.0000]],

        [[1.2500, 0.5000],
         [0.5000, 1.0000]]])

In [56]:
eps = torch.randn(3,2)
eps

tensor([[ 0.7497,  0.3456],
        [ 1.9351,  1.3505],
        [ 1.1620, -0.4637]])

In [57]:
torch.bmm(L, eps.unsqueeze(2))

tensor([[[ 0.9225],
         [ 0.3456]],

        [[ 2.6103],
         [ 1.3505]],

        [[ 0.9302],
         [-0.4637]]])

In [54]:
L.shape

torch.Size([3, 2, 2])

In [55]:
eps.unsqueeze(2).shape

torch.Size([2, 2, 1])