In [1]:
import numpy as np
import h5py

import e2c as e2c_util

# os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
# The GPU id to use, usually either "0" or "1", "2", "3"

import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"
import datetime

from keras import backend as K
from keras.layers import Input
from keras.models import Model
from keras.optimizers import Adam
from keras import losses


# GPU memory management
import tensorflow as tf

Using TensorFlow backend.


In [2]:
# tf.session specification
# TensorFlow wizardry
config = tf.ConfigProto()

# Don't pre-allocate memory; allocate as-needed
config.gpu_options.allow_growth = True

# Only allow a total of half the GPU memory to be allocated
config.gpu_options.per_process_gpu_memory_fraction = 0.85

# Create a session with the above options specified.
K.tensorflow_backend.set_session(tf.Session(config=config))

In [3]:
def reconstruction_loss(x, t_decoded):
    '''
    Reconstruction loss for the plain VAE
    '''
    v = 0.1
    # return K.mean(K.sum((K.batch_flatten(x) - K.batch_flatten(t_decoded)) ** 2 / (2*v) + 0.5*K.log(2*np.pi*v), axis=-1))
    return K.mean(K.sum((K.batch_flatten(x) - K.batch_flatten(t_decoded)) ** 2 / (2*v), axis=-1))
    # return K.sum((K.batch_flatten(x) - K.batch_flatten(t_decoded)) ** 2, axis=-1)


def l2_reg_loss(qm):
    # 0.5 * (torch.log(pv) - torch.log(qv) + qv / pv + (qm - pm).pow(2) / pv - 1)
    # -0.5 * K.sum(1 + t_log_var - K.square(t_mean) - K.exp(t_log_var), axis=-1)
#     kl = -0.5 * (1 - p_logv + q_logv - K.exp(q_logv) / K.exp(p_logv) - K.square(qm - pm) / K.exp(p_logv))
    l2_reg = 0.5*K.square(qm)
    return K.mean(K.sum(l2_reg, axis=-1))


def get_flux_loss(m, state, state_pred):
    '''
    @params:  state, state_pred shape (batch_size, 60, 60, 2)
              p, p_pred shape (batch_size, 60, 60, 1)
              m shape (batch_size, 60, 60, 1)
    
    @return:  loss_flux: scalar
    
    Only consider discrepancies in total flux, not in phases (saturation not used) 
    '''
    
    perm = K.exp(m)
    p = K.expand_dims(state[:, :, :, 1], -1)
    p_pred = K.expand_dims(state_pred[:, :, :, 1], -1)

    #print(K.int_shape(xxx))
    
    tran_x = 1./perm[:, 1:, ...] + 1./perm[:, :-1, ...]
    tran_y = 1./perm[:, :, 1:, ...] + 1./perm[:, :, :-1, ...]
    flux_x = (p[:, 1:, ...] - p[:, :-1, ...]) / tran_x
    flux_y = (p[:, :, 1:, :] - p[:, :, :-1, :]) / tran_y
    flux_x_pred = (p_pred[:, 1:, ...] - p_pred[:, :-1, ...]) / tran_x
    flux_y_pred = (p_pred[:, :, 1:, :] - p_pred[:, :, :-1, :]) / tran_y

    loss_x = K.sum(K.abs(K.batch_flatten(flux_x) - K.batch_flatten(flux_x_pred)), axis=-1)
    loss_y = K.sum(K.abs(K.batch_flatten(flux_y) - K.batch_flatten(flux_y_pred)), axis=-1)

    loss_flux = K.mean(loss_x + loss_y)
    return loss_flux

In [4]:
def get_binary_sat_loss(state, state_pred):
    
    sat_threshold = 0.105
    sat = K.expand_dims(state[:, :, :, 0], -1)
    sat_pred = K.expand_dims(state_pred[:, :, :, 0], -1)
    
    
    sat_bool = K.greater_equal(sat, sat_threshold) #will return boolean values
    sat_bin = K.cast(sat_bool, dtype=K.floatx()) #will convert bool to 0 and 1  
    
    sat_pred_bool = K.greater_equal(sat_pred, sat_threshold) #will return boolean values
    sat_pred_bin = K.cast(sat_pred_bool, dtype=K.floatx()) #will convert bool to 0 and 1  
    
