In [1]:
import numpy as np
import h5py

import e2c_1 as e2c_util

# os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
# The GPU id to use, usually either "0" or "1", "2", "3"

import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import datetime

from keras import backend as K
from keras.layers import Input
from keras.models import Model
from keras.optimizers import Adam
from keras import losses


# GPU memory management
import tensorflow as tf

Using TensorFlow backend.


In [2]:
# tf.session specification
# TensorFlow wizardry
config = tf.ConfigProto()

# Don't pre-allocate memory; allocate as-needed
config.gpu_options.allow_growth = True

# Only allow a total of half the GPU memory to be allocated
config.gpu_options.per_process_gpu_memory_fraction = 0.75

# Create a session with the above options specified.
K.tensorflow_backend.set_session(tf.Session(config=config))

In [3]:
def reconstruction_loss(x, t_decoded):
    '''
    Reconstruction loss for the plain VAE
    '''
    v = 0.1
    # return K.mean(K.sum((K.batch_flatten(x) - K.batch_flatten(t_decoded)) ** 2 / (2*v) + 0.5*K.log(2*np.pi*v), axis=-1))
    return K.mean(K.sum((K.batch_flatten(x) - K.batch_flatten(t_decoded)) ** 2 / (2*v), axis=-1))
    # return K.sum((K.batch_flatten(x) - K.batch_flatten(t_decoded)) ** 2, axis=-1)


def l2_reg_loss(qm):
    # 0.5 * (torch.log(pv) - torch.log(qv) + qv / pv + (qm - pm).pow(2) / pv - 1)
    # -0.5 * K.sum(1 + t_log_var - K.square(t_mean) - K.exp(t_log_var), axis=-1)
#     kl = -0.5 * (1 - p_logv + q_logv - K.exp(q_logv) / K.exp(p_logv) - K.square(qm - pm) / K.exp(p_logv))
    l2_reg = 0.5*K.square(qm)
    return K.mean(K.sum(l2_reg, axis=-1))


def get_flux_loss(m, state, state_pred):
    '''
    @params:  state, state_pred shape (batch_size, 60, 60, 2)
              p, p_pred shape (batch_size, 60, 60, 1)
              m shape (batch_size, 60, 60, 1)
    
    @return:  loss_flux: scalar
    
    Only consider discrepancies in total flux, not in phases (saturation not used) 
    '''
    
    perm = K.exp(m)
    p = K.expand_dims(state[:, :, :, 1], -1)
    p_pred = K.expand_dims(state_pred[:, :, :, 1], -1)

    #print(K.int_shape(xxx))
    
    tran_x = 1./perm[:, 1:, ...] + 1./perm[:, :-1, ...]
    tran_y = 1./perm[:, :, 1:, ...] + 1./perm[:, :, :-1, ...]
    flux_x = (p[:, 1:, ...] - p[:, :-1, ...]) / tran_x
    flux_y = (p[:, :, 1:, :] - p[:, :, :-1, :]) / tran_y
    flux_x_pred = (p_pred[:, 1:, ...] - p_pred[:, :-1, ...]) / tran_x
    flux_y_pred = (p_pred[:, :, 1:, :] - p_pred[:, :, :-1, :]) / tran_y

    loss_x = K.sum(K.abs(K.batch_flatten(flux_x) - K.batch_flatten(flux_x_pred)), axis=-1)
    loss_y = K.sum(K.abs(K.batch_flatten(flux_y) - K.batch_flatten(flux_y_pred)), axis=-1)

    loss_flux = K.mean(loss_x + loss_y)
    return loss_flux

