In [2]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import os
os.chdir("../..")
from parc.data.data import EnergeticMatDataPipeLine as EmData
from parc import misc, metrics, visualization
from parc.model import model


# Data pipeline

In [None]:
# Get data and normalization
state_seq_whole, vel_seq_whole = EmData.clip_raw_data(idx_range = (0,150),sequence_length = 2)
state_seq_norm = EmData.data_normalization(state_seq_whole,3)
vel_seq_norm = EmData.data_normalization(vel_seq_whole,2)


# Model definition

In [8]:
from tensorflow import keras
from tensorflow.keras import  layers, regularizers
from keras.layers import *
import tensorflow as tf
from parc import layer

from tensorflow.keras.layers import Concatenate, Input
from tensorflow.keras.models import Model

"""
Differentiator for EM problems: 
    - state vars including temperature, pressure, microstructure evolution
    - there is no constant field using
"""

def differentiator_em(n_state_var=3):
    # Model initiation
    feature_extraction = layer.feature_extraction_unet(input_shape = (128, 256), n_channel=n_state_var+2)
    
    # Main computation graph
    input_tensor = Input(shape=(128 , 256, n_state_var+2), dtype = tf.float32)

    # Reaction term
    dynamic_feature = feature_extraction(input_tensor)

    x_vel_dot = Conv2D(n_state_var+2,1,padding = 'same')(dynamic_feature)
    
    differentiator = Model(input_tensor, x_vel_dot)
    return differentiator

# def integrator(n_state_var = 3):
#     state_integrators = []
#     for _ in range(n_state_var):
#         state_integrators.append(layer.integrator_cnn(input_shape = (128,192)))

#     velocity_integrator = layer.integrator_cnn(input_shape = (128,192), n_output=2)

#     state_var_prev = keras.layers.Input(shape = (128, 192, n_state_var), dtype = tf.float32)
#     velocity_prev = keras.layers.Input(shape = (128, 192,2), dtype = tf.float32)
    
#     state_var_dot = keras.layers.Input(shape = (128, 192,n_state_var), dtype = tf.float32)
#     velocity_dot = keras.layers.Input(shape = (128, 192,2), dtype = tf.float32)

#     state_var_next = []
        
#     for i in range(n_state_var): 
#         state_var_next.append(state_integrators[i]([state_var_dot[:,:,:,i:i+1], state_var_prev[:,:,:,i:i+1]]))

#     state_var_next = keras.layers.concatenate(state_var_next, axis=-1)
#     velocity_next = velocity_integrator([velocity_dot, velocity_prev])
#     integrator = keras.Model([state_var_dot, velocity_dot, state_var_prev, velocity_prev], [state_var_next, velocity_next])
#     return integrator