#     binary_loss = K.sum(K.abs(K.batch_flatten(sat_bin) - K.batch_flatten(sat_pred_bin)), axis=-1)
    
    binary_loss = losses.binary_crossentropy(sat_bin, sat_pred_bin)
    return K.mean(binary_loss)

In [5]:
def get_well_bhp_loss(state, state_pred, wl_mask):
    '''
    @params: state: shape (batch_size, 60, 60, 2)
             state_pred: shape (batch_size, 60, 60, 2)
             prod_well_loc: shape (batch_size, 5, 2)
             
    p_true: shape (batch_size, 60, 60, 1)
    p_pred: shape (batch_size, 60, 60, 1)
    
    @return: bhp_loss: scalar
    '''
    
    p_true = K.expand_dims(state[:, :, :, 1], -1) # shape (batch_size, 60, 60 ,1)
    p_pred = K.expand_dims(state_pred[:, :, :, 1], -1)
    
    bhp_loss = K.mean(K.sum(K.abs(p_true -p_pred) * wl_mask[:,:,np.newaxis], axis=(1,2,3)))
    
    return bhp_loss

In [6]:
def create_e2c(latent_dim, u_dim, input_shape, sigma=0.0):
    '''
    Creates a E2C.

    Args:
        latent_dim: dimensionality of latent space
        return_kl_loss_op: whether to return the operation for
                           computing the KL divergence loss.

    Returns:
        The VAE model. If return_kl_loss_op is True, then the
        operation for computing the KL divergence loss is
        additionally returned.
    '''

    encoder_ = e2c_util.create_encoder(latent_dim, input_shape, sigma=sigma)
    decoder_ = e2c_util.create_decoder(latent_dim, input_shape)
    transition_ = e2c_util.create_trans(latent_dim, u_dim)
    wc_encoder_ = e2c_util.create_wc_encoder(latent_dim, input_shape)

    return encoder_, decoder_, transition_, wc_encoder_

In [7]:
# Create plain E2C model and associated loss operations

################### case specification ######################

# data_dir = '/data/cees/zjin/lstm_rom/datasets/9W_BHP/'
# data_dir = '/data/cees/zjin/lstm_rom/datasets/9W_BHP_RATE/'
# data_dir = '/data/cees/zjin/lstm_rom/datasets/9W_MS_BHP_RATE/'
data_dir = '/data3/Astro/personal/zjin/datasets/9W_MS_BHP_RATE_GAU/'
# data_dir = '/data3/Astro/personal/zjin/datasets/7W_CHA/'

output_dir = '/data3/Astro/lstm_rom/e2c_larry/saved_models/'

# case_name = '9w_bhp'
# case_name = '9w_bhp_rate'
case_name = '9w_ms_bhp_rate'
# case_name = '7w_cha'

# case_suffix = '_single_out_rel_2'
# case_suffix = '_fix_wl_rel_8'
case_suffix = '_var_wl_rel_1'

train_suffix = '_with_p'

model_suffix = '_flux_loss'
# model_suffix = '_ae_no_l2_ep_10'
# model_suffix = '_no_fl'


n_train_run = 300
n_eval_run = 100
num_t = 20 
# dt = 200 // num_t
dt = 100
n_train_step = n_train_run * num_t
n_eval_step = n_eval_run * num_t


train_file = case_name + '_e2c_train' + case_suffix + train_suffix + '_n%d_dt%dday_nt%d_nrun%d.mat' %(n_train_step, dt, num_t, n_train_run)
eval_file = case_name + '_e2c_eval' + case_suffix + train_suffix +'_n%d_dt%dday_nt%d_nrun%d.mat' %(n_eval_step, dt, num_t, n_eval_run)

#################### model specification ##################
epoch = 20
batch_size = 4
learning_rate = 1e-4
latent_dim = 50

# u_dim = 9*2 # control dimension
u_dim = latent_dim # wwll control encoded to latent dimension

In [8]:
# load data
hf_r = h5py.File(data_dir + train_file, 'r')
state_t_train = np.array(hf_r.get('state_t'))
state_t1_train = np.array(hf_r.get('state_t1'))
bhp_train = np.array(hf_r.get('bhp'))
dt_train = np.array(hf_r.get('dt'))
wl_mask_train = np.array(hf_r.get('wl_mask'))
hf_r.close()

num_train = state_t_train.shape[0]
# dt_train = np.ones((num_train,1)) # dt=20days, normalized to 1