In [4]:
def get_binary_sat_loss(state, state_pred):
    
    sat_threshold = 0.105
    sat = K.expand_dims(state[:, :, :, 0], -1)
    sat_pred = K.expand_dims(state_pred[:, :, :, 0], -1)
    
    
    sat_bool = K.greater_equal(sat, sat_threshold) #will return boolean values
    sat_bin = K.cast(sat_bool, dtype=K.floatx()) #will convert bool to 0 and 1  
    
    sat_pred_bool = K.greater_equal(sat_pred, sat_threshold) #will return boolean values
    sat_pred_bin = K.cast(sat_pred_bool, dtype=K.floatx()) #will convert bool to 0 and 1  
    
#     binary_loss = K.sum(K.abs(K.batch_flatten(sat_bin) - K.batch_flatten(sat_pred_bin)), axis=-1)
    
    binary_loss = losses.binary_crossentropy(sat_bin, sat_pred_bin)
    return K.mean(binary_loss)

In [5]:
def get_well_bhp_loss(state, state_pred, wl_mask):
    '''
    @params: state: shape (batch_size, 60, 60, 2)
             state_pred: shape (batch_size, 60, 60, 2)
             prod_well_loc: shape (batch_size, 5, 2)
             
    p_true: shape (batch_size, 60, 60, 1)
    p_pred: shape (batch_size, 60, 60, 1)
    
    @return: bhp_loss: scalar
    '''
    
    p_true = K.expand_dims(state[:, :, :, 1], -1) # shape (batch_size, 60, 60 ,1)
    p_pred = K.expand_dims(state_pred[:, :, :, 1], -1)
    
    bhp_loss = K.mean(K.sum(K.abs(p_true -p_pred) * wl_mask[:,:,np.newaxis], axis=(1,2,3)))
    
    return bhp_loss

In [6]:
def create_e2c(latent_dim, u_dim, input_shape, sigma=0.0):
    '''
    Creates a E2C.

    Args:
        latent_dim: dimensionality of latent space
        return_kl_loss_op: whether to return the operation for
                           computing the KL divergence loss.

    Returns:
        The VAE model. If return_kl_loss_op is True, then the
        operation for computing the KL divergence loss is
        additionally returned.
    '''

    encoder_, hidden_shapes_ = e2c_util.create_encoder(latent_dim, input_shape, sigma=sigma)
    decoder_ = e2c_util.create_decoder(latent_dim, input_shape, hidden_shapes_)
    transition_ = e2c_util.create_trans(latent_dim, u_dim)
#     wc_encoder_ = e2c_util.create_wc_encoder(latent_dim, input_shape)
    wc_encoder_, _ = e2c_util.create_encoder(latent_dim, input_shape) # input_shape (60,60,2)

    return encoder_, decoder_, transition_, wc_encoder_

In [7]:
## -- Create plain E2C model and associated loss operations
## -- It should be put int0 main function from this part on,
## -- if converted to .py file


################### case specification ######################

# data_dir = '/data/cees/zjin/lstm_rom/datasets/9W_BHP/'
# data_dir = '/data/cees/zjin/lstm_rom/datasets/9W_BHP_RATE/'
# data_dir = '/data/cees/zjin/lstm_rom/datasets/9W_MS_BHP_RATE/'
data_dir = '/data3/Astro/personal/zjin/datasets/9W_MS_BHP_RATE_GAU/'
# data_dir = '/data3/Astro/personal/zjin/datasets/7W_CHA/'

output_dir = '/data3/Astro/lstm_rom/e2c_larry/saved_models/'

# case_name = '9w_bhp'
# case_name = '9w_bhp_rate'
case_name = '9w_ms_bhp_rate'
# case_name = '7w_cha'

# case_suffix = '_single_out_rel_2'
# case_suffix = '_fix_wl_rel_8'
case_suffix = '_var_wl_rel_1'

train_suffix = '_with_p'

model_suffix = '_flux_loss'
# model_suffix = '_ae_no_l2_ep_10'
# model_suffix = '_no_fl'


n_train_run = 300
n_eval_run = 100
num_t = 20 
# dt = 200 // num_t
dt = 100
n_train_step = n_train_run * num_t
n_eval_step = n_eval_run * num_t


