This is an interactive workflow for E2C training and evaluation. Note that this specific case may generate results that are distinct from the paper.

During the training process (while the last cell is running), you can monitor the training status with Tensorboard. Make sure `tensorboard` is installed properly. To install `tensorboard`:  
`pip install tensorboard`  

All the data used for `tensorboard` are stored in `logs/` directory. If you do not have `logs/` directory in your cloned repo, please create one. To turn on `tensorboard`:  
`tensorboard --logdir=logs --port=5678` (`--port` is necesary for port-forwarding)


Zhaoyang Larry Jin  
Stanford University  
zjin@stanford.edu

This notebook has three sections: `0. E2C setup`, `1. E2C Training`, and `2. E2C Eval`

A typical workflow is `sec 0` -> `sec 1` -> `sec 2`.

If you have already run `sec 1` before and have saved the model weights, you can do `sec 0` -> `sec 2`.

# Section 0: E2C setup

### Step 1. Load libaraies and config hardware (gpu)

In [None]:
import numpy as np
import h5py
import tensorflow as tf
from datetime import datetime

from e2c import E2C
from loss import CustomizedLoss
from ROMWithE2C import ROMWithE2C
import matplotlib.pyplot as plt
plt.rcParams['image.cmap'] = 'jet'
import timeit

In [None]:
print(tf.__version__)

In [None]:
devices = tf.config.list_physical_devices()
for device in devices:
    print(device.name)

In [None]:
# Set up some global variables
USE_GPU = len(tf.config.list_physical_devices('GPU'))

if USE_GPU:
    device = '/device:GPU:0'
    # you can either do with or without '/device:'
else:
    device = '/device:CPU:0'

print('Using device: ', device)

## Step 2. Specify params and filenames

In [None]:
################### case specification ######################

data_dir = '../data/'
output_dir = './saved_models/'

case_name = '9w_ms_bhp_rate'
case_suffix = '_fix_wl_rel_8'
train_suffix = '_with_p'
model_suffix = '_flux_loss'

n_train_run = 300
n_eval_run = 100
num_t = 20 
dt = 100
n_train_step = n_train_run * num_t
n_eval_step = n_eval_run * num_t


train_file = case_name + '_e2c_train' + case_suffix + train_suffix + '_n%d_dt%dday_nt%d_nrun%d.mat' %(n_train_step, dt, num_t, n_train_run)
eval_file = case_name + '_e2c_eval' + case_suffix + train_suffix +'_n%d_dt%dday_nt%d_nrun%d.mat' %(n_eval_step, dt, num_t, n_eval_run)

In [None]:
#################### model specification ##################
epoch = 10
batch_size = 4
learning_rate = 1e-4
latent_dim = 50

u_dim = 9*2 # control dimension, gaussian 9 wells

In [None]:
num_train = 6000
num_eval = 2000

In [None]:
input_shape = (60, 60, 2)
perm_shape = (60, 60, 1)
prod_loc_shape = (5, 2)

In [None]:
encoder_file = output_dir + 'e2c_encoder_dt_'+case_name+case_suffix+train_suffix+model_suffix+'_nt%d_l%d_lr%.0e_ep%d.h5' % (num_train, latent_dim, learning_rate, epoch)
decoder_file = output_dir + 'e2c_decoder_dt_'+case_name+case_suffix+train_suffix+model_suffix+'_nt%d_l%d_lr%.0e_ep%d.h5' % (num_train, latent_dim, learning_rate, epoch)
transition_file = output_dir + 'e2c_transition_dt_'+case_name+case_suffix+train_suffix+model_suffix+'_nt%d_l%d_lr%.0e_ep%d.h5' % (num_train, latent_dim, learning_rate, epoch)

print("encoder_file:", encoder_file)
print("decoder_file:", decoder_file)
print("transition_file:", transition_file)

## Step 3. Construct E2C model

In [None]:
my_rom = ROMWithE2C(latent_dim, 
                    u_dim, 
                    input_shape, 
                    perm_shape, 
                    prod_loc_shape, 
                    learning_rate,
                    sigma=0.0)

# Section 1: E2C Training

## Load state data

In [None]:
hf_r = h5py.File(data_dir + train_file, 'r')
state_t_train = np.array(hf_r.get('state_t'))
state_t1_train = np.array(hf_r.get('state_t1'))
bhp_train = np.array(hf_r.get('bhp'))
dt_train = np.array(hf_r.get('dt'))
hf_r.close()

assert num_train == state_t_train.shape[0], "num_train not match!"
# dt_train = np.ones((num_train,1)) # dt=20days, normalized to 1