hf_r = h5py.File(data_dir + eval_file, 'r')
state_t_eval = np.array(hf_r.get('state_t'))
state_t1_eval = np.array(hf_r.get('state_t1'))
bhp_eval = np.array(hf_r.get('bhp'))
dt_eval = np.array(hf_r.get('dt'))
wl_mask_eval = np.array(hf_r.get('wl_mask'))
hf_r.close()

num_eval = state_t_eval.shape[0]

In [9]:
# m = np.loadtxt("/data/cees/zjin/lstm_rom/sim_runs/case6_9w_bhp_rate_ms_h5/template/logk1.dat")
m = np.loadtxt("/data3/Astro/personal/zjin/sim_runs/case8_9w_bhp_rate_ms_gau/template/logk1.dat") # Gaussian
# m = np.loadtxt("/data3/Astro/personal/zjin/sim_runs/case9_cha/template/logk1.dat") # channelized

m = m.reshape(60, 60, 1)
print('m shape is ', m.shape)
#     m_tf = K.placeholder((batch_size, 60, 60 ,1))
m_tf = Input(shape=(60, 60, 1))


m_eval = np.repeat(np.expand_dims(m, axis = 0), state_t_eval.shape[0], axis = 0)
print("m_eval shape is ", m_eval.shape)
m = np.repeat(np.expand_dims(m,axis = 0), state_t_train.shape[0], axis = 0)
print("m shape is ", m.shape)

m shape is  (60, 60, 1)
m_eval shape is  (2000, 60, 60, 1)
m shape is  (6000, 60, 60, 1)


In [10]:
# import importlib
# importlib.reload(e2c_util)

In [11]:
# Construct E2C
input_shape = (60, 60, 2)

#############################################
# here we use a UAE framework, sigma = 0.0001
#############################################
encoder, decoder, transition, wc_encoder = create_e2c(latent_dim, u_dim, input_shape, sigma=0.0) 


xt = Input(shape=input_shape)
xt1 = Input(shape=input_shape)
# ut = Input(shape=(u_dim, ))
ut = Input(shape=input_shape) # (60,60,2)
dt = Input(shape=(1,))
wl_mask = Input(shape=(60, 60)) # both prod and inj

zt = encoder(xt)
xt_rec = decoder(zt)


ut_encoded = wc_encoder(ut)
zt1 = encoder(xt1)

zt1_pred = transition([zt, ut_encoded, dt])
xt1_pred = decoder(zt1_pred)

# Compute loss
loss_rec_t = reconstruction_loss(xt, xt_rec)
loss_rec_t1 = reconstruction_loss(xt1, xt1_pred)

loss_flux_t = get_flux_loss(m_tf, xt, xt_rec) / 1000.
loss_flux_t1 = get_flux_loss(m_tf, xt1, xt1_pred) / 1000.

binary_sat_loss_t = get_binary_sat_loss(xt, xt_rec) * 1
binary_sat_loss_t1 = get_binary_sat_loss(xt1, xt1_pred) * 1

loss_prod_bhp_t = get_well_bhp_loss(xt, xt_rec, wl_mask) * 1
loss_prod_bhp_t1 = get_well_bhp_loss(xt1, xt1_pred, wl_mask) * 1

loss_l2_reg = l2_reg_loss(zt)  # log(1.) = 0.


## -- loss bound: combine data losses
# loss_bound = loss_rec_t + loss_rec_t1 + loss_l2_reg  + loss_flux_t + loss_flux_t1
# loss_bound = loss_rec_t + loss_rec_t1 + loss_kl + binary_sat_loss_t + binary_sat_loss_t1
# loss_bound = loss_rec_t + loss_rec_t1 + loss_l2_reg  + loss_flux_t + loss_flux_t1 + loss_prod_bhp_t + loss_prod_bhp_t1 # JCP 2019 Gaussian case
loss_bound = loss_rec_t + loss_rec_t1  + loss_flux_t + loss_flux_t1 + loss_prod_bhp_t + loss_prod_bhp_t1 # UAE
# loss_bound = loss_rec_t + loss_rec_t1 + loss_l2_reg # no flux/bhp loss comparison



# Use zt_logvar to approximate zt1_logvar_pred
loss_trans = l2_reg_loss(zt1_pred - zt1)
# loss_trans = kl_normal_loss(zt1_mean_pred, zt1_logvar_pred, zt1_mean, zt1_logvar)