train_file = case_name + '_e2c_train' + case_suffix + train_suffix + '_n%d_dt%dday_nt%d_nrun%d.mat' %(n_train_step, dt, num_t, n_train_run)
eval_file = case_name + '_e2c_eval' + case_suffix + train_suffix +'_n%d_dt%dday_nt%d_nrun%d.mat' %(n_eval_step, dt, num_t, n_eval_run)

#################### model specification ##################
epoch = 20
batch_size = 4
learning_rate = 1e-4
latent_dim = 50

# u_dim = 9*2 # control dimension
u_dim = latent_dim # wwll control encoded to latent dimension

In [8]:
# load data
hf_r = h5py.File(data_dir + train_file, 'r')
state_t_train = np.array(hf_r.get('state_t'))
state_t1_train = np.array(hf_r.get('state_t1'))
bhp_train = np.array(hf_r.get('bhp'))
dt_train = np.array(hf_r.get('dt'))
wl_mask_train = np.array(hf_r.get('wl_mask'))
hf_r.close()

num_train = state_t_train.shape[0]
# dt_train = np.ones((num_train,1)) # dt=20days, normalized to 1

hf_r = h5py.File(data_dir + eval_file, 'r')
state_t_eval = np.array(hf_r.get('state_t'))
state_t1_eval = np.array(hf_r.get('state_t1'))
bhp_eval = np.array(hf_r.get('bhp'))
dt_eval = np.array(hf_r.get('dt'))
wl_mask_eval = np.array(hf_r.get('wl_mask'))
hf_r.close()

num_eval = state_t_eval.shape[0]

In [9]:
# m = np.loadtxt("/data/cees/zjin/lstm_rom/sim_runs/case6_9w_bhp_rate_ms_h5/template/logk1.dat")
m = np.loadtxt("/data3/Astro/personal/zjin/sim_runs/case8_9w_bhp_rate_ms_gau/template/logk1.dat") # Gaussian
# m = np.loadtxt("/data3/Astro/personal/zjin/sim_runs/case9_cha/template/logk1.dat") # channelized

m = m.reshape(60, 60, 1)
print('m shape is ', m.shape)
#     m_tf = K.placeholder((batch_size, 60, 60 ,1))
m_tf = Input(shape=(60, 60, 1))


m_eval = np.repeat(np.expand_dims(m, axis = 0), state_t_eval.shape[0], axis = 0)
print("m_eval shape is ", m_eval.shape)
m = np.repeat(np.expand_dims(m,axis = 0), state_t_train.shape[0], axis = 0)
print("m shape is ", m.shape)

m shape is  (60, 60, 1)
m_eval shape is  (2000, 60, 60, 1)
m shape is  (6000, 60, 60, 1)


In [10]:
# import importlib
# importlib.reload(e2c_util)

In [11]:
# Construct E2C
input_shape = (60, 60, 2)

#############################################
# here we use a UAE framework, sigma = 0.0001
#############################################
encoder, decoder, transition, wc_encoder = create_e2c(latent_dim, u_dim, input_shape, sigma=0.0) 


xt = Input(shape=input_shape)
xt1 = Input(shape=input_shape)
# ut = Input(shape=(u_dim, ))
ut = Input(shape=input_shape) # (60,60,2)
dt = Input(shape=(1,))
wl_mask = Input(shape=(60, 60)) # both prod and inj

zt= encoder(xt)
xt_rec = decoder(zt)

ut_encoded = wc_encoder(ut)
zt1 = encoder(xt1)

zt1_pred = transition([zt, ut_encoded, dt])
xt1_pred = decoder(zt1_pred)

# Compute loss
loss_rec_t = reconstruction_loss(xt, xt_rec)
loss_rec_t1 = reconstruction_loss(xt1, xt1_pred)

loss_flux_t = get_flux_loss(m_tf, xt, xt_rec) / 1000.
loss_flux_t1 = get_flux_loss(m_tf, xt1, xt1_pred) / 1000.