hf_r = h5py.File(data_dir + eval_file, 'r')
state_t_eval = np.array(hf_r.get('state_t'))
state_t1_eval = np.array(hf_r.get('state_t1'))
bhp_eval = np.array(hf_r.get('bhp'))
dt_eval = np.array(hf_r.get('dt'))
hf_r.close()

print("state_t_eval.shape: ", state_t_eval.shape)
print("state_t1_eval.shape: ", state_t1_eval.shape)
print("bhp_eval.shape: ", bhp_eval.shape)
print("dt_eval.shape: ", dt_eval.shape)


assert num_eval == state_t_eval.shape[0], "num_eval not match!"
# dt_eval = np.ones((num_eval, 1)) # dt=20days, normalized to 1

num_batch = int(num_train/batch_size)
print("num_batch: ", num_batch)

## Load permeability data

In [None]:
m = np.loadtxt(data_dir + "template/logk1.dat") # Gaussian

m = m.reshape(60, 60, 1)
print('m shape is ', m.shape)

m_eval = np.repeat(np.expand_dims(m, axis = 0), state_t_eval.shape[0], axis = 0)
print("m_eval shape is ", m_eval.shape)

m = np.repeat(np.expand_dims(m,axis = 0), state_t_train.shape[0], axis = 0)
print("m shape is ", m.shape)

## Load well location data

In [None]:
well_loc_file = data_dir + 'template/well_loc00.dat'

well_loc = np.loadtxt(well_loc_file).astype(int)
num_prod = well_loc[0,0]
num_inj = well_loc[0,1]
num_well = num_prod+num_inj
print(num_inj, num_prod)

prod_loc = well_loc[1:num_prod+1,:]
print("prod_loc:\n{}".format(prod_loc))
print(prod_loc.shape)

print('prod_loc shape is ', prod_loc.shape)
# prod_loc_tf = tf.placeholder(tf.int32, shape=(num_prod,2))


In [None]:
## used to generate log directory
currentDT = datetime.now()
current_time = str(currentDT).replace(" ", "-")[:-10]
print(current_time)
summary_writer = tf.summary.create_file_writer('logs/' + case_name + case_suffix + '_ep' + str(epoch) + '_tr' + str(n_train_run) + '_' + current_time)

# @tf.function
def write_summary(value, tag, writer, global_step):
    with writer.as_default():
        tf.summary.scalar(tag, value, step=global_step)

## Start training process

In [None]:
with tf.device(device):
    for e in range(epoch):
        for ib in range(num_batch):
            ind0 = ib * batch_size

            state_t_batch  = state_t_train[ind0:ind0+batch_size, ...]
            state_t1_batch = state_t1_train[ind0:ind0 + batch_size, ...]
            bhp_batch      = bhp_train[ind0:ind0 + batch_size, ...]
            m_batch        = m[ind0:ind0 + batch_size, ...]
            dt_batch       = dt_train[ind0:ind0 + batch_size, ...]

            inputs = (state_t_batch, bhp_batch, dt_batch, m_batch, prod_loc)
            labels = state_t1_batch

            my_rom.update(inputs, labels)

            n_itr = e * num_train + ib * batch_size + batch_size
            write_summary(my_rom.train_loss.result(), 'train/total_loss', summary_writer, n_itr) # log for tensorboard
            write_summary(my_rom.train_reconstruction_loss.result(), 'train/reconstruction_loss', summary_writer, n_itr) # log for tensorboard
            write_summary(my_rom.train_flux_loss.result(), 'train/flux_loss', summary_writer, n_itr) # log for tensorboard
            write_summary(my_rom.train_well_loss.result(), 'train/well_loss', summary_writer, n_itr) # log for tensorboard
            summary_writer.flush()

            if ib % 50 == 0:
                print('Epoch %d/%d, Batch %d/%d, Loss %f,' % (e+1, epoch, ib+1, num_batch, my_rom.train_loss.result()))
                test_inputs = (state_t_eval, bhp_eval, dt_eval, m_eval, prod_loc)
                test_labels = state_t1_eval
                my_rom.evaluate(test_inputs, test_labels)

                write_summary(my_rom.test_loss.result(), 'eval/total_loss', summary_writer, n_itr) # log for tensorboard
                summary_writer.flush()

        print('====================================================')
        print('\n')
        print('Epoch %d/%d, Train loss %f, Eval loss %f' % (e + 1, epoch, my_rom.train_loss.result(), my_rom.test_loss.result()))
        print('\n')
        print('====================================================')

## Save model parameters to file

In [None]:
my_rom.model.saveWeightsToFile(encoder_file, decoder_file, transition_file)

# Seciton 2: E2C Test (Eval)