trans_loss_weight = 1.0 # lambda in E2C paper Eq. (11)
loss = loss_bound + trans_loss_weight * loss_trans

In [12]:
## log for tensorboard
def write_summary(value, tag, summary_writer, global_step):
    """Write a single summary value to tensorboard"""
    summary = tf.Summary()
    summary.value.add(tag=tag, simple_value=value)
    summary_writer.add_summary(summary, global_step)

## used to generate log directory
currentDT = datetime.datetime.now()
current_time = str(currentDT).replace(" ", "-")[:-10]
print(current_time)

summary_writer = tf.summary.FileWriter('logs/' + case_name + case_suffix + '_ep' + str(epoch) + '_tr' + str(n_train_run) + '_' + current_time)

2019-05-17-14:03


In [13]:
# Optimization
opt = Adam(lr=learning_rate)

trainable_weights = encoder.trainable_weights + decoder.trainable_weights + transition.trainable_weights + wc_encoder.trainable_weights

updates = opt.get_updates(loss, trainable_weights)

# iterate = K.function([xt, ut, xt1, m_tf, dt], [loss, loss_rec_t, loss_rec_t1, loss_l2_reg, loss_trans, loss_flux_t, loss_flux_t1, binary_sat_loss_t, binary_sat_loss_t1], updates=updates)
iterate = K.function([xt, ut, xt1, m_tf, wl_mask, dt], [loss, loss_rec_t, loss_rec_t1, loss_l2_reg, loss_trans, loss_flux_t, loss_flux_t1, loss_prod_bhp_t, loss_prod_bhp_t1], updates=updates)

eval_loss = K.function([xt, ut, xt1, m_tf, wl_mask, dt], [loss])

num_batch = int(num_train/batch_size)

for e in range(epoch):
    for ib in range(num_batch):
        ind0 = ib * batch_size
        state_t_batch  = state_t_train[ind0:ind0+batch_size, ...]
        state_t1_batch = state_t1_train[ind0:ind0 + batch_size, ...]
        bhp_batch      = bhp_train[ind0:ind0 + batch_size, ...]
        m_batch        = m[ind0:ind0 + batch_size, ...]
        dt_batch       = dt_train[ind0:ind0 + batch_size, ...]
        wl_mask_batch  = wl_mask_train[ind0:ind0 + batch_size, ...]

        output = iterate([state_t_batch, bhp_batch, state_t1_batch, m_batch, wl_mask_batch, dt_batch])

        # tf.session.run(feed_dict={xt: sat_t_batch, ut: bhp_batch, xt1: sat_t1_batch}, ...
        #                fetches= [loss, loss_rec_t, loss_rec_t1, loss_kl, loss_trans, updates])
        # But output tensor for the updates operation is not returned
        
        n_itr = e * num_train + ib * batch_size + batch_size
        write_summary(output[0], 'train/total_loss', summary_writer, n_itr) # log for tensorboard
        write_summary(output[1]+output[2], 'train/sum_rec_loss', summary_writer, n_itr) # log for tensorboard
        write_summary(output[5]+output[6], 'train/sum_flux_loss', summary_writer, n_itr) # log for tensorboard
        write_summary(output[7]+output[8], 'train/sum_well_loss', summary_writer, n_itr) # log for tensorboard
        
        if ib % 50 == 0:
            print('Epoch %d/%d, Batch %d/%d, Loss %f, Loss rec %f, loss rec t1 %f, loss kl %f, loss_trans %f, loss flux %f, loss flux t1 %f, prod bhp loss %f, prod bhp loss t1 %f'
                  % (e+1, epoch, ib+1, num_batch, output[0], output[1], output[2], output[3], output[4], output[5], output[6], output[7], output[8]))
            
            eval_loss_val = eval_loss([state_t_eval, bhp_eval, state_t1_eval, m_eval, wl_mask_eval, dt_eval])
            write_summary(eval_loss_val[0], 'eval/total_loss', summary_writer, n_itr) # log for tensorboard

    
    print('====================================================')
    print('\n')
    print('Epoch %d/%d, Train loss %f, Eval loss %f' % (e + 1, epoch, output[0], eval_loss_val[0]))
    print('\n')
    print('====================================================')