binary_sat_loss_t = get_binary_sat_loss(xt, xt_rec) * 1
binary_sat_loss_t1 = get_binary_sat_loss(xt1, xt1_pred) * 1

loss_prod_bhp_t = get_well_bhp_loss(xt, xt_rec, wl_mask) * 1
loss_prod_bhp_t1 = get_well_bhp_loss(xt1, xt1_pred, wl_mask) * 1

loss_l2_reg = l2_reg_loss(zt)  # log(1.) = 0.


## -- loss bound: combine data losses
# loss_bound = loss_rec_t + loss_rec_t1 + loss_l2_reg  + loss_flux_t + loss_flux_t1
# loss_bound = loss_rec_t + loss_rec_t1 + loss_kl + binary_sat_loss_t + binary_sat_loss_t1
# loss_bound = loss_rec_t + loss_rec_t1 + loss_l2_reg  + loss_flux_t + loss_flux_t1 + loss_prod_bhp_t + loss_prod_bhp_t1 # JCP 2019 Gaussian case
# loss_bound = loss_rec_t + loss_rec_t1  + loss_flux_t + loss_flux_t1 + loss_prod_bhp_t + loss_prod_bhp_t1 # UAE **
# loss_bound = loss_rec_t + loss_rec_t1 + loss_l2_reg # no flux/bhp loss comparison

loss_bound = loss_rec_t + loss_flux_t + loss_prod_bhp_t # just reconstruction


# Use zt_logvar to approximate zt1_logvar_pred
loss_trans = l2_reg_loss(zt1_pred - zt1)
# loss_trans = kl_normal_loss(zt1_mean_pred, zt1_logvar_pred, zt1_mean, zt1_logvar)


trans_loss_weight = 1.0 # lambda in E2C paper Eq. (11)
loss = loss_bound + trans_loss_weight * loss_trans

In [12]:
## log for tensorboard
def write_summary(value, tag, summary_writer, global_step):
    """Write a single summary value to tensorboard"""
    summary = tf.Summary()
    summary.value.add(tag=tag, simple_value=value)
    summary_writer.add_summary(summary, global_step)

## used to generate log directory
currentDT = datetime.datetime.now()
current_time = str(currentDT).replace(" ", "-")[:-10]
print(current_time)

suffix = ''

summary_writer = tf.summary.FileWriter('logs/' + case_name + case_suffix + '_ep' + str(epoch) + '_tr' + str(n_train_run) + '_' + current_time)

2019-05-27-14:21


In [13]:
# Optimization
opt = Adam(lr=learning_rate)

trainable_weights = encoder.trainable_weights + decoder.trainable_weights + transition.trainable_weights + wc_encoder.trainable_weights

updates = opt.get_updates(loss, trainable_weights)

# iterate = K.function([xt, ut, xt1, m_tf, dt], [loss, loss_rec_t, loss_rec_t1, loss_l2_reg, loss_trans, loss_flux_t, loss_flux_t1, binary_sat_loss_t, binary_sat_loss_t1], updates=updates)
iterate = K.function([xt, ut, xt1, m_tf, wl_mask, dt], [loss, loss_rec_t, loss_rec_t1, loss_l2_reg, loss_trans, loss_flux_t, loss_flux_t1, loss_prod_bhp_t, loss_prod_bhp_t1], updates=updates)

eval_loss = K.function([xt, ut, xt1, m_tf, wl_mask, dt], [loss])

num_batch = int(num_train/batch_size)