## Load ROM with E2C model (if you did not run section 1)

In [None]:
my_rom.model.loadWeightsFromFile(encoder_file, decoder_file, transition_file)

## Load and manipulate data

In [None]:
target_suffix = '_fix_wl_rel_8' # the dataset being evaluated here
eval_file = case_name + '_e2c_eval' + target_suffix + train_suffix + '_n%d_dt%dday_nt%d_nrun%d.mat'%(n_eval_step, dt, num_t, n_eval_run)

state_file = case_name + '_train_n_400_full'
ctrl_file = case_name + '_norm_bhps_n_400'

state_data = state_file + target_suffix + '.mat'
ctrl_data = ctrl_file + target_suffix + '.mat'

In [None]:
hf_r = h5py.File(data_dir + state_data, 'r')
sat = np.array(hf_r.get('sat'))
pres = np.array(hf_r.get('pres'))
hf_r.close()

In [None]:
hf_r = h5py.File(data_dir + ctrl_data, 'r')
bhp0 = np.array(hf_r.get('bhp'))
rate0 = np.array(hf_r.get('rate'))
hf_r.close()

In [None]:
bhp = np.concatenate((bhp0,rate0),axis=1)
print(bhp.shape)

In [None]:
sat = sat.T.reshape((400, 201, 3600))
pres = pres.T.reshape((400, 201, 3600))

In [None]:
test_case0 = np.zeros((25,4))
a = np.array(range(75,400,100))[np.newaxis,:]
b = np.array(range(25))[:,np.newaxis]

test_case = (test_case0 + a + b).T.reshape(100)
test_case = np.array(test_case).astype(int)

In [None]:
m = np.loadtxt(data_dir + "template/logk1.dat") # Gaussian
m = m.reshape(60, 60, 1)
print('perm shape is ', m.shape)

## Pick 4 representative test cases to visualize  
Note here we have 100 test cases. In the E2C sequential workflow, prediction are done for all of them. However, to keep the notebook clean and short, we will only visualize a subset of 100.

In [None]:
ind_case = np.array([10, 25, 77, 97])

In [None]:
num_case = test_case.shape[0] # 4
num_tstep = 20
sat_pred = np.zeros((num_case, num_tstep, 60, 60, 1))
pres_pred = np.zeros((num_case, num_tstep, 60, 60, 1))

num_prod = 5
num_inj = 4
num_well = num_prod + num_inj

num_all_case = 400
num_ctrl = 20

Specify timesteps, time intervals, etc.

## Reshape the input data  
To a proper format, so that it can be easily consumed by E2C model