encoder.save_weights(output_dir + 'e2c_encoder_dt_' + case_name + case_suffix + train_suffix + model_suffix + '_nt%d_l%d_lr%.0e_ep%d.h5' % (num_train, latent_dim, learning_rate, epoch))
decoder.save_weights(output_dir + 'e2c_decoder_dt_' + case_name + case_suffix + train_suffix + model_suffix + '_nt%d_l%d_lr%.0e_ep%d.h5' % (num_train, latent_dim, learning_rate, epoch))
transition.save_weights(output_dir + 'e2c_transition_dt_' + case_name + case_suffix + train_suffix + model_suffix + '_nt%d_l%d_lr%.0e_ep%d.h5' % (num_train, latent_dim, learning_rate, epoch))
wc_encoder.save_weights(output_dir + 'e2c_wc_encoder_dt_' + case_name + case_suffix + train_suffix + model_suffix + '_nt%d_l%d_lr%.0e_ep%d.h5' % (num_train, latent_dim, learning_rate, epoch))

Epoch 1/10, Batch 1/1500, Loss 16998.460938, Loss rec 7621.967285, loss rec t1 8587.764648, loss kl 0.037656, loss_trans 0.043145, loss flux 59.968620, loss flux t1 79.230949, prod bhp loss 313.883148, prod bhp loss t1 335.603088
Epoch 1/10, Batch 51/1500, Loss 2069.970459, Loss rec 379.654236, loss rec t1 1254.191406, loss kl 9.352005, loss_trans 10.955702, loss flux 153.029648, loss flux t1 147.739883, prod bhp loss 35.288845, prod bhp loss t1 89.110550
Epoch 1/10, Batch 101/1500, Loss 1292.242554, Loss rec 461.143921, loss rec t1 427.905243, loss kl 10.342924, loss_trans 10.590265, loss flux 148.115234, loss flux t1 156.272263, prod bhp loss 47.818321, prod bhp loss t1 40.397308
Epoch 1/10, Batch 151/1500, Loss 1547.635376, Loss rec 573.058533, loss rec t1 543.290649, loss kl 14.458227, loss_trans 13.846134, loss flux 155.002548, loss flux t1 166.146835, prod bhp loss 51.548645, prod bhp loss t1 44.742023
Epoch 1/10, Batch 201/1500, Loss 1130.814331, Loss rec 401.421478, loss rec t1

Epoch 2/10, Batch 301/1500, Loss 480.711121, Loss rec 114.513763, loss rec t1 171.381470, loss kl 1.493479, loss_trans 0.669249, loss flux 67.812332, loss flux t1 68.961311, prod bhp loss 25.120136, prod bhp loss t1 32.252842
Epoch 2/10, Batch 351/1500, Loss 322.053284, Loss rec 72.592041, loss rec t1 95.900421, loss kl 1.671621, loss_trans 0.601464, loss flux 61.474728, loss flux t1 60.839954, prod bhp loss 12.537811, prod bhp loss t1 18.106846
Epoch 2/10, Batch 401/1500, Loss 421.334320, Loss rec 101.546768, loss rec t1 150.238739, loss kl 1.468535, loss_trans 0.410820, loss flux 58.671902, loss flux t1 57.769142, prod bhp loss 22.233505, prod bhp loss t1 30.463470
Epoch 2/10, Batch 451/1500, Loss 498.505798, Loss rec 115.604385, loss rec t1 195.894501, loss kl 2.491703, loss_trans 1.117937, loss flux 55.202736, loss flux t1 69.734344, prod bhp loss 27.769165, prod bhp loss t1 33.182705
Epoch 2/10, Batch 501/1500, Loss 504.866913, Loss rec 122.419647, loss rec t1 202.814209, loss kl 

Epoch 3/10, Batch 601/1500, Loss 323.274506, Loss rec 75.469635, loss rec t1 88.818420, loss kl 1.538796, loss_trans 0.330426, loss flux 59.022690, loss flux t1 59.569454, prod bhp loss 19.058285, prod bhp loss t1 21.005606
Epoch 3/10, Batch 651/1500, Loss 313.668182, Loss rec 72.479477, loss rec t1 94.222046, loss kl 1.399804, loss_trans 0.317713, loss flux 56.274853, loss flux t1 54.213322, prod bhp loss 14.514521, prod bhp loss t1 21.646225
Epoch 3/10, Batch 701/1500, Loss 224.213074, Loss rec 47.616035, loss rec t1 54.829407, loss kl 0.962479, loss_trans 0.195670, loss flux 48.576088, loss flux t1 49.958412, prod bhp loss 10.705578, prod bhp loss t1 12.331879
Epoch 3/10, Batch 751/1500, Loss 214.221115, Loss rec 43.531059, loss rec t1 46.162704, loss kl 1.208065, loss_trans 0.366457, loss flux 48.376644, loss flux t1 49.147579, prod bhp loss 12.933926, prod bhp loss t1 13.702745
Epoch 3/10, Batch 801/1500, Loss 186.275879, Loss rec 31.904953, loss rec t1 43.766308, loss kl 0.771807