for e in range(epoch):
    for ib in range(num_batch):
        ind0 = ib * batch_size
        state_t_batch  = state_t_train[ind0:ind0+batch_size, ...]
        state_t1_batch = state_t1_train[ind0:ind0 + batch_size, ...]
        bhp_batch      = bhp_train[ind0:ind0 + batch_size, ...]
        m_batch        = m[ind0:ind0 + batch_size, ...]
        dt_batch       = dt_train[ind0:ind0 + batch_size, ...]
        wl_mask_batch  = wl_mask_train[ind0:ind0 + batch_size, ...]

        output = iterate([state_t_batch, bhp_batch, state_t1_batch, m_batch, wl_mask_batch, dt_batch])

        # tf.session.run(feed_dict={xt: sat_t_batch, ut: bhp_batch, xt1: sat_t1_batch}, ...
        #                fetches= [loss, loss_rec_t, loss_rec_t1, loss_kl, loss_trans, updates])
        # But output tensor for the updates operation is not returned
        
        n_itr = e * num_train + ib * batch_size + batch_size
        write_summary(output[0], 'train/total_loss', summary_writer, n_itr) # log for tensorboard
        write_summary(output[1]+output[2], 'train/sum_rec_loss', summary_writer, n_itr) # log for tensorboard
        write_summary(output[5]+output[6], 'train/sum_flux_loss', summary_writer, n_itr) # log for tensorboard
        write_summary(output[7]+output[8], 'train/sum_well_loss', summary_writer, n_itr) # log for tensorboard
        
        if ib % 100 == 0:
            print('Epoch %d/%d, Batch %d/%d, Loss %f, Loss rec %f, loss rec t1 %f, loss kl %f, loss_trans %f, loss flux %f, loss flux t1 %f, prod bhp loss %f, prod bhp loss t1 %f'
                  % (e+1, epoch, ib+1, num_batch, output[0], output[1], output[2], output[3], output[4], output[5], output[6], output[7], output[8]))
            
            eval_loss_val = eval_loss([state_t_eval, bhp_eval, state_t1_eval, m_eval, wl_mask_eval, dt_eval])
            write_summary(eval_loss_val[0], 'eval/total_loss', summary_writer, n_itr) # log for tensorboard

    
    print('====================================================')
    print('\n')
    print('Epoch %d/%d, Train loss %f, Eval loss %f' % (e + 1, epoch, output[0], eval_loss_val[0]))
    print('\n')
    print('====================================================')

encoder.save_weights(output_dir + 'e2c_encoder_dt_' + case_name + case_suffix + train_suffix + model_suffix + '_nt%d_l%d_lr%.0e_ep%d.h5' % (num_train, latent_dim, learning_rate, epoch))
decoder.save_weights(output_dir + 'e2c_decoder_dt_' + case_name + case_suffix + train_suffix + model_suffix + '_nt%d_l%d_lr%.0e_ep%d.h5' % (num_train, latent_dim, learning_rate, epoch))
transition.save_weights(output_dir + 'e2c_transition_dt_' + case_name + case_suffix + train_suffix + model_suffix + '_nt%d_l%d_lr%.0e_ep%d.h5' % (num_train, latent_dim, learning_rate, epoch))
wc_encoder.save_weights(output_dir + 'e2c_wc_encoder_dt_' + case_name + case_suffix + train_suffix + model_suffix + '_nt%d_l%d_lr%.0e_ep%d.h5' % (num_train, latent_dim, learning_rate, epoch))

Epoch 1/20, Batch 1/1500, Loss 8098.198242, Loss rec 7705.566895, loss rec t1 8593.504883, loss kl 0.235634, loss_trans 0.269626, loss flux 74.871117, loss flux t1 79.320747, prod bhp loss 317.490845, prod bhp loss t1 335.844727
Epoch 1/20, Batch 101/1500, Loss 672.066284, Loss rec 438.714111, loss rec t1 514.988953, loss kl 4.465278, loss_trans 0.055676, loss flux 201.718079, loss flux t1 221.134216, prod bhp loss 31.578407, prod bhp loss t1 39.385521
Epoch 1/20, Batch 201/1500, Loss 613.011353, Loss rec 419.524841, loss rec t1 491.309418, loss kl 7.946171, loss_trans 0.069762, loss flux 164.019135, loss flux t1 171.242691, prod bhp loss 29.397640, prod bhp loss t1 34.851440
Epoch 1/20, Batch 301/1500, Loss 508.416168, Loss rec 346.250397, loss rec t1 437.358429, loss kl 5.795074, loss_trans 0.158001, loss flux 128.610870, loss flux t1 128.702271, prod bhp loss 33.396908, prod bhp loss t1 43.457573
Epoch 1/20, Batch 401/1500, Loss 396.329956, Loss rec 261.361969, loss rec t1 311.61627

