In [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import os
os.chdir(".")
from parc.model import model_burgers

# Data pipeline

In [7]:
R_list = [1000, 2500, 5000, 7500, 10000]
a_list = [0.5, 0.6, 0.7, 0.8, 0.9]
w_list = [0.7, 0.8, 0.9, 1.0]

def clip_raw_data(sequence_length=2):
    vel_seq_whole = []
    # Looping through the file list
    for R in R_list:
        for a in a_list:
            for w in w_list:
                data_file_name = 'burgers_train_' + str(int(R)) + '_' + str(int(a*10)) + '_' + str(int(w*10)) + '.npy'
                file_path = './data/burgers/train_data/' + data_file_name                
                if os.path.exists(file_path):
                    # Load data
                    raw_data = np.float32(np.load(file_path))
                    # Reorganize tensor shape
                    raw_data = np.moveaxis(raw_data,-2,0)
                    data_shape = raw_data.shape
                    num_time_steps = data_shape[0]
                    norm_r = R/15000
                    r_img = norm_r*np.ones(shape = (1,data_shape[1],data_shape[2],1))
                    
                    # Reorganize tensor shape
                    for j in range (num_time_steps-sequence_length):
                        # Assemble first step
                        init_snapshot = np.concatenate([raw_data[j:j+1, :, :, :],r_img],axis = -1)
                        # Collect the rest
                        following_snapshot = []
                        for k in range(sequence_length-1):
                            following_snapshot.append(raw_data[(j + k +1):(j + k + 2), :, :, :])
                        following_snapshot = np.concatenate(following_snapshot,axis = -1)
                        # Assemble all
                        vel_seq_case = np.concatenate([init_snapshot,following_snapshot],axis = -1)
                        vel_seq_whole.append(vel_seq_case)

    vel_seq_whole = np.concatenate(vel_seq_whole, axis=0)
    return vel_seq_whole

seq_clipped = clip_raw_data(3)

# Training


### Stage 1: Differentiator training

In [8]:
# Create tf.dataset
dataset_input = tf.data.Dataset.from_tensor_slices(seq_clipped[:,:,:,:3])
dataset_label = tf.data.Dataset.from_tensor_slices(seq_clipped[:,:,:,3:])
dataset = tf.data.Dataset.zip((dataset_input, dataset_label))
dataset = dataset.shuffle(buffer_size = 10000) 
dataset = dataset.batch(8)

In [None]:
tf.keras.backend.clear_session()
parc = model_burgers.PARCv2_burgers(n_time_step = 1, step_size= 1/100, solver = "heun", mode = "differentiator_training")
parc.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = 0.00005, beta_1 = 0.9, beta_2 = 0.999))
parc.fit(dataset, epochs = 1, shuffle = True)

In [None]:
parc.differentiator.save_weights('./pretrained_weights/burgers/parc2_diff_burgers_heun.h5')

### Stage 2: Data-driven integration training

In [None]:
# Pretrain integrator
tf.keras.backend.clear_session()
parc = model_burgers.PARCv2_burgers(n_time_step = 2, step_size= 1/100, solver = "heun", mode = "integrator_training")
parc.differentiator.load_weights('./pretrained_weights/burgers/parc2_diff_burgers_heun.h5')
parc.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = 0.00001, beta_1 = 0.9, beta_2 = 0.999))
parc.fit(dataset, epochs = 1, shuffle = True)

In [None]:
parc.integrator.save_weights('./pretrained_weights/burgers/parc2_int_burgers_heun.h5')

# Validation

In [15]:
R_list = [100, 500, 3000, 6500, 12500, 15000]
a_list = [0.35, 0.40, 0.45, 0.55, 0.65, 0.75, 0.85, 0.95, 1.00]
w_list = [0.55, 0.6, 0.65, 0.75, 0.85, 0.95, 1.05]

def clip_raw_data_for_validation(sequence_length=2):
    vel_seq_whole = []

    for R in R_list:
        for a in a_list:
            for w in w_list:
                data_file_name = 'burgers_test_' + str(int(R)) + '_' + str(int(a*100)) + '_' + str(int(w*100)) + '.npy'
                file_path = '.././data/burgers/test_data/' + data_file_name
                if os.path.exists(file_path):
                    raw_data = np.float32(np.load(file_path))
                    raw_data = np.moveaxis(raw_data,-2,0)
                    data_shape = raw_data.shape
                    norm_r = R/15000
                    r_img = norm_r*np.ones(shape = (1,data_shape[1],data_shape[2],1))            
                    for j in range (1):
                        # Assemble first step
                        init_snapshot = np.concatenate([raw_data[j:j+1, :, :, :],r_img],axis = -1)
                        # Collect the rest
                        following_snapshot = []
                        for k in range(sequence_length-1):
                            following_snapshot.append(raw_data[(j + k +1):(j + k + 2), :, :, :])
                        following_snapshot = np.concatenate(following_snapshot,axis = -1)
                        # Assemble all
                        vel_seq_case = np.concatenate([init_snapshot,following_snapshot],axis = -1)
                        vel_seq_whole.append(vel_seq_case)
    vel_seq_whole = np.concatenate(vel_seq_whole, axis=0)
    return vel_seq_whole

seq_clipped_test = clip_raw_data_for_validation(sequence_length = 101)

In [16]:
seq_clipped_test.shape

(378, 64, 64, 203)

## Load model

### Option 1: from weight files

In [2]:
tf.keras.backend.clear_session()
parc = model_burgers.PARCv2_burgers(n_time_step = 99, step_size= 1/100, solver = "heun")
parc.differentiator.load_weights('./pretrained_weights/burgers/parc2_diff_burgers_heun.h5')
parc.integrator.load_weights('./pretrained_weights/burgers/parc2_int_burgers_heun.h5')

### Option 2: from model files

In [None]:
loaded_parc = tf.keras.models.load_model('./pretrained_weights/burgers/parcv2_burgers.keras')

## Make prediction on the test set

In [None]:
prediction_data = []
for j in range(50):
    input_seq_current = tf.cast(seq_clipped_test[j:j+1,:,:,:3], dtype = tf.float32)
    res = loaded_parc.predict(input_seq_current)
    prediction_data.append(res)
    print('Finish case ', j)
prediction_data = np.concatenate(prediction_data, axis = 0)

In [38]:
prediction_data = np.squeeze(prediction_data)

In [8]:
np.save('./plotting/burgers/parcv2_burgers.npy',prediction_data)