Epoch 4/10, Batch 901/1500, Loss 190.965057, Loss rec 34.361046, loss rec t1 41.466694, loss kl 1.239509, loss_trans 0.146438, loss flux 46.390610, loss flux t1 48.142509, prod bhp loss 9.651222, prod bhp loss t1 10.806534
Epoch 4/10, Batch 951/1500, Loss 329.030273, Loss rec 33.052937, loss rec t1 160.835175, loss kl 1.084533, loss_trans 0.301620, loss flux 37.792908, loss flux t1 56.241661, prod bhp loss 11.064322, prod bhp loss t1 29.741682
Epoch 4/10, Batch 1001/1500, Loss 162.802216, Loss rec 34.039917, loss rec t1 34.106491, loss kl 0.455316, loss_trans 0.063784, loss flux 37.200134, loss flux t1 37.974522, prod bhp loss 9.635582, prod bhp loss t1 9.781801
Epoch 4/10, Batch 1051/1500, Loss 189.171005, Loss rec 33.794918, loss rec t1 39.926479, loss kl 1.175774, loss_trans 0.178833, loss flux 45.797112, loss flux t1 45.895416, prod bhp loss 10.907959, prod bhp loss t1 12.670288
Epoch 4/10, Batch 1101/1500, Loss 209.114975, Loss rec 37.739662, loss rec t1 56.647713, loss kl 0.99487

Epoch 5/10, Batch 1201/1500, Loss 200.267670, Loss rec 30.943449, loss rec t1 64.925079, loss kl 1.415661, loss_trans 0.174955, loss flux 39.340954, loss flux t1 39.747799, prod bhp loss 5.989881, prod bhp loss t1 19.145546
Epoch 5/10, Batch 1251/1500, Loss 137.671417, Loss rec 21.880875, loss rec t1 29.119711, loss kl 0.823642, loss_trans 0.098331, loss flux 33.729530, loss flux t1 35.560642, prod bhp loss 6.956777, prod bhp loss t1 10.325560
Epoch 5/10, Batch 1301/1500, Loss 119.333633, Loss rec 16.317402, loss rec t1 23.163242, loss kl 0.402783, loss_trans 0.047628, loss flux 31.576704, loss flux t1 31.738760, prod bhp loss 6.902711, prod bhp loss t1 9.587182
Epoch 5/10, Batch 1351/1500, Loss 172.754868, Loss rec 26.126499, loss rec t1 45.662487, loss kl 0.748421, loss_trans 0.090629, loss flux 36.687469, loss flux t1 38.106369, prod bhp loss 9.720388, prod bhp loss t1 16.361042
Epoch 5/10, Batch 1401/1500, Loss 209.647964, Loss rec 25.581730, loss rec t1 74.874428, loss kl 0.858303



Epoch 6/10, Train loss 147.518066, Eval loss 250.327988


Epoch 7/10, Batch 1/1500, Loss 170.391724, Loss rec 14.879498, loss rec t1 68.147896, loss kl 0.518436, loss_trans 0.110388, loss flux 23.314312, loss flux t1 39.409149, prod bhp loss 5.833701, prod bhp loss t1 18.696800
Epoch 7/10, Batch 51/1500, Loss 374.244476, Loss rec 24.198294, loss rec t1 232.079895, loss kl 0.381052, loss_trans 0.270539, loss flux 24.509640, loss flux t1 44.399632, prod bhp loss 11.316113, prod bhp loss t1 37.470383
Epoch 7/10, Batch 101/1500, Loss 201.369171, Loss rec 11.742079, loss rec t1 102.094635, loss kl 0.405273, loss_trans 0.109185, loss flux 21.699709, loss flux t1 37.256641, prod bhp loss 5.472520, prod bhp loss t1 22.994394
Epoch 7/10, Batch 151/1500, Loss 216.602081, Loss rec 42.586342, loss rec t1 58.179565, loss kl 1.071868, loss_trans 0.087554, loss flux 36.723377, loss flux t1 38.958145, prod bhp loss 18.354828, prod bhp loss t1 21.712265
Epoch 7/10, Batch 201/1500, Loss 182.305435, Lo