Epoch 3/20, Batch 501/1500, Loss 133.767517, Loss rec 66.767303, loss rec t1 98.238159, loss kl 1.012325, loss_trans 0.059785, loss flux 52.567505, loss flux t1 54.096107, prod bhp loss 14.372931, prod bhp loss t1 24.943945
Epoch 3/20, Batch 601/1500, Loss 122.362473, Loss rec 54.737404, loss rec t1 68.182030, loss kl 1.579574, loss_trans 0.028503, loss flux 55.658260, loss flux t1 57.787800, prod bhp loss 11.938307, prod bhp loss t1 14.589559
Epoch 3/20, Batch 701/1500, Loss 98.870918, Loss rec 42.943928, loss rec t1 48.502678, loss kl 0.864948, loss_trans 0.014382, loss flux 48.489822, loss flux t1 50.228676, prod bhp loss 7.422793, prod bhp loss t1 10.278042
Epoch 3/20, Batch 801/1500, Loss 87.735382, Loss rec 37.677307, loss rec t1 48.880745, loss kl 0.669064, loss_trans 0.020694, loss flux 43.825905, loss flux t1 45.513710, prod bhp loss 6.211479, prod bhp loss t1 12.184855
Epoch 3/20, Batch 901/1500, Loss 102.861656, Loss rec 40.594734, loss rec t1 45.609966, loss kl 1.629957, lo

Epoch 5/20, Batch 1101/1500, Loss 73.488335, Loss rec 28.958164, loss rec t1 50.976875, loss kl 1.210084, loss_trans 0.027695, loss flux 38.618668, loss flux t1 39.780762, prod bhp loss 5.883809, prod bhp loss t1 17.941092
Epoch 5/20, Batch 1201/1500, Loss 75.075943, Loss rec 29.345175, loss rec t1 63.047771, loss kl 1.806052, loss_trans 0.032393, loss flux 39.779140, loss flux t1 40.624901, prod bhp loss 5.919237, prod bhp loss t1 19.361492
Epoch 5/20, Batch 1301/1500, Loss 58.879692, Loss rec 21.510487, loss rec t1 35.602089, loss kl 0.351746, loss_trans 0.016381, loss flux 31.508165, loss flux t1 32.218674, prod bhp loss 5.844657, prod bhp loss t1 14.026278
Epoch 5/20, Batch 1401/1500, Loss 61.680401, Loss rec 24.104862, loss rec t1 77.133842, loss kl 0.995898, loss_trans 0.060550, loss flux 29.485481, loss flux t1 45.970722, prod bhp loss 8.029507, prod bhp loss t1 18.379951


Epoch 5/20, Train loss 49.105812, Eval loss 93.977654


Epoch 6/20, Batch 1/1500, Loss 53.033875, Loss rec