class PARC_EM(keras.Model):
    def __init__(self, n_state_var, n_time_step, step_size, solver = "rk4", **kwargs):
        super(PARC_EM, self).__init__(**kwargs)
        self.n_state_var = n_state_var
        self.n_time_step = n_time_step
        self.step_size = step_size
        self.solver = solver
        self.differentiator = differentiator_em(n_state_var=self.n_state_var)
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")

    @property
    def metrics(self):
        return [
        self.total_loss_tracker,
        ]
    
    def call(self, input):
        state_var_init = tf.cast(input[0],dtype = tf.float32)
        velocity_init = tf.cast(input[1], dtype = tf.float32)
        input_seq = Concatenate(axis = -1)([state_var_init, velocity_init])
        input_seq_current = input_seq

        res = []
        for _ in range(self.n_time_step):    
            input_seq_current = self.explicit_update(input_seq_current)                        
            res.append(input_seq_current)
        output = Concatenate(axis = -1)(res)
        return output

    @tf.function
    def train_step(self, data):
        state_var_init = tf.cast(data[0][0],dtype = tf.float32)
        velocity_init = tf.cast(data[0][1], dtype = tf.float32)
        input_seq = Concatenate(axis = -1)([state_var_init, velocity_init])

        state_var_gt = tf.cast(data[1][0], dtype = tf.float32)
        velocity_gt = tf.cast(data[1][1], dtype = tf.float32)

        input_seq_current = input_seq
        with tf.GradientTape() as tape:
            state_whole = []
            vel_whole = []
            for _ in range(self.n_time_step):
                input_seq_current, update = self.explicit_update(input_seq_current)
                state_whole.append(input_seq_current[:,:,:,:3])
                vel_whole.append(input_seq_current[:,:,:,3:])
            state_pred = Concatenate(axis = -1)(state_whole)
            vel_pred = Concatenate(axis = -1)(vel_whole)
                    
            total_loss  = (tf.keras.losses.MeanAbsoluteError(reduction = 'sum')(state_pred,state_var_gt) + 
                            tf.keras.losses.MeanAbsoluteError(reduction = 'sum')(vel_pred,velocity_gt))/2
            
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        self.total_loss_tracker.update_state(total_loss)

        return {
            "total_loss": self.total_loss_tracker.result(),
        }
    
    # Update scheme
    def explicit_update(self, input_seq_current):
        if self.solver == "rk4":
            input_seq_current = self.rk4_update(input_seq_current)
        elif self.solver == 'heun':
            input_seq_current = self.heun_update(input_seq_current)
        else:
            input_seq_current = self.euler_update(input_seq_current)

        return input_seq_current

    def rk4_update(self, input_seq_current):
        input_seq_current = tf.clip_by_value(input_seq_current, 0, 1)

        # Compute k1
        k1 = self.differentiator(input_seq_current)

        # Compute k2
        inp_k2 = input_seq_current + self.step_size*1/2*k1 
        k2 = self.differentiator(inp_k2)

        # Compute k3
        inp_k3 = input_seq_current + self.step_size*1/2*k2
        k3 = self.differentiator(inp_k3)

        # Compute k4
        inp_k4 = input_seq_current + self.step_size*k3
        k4 = self.differentiator(inp_k4)

        # Final
        update = 1/6*(k1 + 2*k2 + 2*k3 + k4)
        input_seq_current = input_seq_current + self.step_size*update 
        return input_seq_current, update
    
    # Euler update function
    def heun_update(self, input_seq_current):
        input_seq_current = tf.clip_by_value(input_seq_current, 0, 1)
        # Compute update
        k1 = self.differentiator(input_seq_current)

        # Compute k2
        inp_k2 = input_seq_current + self.step_size*k1 
        k2 = self.differentiator(inp_k2)
        
        update = 1/2*(k1 + k2)
        input_seq_current = input_seq_current + self.step_size*update 

        return input_seq_current, update
    
    # Euler update function
    def euler_update(self, input_seq_current):
        input_seq_current = tf.clip_by_value(input_seq_current, 0, 1)
        # Compute update
        update = self.differentiator(input_seq_current)
        input_seq_current = input_seq_current + self.step_size*update 

        return input_seq_current, update

# Training


### Stage 1: Differentiator training

In [None]:
# Create tf.dataset
dataset_input = tf.data.Dataset.from_tensor_slices((state_seq_norm[0][:,:,:,:3],vel_seq_norm[0][:,:,:,:2]))
dataset_label = tf.data.Dataset.from_tensor_slices((state_seq_norm[0][:,:,:,3:],vel_seq_norm[0][:,:,:,2:]))
dataset = tf.data.Dataset.zip((dataset_input, dataset_label))
dataset = dataset.shuffle(buffer_size = 2192) 
dataset = dataset.batch(4)

In [9]:
tf.keras.backend.clear_session()
parc = PARC_EM(n_state_var = 3, n_time_step = 5, step_size= 1/15, solver = "rk4")
# parc.differentiator.load_weights('parc_diff.h5')
parc.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = 0.00001, beta_1 = 0.9, beta_2 = 0.999))
parc.fit(dataset, epochs = 50, shuffle = True)

2024-01-07 23:17:55.183888: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1636] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 78791 MB memory:  -> device: 0, name: NVIDIA A100-SXM4-80GB, pci bus id: 0000:b7:00.0, compute capability: 8.0


In [6]:
parc.differentiator.save_weights('parc_diff.h5')

In [None]:
# Define sequence length for training
sequence_length = 3
state_seq_whole, vel_seq_whole = EmData.clip_raw_data(idx_range = (0,150), sequence_length = sequence_length + 1, n_state_var = 3)
state_seq_norm_whole = EmData.data_normalization_test(state_seq_whole,state_seq_norm[1], state_seq_norm[2],3)
vel_seq_norm_whole = EmData.data_normalization_test(vel_seq_whole,vel_seq_norm[1], vel_seq_norm[2],2)

In [5]:
import time
import os
import numpy as np
import skimage
from skimage.measure import block_reduce