In [None]:
t_steps = np.arange(0,200,200//num_tstep)

dt = 10
t_steps1 = (t_steps + dt).astype(int)

indt_del = t_steps1 - t_steps
indt_del = indt_del / max(indt_del)

tmp = np.array(range(num_tstep)) - 1
tmp1 = np.array(range(num_tstep))
tmp[0] = 0

In [None]:
bhp_b0 = bhp.reshape(num_all_case, num_well, num_ctrl)
bhp_b1 = np.repeat(bhp_b0[..., np.newaxis], num_tstep // num_ctrl, axis=3)
assert num_tstep // num_ctrl * num_ctrl == num_tstep, "no exaxt division num_step = %d, num_ctrl=%d"%(num_tstep, num_ctrl)

bhp_b2 = bhp_b1.reshape(num_all_case, num_well, num_tstep)

bhp_tt = bhp_b2[:,:, tmp]
bhp_tt1 = bhp_b2[:,:, tmp1]

bhp_tt0 = np.concatenate((bhp_tt, bhp_tt1), axis=1)
bhp_t = np.swapaxes(bhp_tt0,1,2)

bhp_seq = bhp_t[test_case, :, :]

In [None]:
sat_t_seq = sat[test_case, 0, :].reshape((num_case, 60, 60, 1)) # 4 here is the 4th timestep, t = 8
pres_t_seq = pres[test_case, 0, :].reshape((num_case, 60, 60, 1))

state_t_seq = np.concatenate((sat_t_seq, pres_t_seq),axis=3)
state_pred = np.concatenate((sat_pred, pres_pred),axis=4)

In [None]:
m_t_seq = np.repeat(np.expand_dims(m, axis = 0), state_t_seq.shape[0], axis = 0)

In [None]:
prod_loc_t_seq = np.repeat(np.expand_dims(prod_loc, axis = 0), state_t_seq.shape[0], axis = 0)

## E2C sequential workflow

In [None]:
start = timeit.default_timer()

for i_tstep in range(num_tstep):
    state_pred[:, i_tstep, ...] = state_t_seq.copy()
    dt_seq = np.ones((num_case,1)) * indt_del[i_tstep]
    inputs = (state_t_seq, bhp_seq[:,i_tstep,:], dt_seq, m_t_seq, prod_loc_t_seq)
    state_t1_seq = my_rom.predict(inputs)
    state_t_seq = state_t1_seq.copy()

end = timeit.default_timer()
print("Time for sequential process: %f" %(end - start))

## Visualization

In [None]:
# sat_seq_true = sat[test_case[ind_case], ...]
sat_seq_true = sat[test_case, ...]
sat_seq_true = sat_seq_true[:, list(np.arange(0,200,10)), :]

# pres_seq_true = pres[test_case[ind_case], ...]
pres_seq_true = pres[test_case, ...]
pres_seq_true = pres_seq_true[:, list(np.arange(0,200,10)), :]
state_seq_true = np.zeros((len(test_case),20,3600,2))
state_seq_true[:,:,:,0] = sat_seq_true
state_seq_true[:,:,:,1] = pres_seq_true

### Visualization for saturation

In [None]:
s_max = 1
s_min = 0
s_diff = s_max - s_min

In [None]:
sat_pred_plot = state_pred[:, :, :, :, 0] * s_diff + s_min
state_pred[:, :, :, :, 0] = state_pred[:, :, :, :, 0] * s_diff + s_min

In [None]:
divide = 2
for k in range(4):
    print("Case num: %d"%ind_case[k])
    plt.figure(figsize=(16,5))
    for i_tstep in range(len(t_steps)//divide):
        plt.subplot(3, num_tstep//divide, i_tstep+1)
        plt.imshow(sat_pred_plot[ind_case[k], i_tstep*divide, :,:])
        plt.title('t=%d'%(t_steps[i_tstep*divide]*dt))
        plt.xticks([])
        plt.yticks([])
        plt.clim([0.1, 0.7])
        if i_tstep == 9:
            plt.colorbar(fraction=0.046) 
            
        
        plt.subplot(3, num_tstep//divide, i_tstep+1+num_tstep//divide)
        plt.imshow(state_seq_true[ind_case[k], i_tstep*divide, :, 0].reshape((60,60)))
        plt.xticks([])
        plt.yticks([])
        plt.clim([0.1, 0.7])
        if i_tstep == 9:
            plt.colorbar(fraction=0.046)         
        
        plt.subplot(3, num_tstep//divide, i_tstep+1+2*num_tstep//divide)
        plt.imshow(np.fabs(state_seq_true[ind_case[k], i_tstep*divide, :, 0].reshape((60,60)) - sat_pred_plot[ind_case[k], i_tstep*divide, :,:]))
        plt.xticks([])
        plt.yticks([])
        plt.clim([0, 0.15])
        if i_tstep == 9:
            plt.colorbar(fraction=0.046) 

    plt.show()

### Visualization for pressure

In [None]:
p_max = 425
p_min = 250
p_diff = p_max - p_min

In [None]:
state_pred_plot = state_pred[:, :, :, :, 1] * p_diff + p_min
state_seq_true_plot = state_seq_true[:, :, :, 1] * p_diff + p_min

In [None]:
divide = 2
for k in range(4):
    print("Case num: %d"%ind_case[k])
    plt.figure(figsize=(16,5))
    for i_tstep in range(len(t_steps)//divide):
        plt.subplot(3, num_tstep//divide, i_tstep+1)
        plt.imshow(state_pred_plot[ind_case[k], i_tstep*divide, :, :])
        plt.title('t=%d'%(t_steps[i_tstep*divide]*dt))
        plt.xticks([])
        plt.yticks([])
#         plt.clim([4150, 4650])
        if i_tstep == 9:
            plt.colorbar(fraction=0.046) 
            
        
        plt.subplot(3, num_tstep//divide, i_tstep+1+num_tstep//divide)
        plt.imshow(state_seq_true_plot[ind_case[k], i_tstep*divide, :].reshape((60,60)))
        plt.xticks([])
        plt.yticks([])
#         plt.clim([4150, 4650])
        if i_tstep == 9:
            plt.colorbar(fraction=0.046)         
        
        plt.subplot(3, num_tstep//divide, i_tstep+1+2*num_tstep//divide)
        plt.imshow(np.fabs(state_seq_true_plot[ind_case[k], i_tstep*divide, :].reshape((60,60)) - state_pred_plot[ind_case[k], i_tstep*divide, :,:]))
        plt.xticks([])
        plt.yticks([])
#         plt.clim([0, 0.02])
        if i_tstep == 9:
            plt.colorbar(fraction=0.046) 

    plt.show()