Epoch 8/20, Batch 101/1500, Loss 39.320084, Loss rec 13.972629, loss rec t1 92.861191, loss kl 0.307724, loss_trans 0.062860, loss flux 20.601231, loss flux t1 36.140800, prod bhp loss 4.683363, prod bhp loss t1 22.195560
Epoch 8/20, Batch 201/1500, Loss 50.775909, Loss rec 23.319038, loss rec t1 70.410522, loss kl 0.778952, loss_trans 0.049519, loss flux 23.707146, loss flux t1 41.057819, prod bhp loss 3.700208, prod bhp loss t1 13.330336
Epoch 8/20, Batch 301/1500, Loss 52.771805, Loss rec 16.908665, loss rec t1 32.410778, loss kl 0.702722, loss_trans 0.015178, loss flux 30.023849, loss flux t1 31.479666, prod bhp loss 5.824113, prod bhp loss t1 15.194958
Epoch 8/20, Batch 401/1500, Loss 63.801495, Loss rec 29.286896, loss rec t1 47.881599, loss kl 0.502374, loss_trans 0.014386, loss flux 27.456240, loss flux t1 28.946518, prod bhp loss 7.043976, prod bhp loss t1 13.538947
Epoch 8/20, Batch 501/1500, Loss 57.572712, Loss rec 21.764191, loss rec t1 74.737465, loss kl 0.719887, loss_tr

Epoch 10/20, Batch 701/1500, Loss 43.583588, Loss rec 15.938660, loss rec t1 20.009575, loss kl 0.677735, loss_trans 0.007046, loss flux 23.795338, loss flux t1 25.291010, prod bhp loss 3.842543, prod bhp loss t1 7.353870
Epoch 10/20, Batch 801/1500, Loss 36.821091, Loss rec 11.931694, loss rec t1 26.440323, loss kl 0.451307, loss_trans 0.020315, loss flux 21.748516, loss flux t1 23.743818, prod bhp loss 3.120565, prod bhp loss t1 13.807207
Epoch 10/20, Batch 901/1500, Loss 46.482677, Loss rec 14.418891, loss rec t1 25.179073, loss kl 1.411686, loss_trans 0.014353, loss flux 27.189486, loss flux t1 28.615213, prod bhp loss 4.859947, prod bhp loss t1 9.930984
Epoch 10/20, Batch 1001/1500, Loss 41.371861, Loss rec 14.274059, loss rec t1 30.917412, loss kl 0.281029, loss_trans 0.014850, loss flux 22.767321, loss flux t1 25.473013, prod bhp loss 4.315629, prod bhp loss t1 15.002759
Epoch 10/20, Batch 1101/1500, Loss 61.606773, Loss rec 24.161793, loss rec t1 43.460632, loss kl 1.239890, lo

Epoch 12/20, Batch 1301/1500, Loss 38.634052, Loss rec 14.498227, loss rec t1 24.007837, loss kl 0.355383, loss_trans 0.012748, loss flux 19.863861, loss flux t1 21.430305, prod bhp loss 4.259217, prod bhp loss t1 9.236208
Epoch 12/20, Batch 1401/1500, Loss 33.453888, Loss rec 9.778265, loss rec t1 65.110550, loss kl 0.883693, loss_trans 0.055392, loss flux 18.585289, loss flux t1 34.231232, prod bhp loss 5.034940, prod bhp loss t1 18.295868


Epoch 12/20, Train loss 26.364985, Eval loss 67.427391


Epoch 13/20, Batch 1/1500, Loss 28.386251, Loss rec 9.461993, loss rec t1 46.432938, loss kl 0.693724, loss_trans 0.046725, loss flux 16.254896, loss flux t1 29.622316, prod bhp loss 2.622639, prod bhp loss t1 15.296923
Epoch 13/20, Batch 101/1500, Loss 28.866442, Loss rec 10.129685, loss rec t1 40.151329, loss kl 0.349053, loss_trans 0.027138, loss flux 16.100250, loss flux t1 28.616247, prod bhp loss 2.609369, prod bhp loss t1 15.743257
Epoch 13/20, Batch 201/1500, Loss 31.467731, Loss re