def clip_raw_data(idx_range, sequence_length=2, n_state_var=3, purpose = "diff_training"):
    state_seq_whole = []
    vel_seq_whole = []

    for i in range(idx_range[0],idx_range[1]):
        file_path = os.path.join(os.sep,'project','SDS','research', 'Nguyen_storage', 'data', 'single_void_data', f'void_{i}.npy')
        if os.path.exists(file_path):
            raw_data = np.float32(np.load(file_path))
            data_shape = raw_data.shape
            if data_shape[2] > sequence_length:
                print(i)
                npad = ((0, abs(data_shape[0] - 512)), (0, abs(data_shape[1] - 1024)), (0, 0))
                raw_data = np.pad(raw_data, pad_width=npad, mode='edge')
                raw_data = np.expand_dims(raw_data, axis=0)
                raw_data = skimage.measure.block_reduce(raw_data[:,:,:,:], (1,4,4,1),np.max)

                data_shape = raw_data.shape
                num_time_steps = data_shape[-1] // (n_state_var + 2)
                if purpose == "diff_training":
                    j_range = num_time_steps - sequence_length
                else:
                    j_range = 1
                state_seq_case = [np.concatenate([raw_data[:, :, :256, (j + k) * (n_state_var + 2):\
                                                        (j + k) * (n_state_var + 2) + n_state_var] \
                                                        for k in range(sequence_length)], axis=-1) \
                                                        for j in range  (j_range)] 

                vel_seq_case = [np.concatenate([raw_data[:, :, :256, (j + k) * (n_state_var + 2) +  n_state_var :\
                                                        (j + k) * (n_state_var + 2) + n_state_var + 2] \
                                                        for k in range(sequence_length)], axis=-1) \
                                                        for j in range (j_range)] 


                state_seq_whole.extend(state_seq_case)
                vel_seq_whole.extend(vel_seq_case)

    state_seq_whole = np.concatenate(state_seq_whole, axis=0)
    vel_seq_whole = np.concatenate(vel_seq_whole, axis=0)

    return state_seq_whole, vel_seq_whole

# Normalization
def data_normalization(input_data,no_of_channel):
    norm_data = np.zeros(input_data.shape)
    min_val = []
    max_val = []
    for i in range(no_of_channel):
        norm_data[:,:,:,i::no_of_channel] = ((input_data[:,:,:,i::no_of_channel] - np.amin(input_data[:,:,:,i::no_of_channel])) / (np.amax(input_data[:,:,:,i::no_of_channel]) - np.amin(input_data[:,:,:,i::no_of_channel])) + 1E-9)
        min_val.append(np.amin(input_data[:,:,:,i::no_of_channel]))
        max_val.append(np.amax(input_data[:,:,:,i::no_of_channel]))
    return norm_data, min_val, max_val

def data_normalization_test(input_data, min_val, max_val, no_of_channel):
    norm_data = np.zeros(input_data.shape)
    for i in range(no_of_channel):
        norm_data[:,:,:,i::no_of_channel] = ((input_data[:,:,:,i::no_of_channel] - min_val[i]) / (max_val[i] - min_val[i] + 1E-9))
    return norm_data

# Validation

In [None]:
state_seq_whole, vel_seq_whole = clip_raw_data(idx_range = (150,200), sequence_length = 15, n_state_var = 3, purpose = "test")
state_seq_norm_test = data_normalization_test(state_seq_whole, state_seq_norm[1], state_seq_norm[2],3)
vel_seq_norm_test = data_normalization_test(vel_seq_whole, vel_seq_norm[1], vel_seq_norm[2],2)

In [17]:
parc = PARC_EM(n_state_var = 3, n_time_step = 15, step_size= 1/15, solver = "rk4")
parc.differentiator.load_weights('parc_diff_5.h5')

In [None]:
# Make prediction
pred_whole = []
for case_idx in range(34):
    state_var_current = state_seq_norm_test[case_idx:case_idx+1,:,:,0:3]
    velocity_current = vel_seq_norm_test[case_idx:case_idx+1,:,:,0:2]
    pred_state = parc.predict([state_var_current,velocity_current])
    pred_whole.append(pred_state)
pred = np.concatenate(pred_whole, axis = 0)


In [19]:
def data_denormalization(input_data, min_val, max_val, no_of_channel):
    norm_data = np.zeros(input_data.shape)
    for i in range(no_of_channel):
        norm_data[:,:,:,i::no_of_channel] = (input_data[:,:,:,i::no_of_channel] * (max_val[i] - min_val[i] + 1E-9)) + min_val[i]
    return norm_data

In [20]:
min_val_state = np.concatenate([state_seq_norm[1],vel_seq_norm[1]], axis = 0)
max_val_state = np.concatenate([state_seq_norm[2],vel_seq_norm[2]], axis = 0)

In [15]:
pred_out = data_denormalization(pred,min_val_state,max_val_state, no_of_channel = 5)

In [21]:
np.save('./plotting/em/neuralode_em.npy',pred_out)