Epoch 8/10, Batch 301/1500, Loss 122.653770, Loss rec 14.093686, loss rec t1 29.224962, loss kl 0.535538, loss_trans 0.050883, loss flux 28.876875, loss flux t1 29.784798, prod bhp loss 5.912075, prod bhp loss t1 14.710499
Epoch 8/10, Batch 351/1500, Loss 106.793053, Loss rec 10.939372, loss rec t1 22.411701, loss kl 0.516372, loss_trans 0.047173, loss flux 27.590158, loss flux t1 28.643549, prod bhp loss 5.211547, prod bhp loss t1 11.949556
Epoch 8/10, Batch 401/1500, Loss 121.653999, Loss rec 17.677279, loss rec t1 33.341328, loss kl 0.475792, loss_trans 0.036582, loss flux 25.573851, loss flux t1 26.834185, prod bhp loss 4.407091, prod bhp loss t1 13.783686
Epoch 8/10, Batch 451/1500, Loss 187.677200, Loss rec 18.620583, loss rec t1 73.904144, loss kl 0.842036, loss_trans 0.137247, loss flux 26.411573, loss flux t1 41.904716, prod bhp loss 8.055933, prod bhp loss t1 18.642982
Epoch 8/10, Batch 501/1500, Loss 146.635056, Loss rec 14.451952, loss rec t1 50.031166, loss kl 0.615293, lo

Epoch 9/10, Batch 601/1500, Loss 112.725410, Loss rec 14.696182, loss rec t1 24.417311, loss kl 0.760971, loss_trans 0.047995, loss flux 25.931519, loss flux t1 27.615591, prod bhp loss 7.394941, prod bhp loss t1 12.621868
Epoch 9/10, Batch 651/1500, Loss 123.780685, Loss rec 13.470760, loss rec t1 32.434338, loss kl 0.663631, loss_trans 0.048179, loss flux 26.831444, loss flux t1 28.274302, prod bhp loss 6.086837, prod bhp loss t1 16.634827
Epoch 9/10, Batch 701/1500, Loss 125.988548, Loss rec 22.860191, loss rec t1 29.850471, loss kl 0.505839, loss_trans 0.040404, loss flux 24.811890, loss flux t1 26.741270, prod bhp loss 10.533954, prod bhp loss t1 11.150368
Epoch 9/10, Batch 751/1500, Loss 156.819244, Loss rec 16.052956, loss rec t1 55.798817, loss kl 0.622886, loss_trans 0.087285, loss flux 26.480082, loss flux t1 29.209044, prod bhp loss 5.033558, prod bhp loss t1 24.157505
Epoch 9/10, Batch 801/1500, Loss 95.074059, Loss rec 9.811247, loss rec t1 20.392448, loss kl 0.465293, los

Epoch 10/10, Batch 901/1500, Loss 103.115364, Loss rec 13.093022, loss rec t1 19.613970, loss kl 0.845531, loss_trans 0.053482, loss flux 26.256393, loss flux t1 28.306498, prod bhp loss 7.495326, prod bhp loss t1 8.296675
Epoch 10/10, Batch 951/1500, Loss 158.034912, Loss rec 11.840683, loss rec t1 63.101738, loss kl 0.800026, loss_trans 0.150233, loss flux 20.724724, loss flux t1 38.088768, prod bhp loss 6.246641, prod bhp loss t1 17.882113
Epoch 10/10, Batch 1001/1500, Loss 83.606812, Loss rec 9.609060, loss rec t1 16.198479, loss kl 0.356361, loss_trans 0.027877, loss flux 22.366554, loss flux t1 23.271360, prod bhp loss 3.957231, prod bhp loss t1 8.176249
Epoch 10/10, Batch 1051/1500, Loss 141.319611, Loss rec 29.523558, loss rec t1 32.918831, loss kl 0.837460, loss_trans 0.057644, loss flux 28.312979, loss flux t1 28.094795, prod bhp loss 11.079024, prod bhp loss t1 11.332771
Epoch 10/10, Batch 1101/1500, Loss 119.109657, Loss rec 17.173141, loss rec t1 29.846642, loss kl 0.77458