Epoch 15/20, Batch 301/1500, Loss 37.096344, Loss rec 11.540946, loss rec t1 31.350548, loss kl 0.804325, loss_trans 0.023568, loss flux 20.526648, loss flux t1 23.087677, prod bhp loss 5.005183, prod bhp loss t1 15.189991
Epoch 15/20, Batch 401/1500, Loss 29.466064, Loss rec 8.770996, loss rec t1 17.045967, loss kl 0.563765, loss_trans 0.015346, loss flux 17.832813, loss flux t1 20.074335, prod bhp loss 2.846910, prod bhp loss t1 10.201005
Epoch 15/20, Batch 501/1500, Loss 40.345398, Loss rec 16.898621, loss rec t1 48.211502, loss kl 0.848844, loss_trans 0.039446, loss flux 19.501198, loss flux t1 21.212622, prod bhp loss 3.906133, prod bhp loss t1 20.724104
Epoch 15/20, Batch 601/1500, Loss 50.018139, Loss rec 21.421396, loss rec t1 23.216169, loss kl 1.505534, loss_trans 0.011007, loss flux 21.581827, loss flux t1 22.612822, prod bhp loss 7.003912, prod bhp loss t1 7.541952
Epoch 15/20, Batch 701/1500, Loss 35.246037, Loss rec 11.677794, loss rec t1 17.415852, loss kl 0.742031, loss

Epoch 17/20, Batch 901/1500, Loss 35.839592, Loss rec 11.602075, loss rec t1 17.429419, loss kl 1.739032, loss_trans 0.012710, loss flux 21.138317, loss flux t1 22.597641, prod bhp loss 3.086489, prod bhp loss t1 5.980621
Epoch 17/20, Batch 1001/1500, Loss 31.173719, Loss rec 10.167690, loss rec t1 15.198217, loss kl 0.353552, loss_trans 0.011197, loss flux 17.612162, loss flux t1 19.178368, prod bhp loss 3.382669, prod bhp loss t1 8.890501
Epoch 17/20, Batch 1101/1500, Loss 45.128357, Loss rec 21.578274, loss rec t1 31.803448, loss kl 1.411554, loss_trans 0.014703, loss flux 20.208590, loss flux t1 21.595541, prod bhp loss 3.326788, prod bhp loss t1 10.182642
Epoch 17/20, Batch 1201/1500, Loss 34.031891, Loss rec 10.426437, loss rec t1 14.837912, loss kl 2.208182, loss_trans 0.012558, loss flux 20.717802, loss flux t1 21.112904, prod bhp loss 2.875094, prod bhp loss t1 7.329598
Epoch 17/20, Batch 1301/1500, Loss 27.969257, Loss rec 8.051281, loss rec t1 14.146624, loss kl 0.424138, lo



Epoch 19/20, Train loss 20.004053, Eval loss 59.401882


Epoch 20/20, Batch 1/1500, Loss 19.585705, Loss rec 5.272161, loss rec t1 20.974812, loss kl 0.839402, loss_trans 0.023843, loss flux 12.676744, loss flux t1 21.373720, prod bhp loss 1.612955, prod bhp loss t1 9.560857
Epoch 20/20, Batch 101/1500, Loss 22.759220, Loss rec 6.350091, loss rec t1 26.035458, loss kl 0.436497, loss_trans 0.024907, loss flux 12.189466, loss flux t1 21.945580, prod bhp loss 4.194754, prod bhp loss t1 11.328128
Epoch 20/20, Batch 201/1500, Loss 22.783590, Loss rec 6.259301, loss rec t1 22.420132, loss kl 1.089322, loss_trans 0.023499, loss flux 14.482325, loss flux t1 23.220360, prod bhp loss 2.018467, prod bhp loss t1 12.043717
Epoch 20/20, Batch 301/1500, Loss 39.307453, Loss rec 12.030291, loss rec t1 24.915689, loss kl 1.007515, loss_trans 0.013969, loss flux 18.606422, loss flux t1 20.888697, prod bhp loss 8.656771, prod bhp loss t1 12.967157
Epoch 20/20, Batch 401/1500, Loss 29.516436, Loss rec 8