In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '1'

# Import and Build Model

In [2]:
from deeprte.train import Trainer
import jax
import haiku as hk

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
source_dir = "/workspaces/deeprte/rte_data/matlab/train-scattering-kernel-0204"
data_name_list = ["train_random_kernel_1.mat","train_random_kernel_2.mat"]

In [4]:
from deeprte.config import get_config
config = get_config()
config = config.experiment_kwargs.config
config.dataset.source_dir = source_dir
config.dataset.data_name_list = data_name_list
rng = hk.PRNGSequence(jax.random.PRNGKey(42))
e = Trainer(mode = "train", init_rng=0, config = config)

# Data Shape Visualization

In [5]:
import tensorflow as tf
tf.nest.map_structure(lambda x: x.shape, (e.tf_data))

{'sigma': TensorShape([600, 1521, 2]),
 'psi_label': TensorShape([600, 36504]),
 'scattering_kernel': TensorShape([600, 36504, 24]),
 'boundary_scattering_kernel': TensorShape([600, 1968, 24]),
 'self_scattering_kernel': TensorShape([600, 24, 24]),
 'boundary': TensorShape([600, 1968]),
 'position_coords': TensorShape([1521, 2]),
 'velocity_coords': TensorShape([24, 2]),
 'phase_coords': TensorShape([36504, 4]),
 'boundary_coords': TensorShape([1968, 4]),
 'boundary_weights': TensorShape([1968]),
 'velocity_weights': TensorShape([24])}

## Train Input

In [6]:
train_input = e._build_train_input()
tf.nest.map_structure(lambda x: x.shape, next(train_input))

Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


{'sigma': (1, 4, 1521, 2),
 'psi_label': (1, 4, 100),
 'scattering_kernel': (1, 4, 100, 24),
 'boundary_scattering_kernel': (1, 4, 1968, 24),
 'self_scattering_kernel': (1, 4, 24, 24),
 'boundary': (1, 4, 1968),
 'position_coords': (1, 1521, 2),
 'velocity_coords': (1, 24, 2),
 'phase_coords': (1, 100, 4),
 'boundary_coords': (1, 1968, 4),
 'boundary_weights': (1, 1968),
 'velocity_weights': (1, 24),
 'sampled_boundary_coords': (1, 50, 4),
 'sampled_boundary': (1, 4, 50),
 'sampled_boundary_scattering_kernel': (1, 4, 50, 24)}

## Val Input

In [7]:
val_input = e._build_eval_input()
tf.nest.map_structure(lambda x: x.shape, next(val_input))

{'sigma': (1, 4, 1521, 2),
 'psi_label': (1, 4, 36504),
 'scattering_kernel': (1, 4, 36504, 24),
 'boundary_scattering_kernel': (1, 4, 1968, 24),
 'self_scattering_kernel': (1, 4, 24, 24),
 'boundary': (1, 4, 1968),
 'position_coords': (1, 1521, 2),
 'velocity_coords': (1, 24, 2),
 'phase_coords': (1, 36504, 4),
 'boundary_coords': (1, 1968, 4),
 'boundary_weights': (1, 1968),
 'velocity_weights': (1, 24)}

## Dummy Input for Model Init

In [8]:
ds = e._build_dummy_input()

In [9]:
tf.nest.map_structure(lambda x: x.shape, ds)

{'sigma': (1, 1, 1521, 2),
 'psi_label': (1, 1, 1),
 'scattering_kernel': (1, 1, 1, 24),
 'boundary_scattering_kernel': (1, 1, 1968, 24),
 'self_scattering_kernel': (1, 1, 24, 24),
 'boundary': (1, 1, 1968),
 'position_coords': (1, 1521, 2),
 'velocity_coords': (1, 24, 2),
 'phase_coords': (1, 1, 4),
 'boundary_coords': (1, 1968, 4),
 'boundary_weights': (1, 1968),
 'velocity_weights': (1, 24),
 'sampled_boundary_coords': (1, 1, 4),
 'sampled_boundary': (1, 1, 1),
 'sampled_boundary_scattering_kernel': (1, 1, 1, 24)}

# Test Train Step

In [10]:
from jaxline import utils as jl_utils
init_rng = jl_utils.bcast_local_devices(next(rng))
e._initialize_training()
# iterator=tf.data.Dataset.from_tensors(ds).as_numpy_iterator()
# e._train_input = train_input

In [11]:
e._train_input = train_input

In [12]:
scalars = e.step(jl_utils.bcast_local_devices(0), jl_utils.bcast_local_devices(next(rng)))

2023-03-08 11:37:52.037174: W external/org_tensorflow/tensorflow/tsl/framework/bfc_allocator.cc:290] Allocator (GPU_0_bfc) ran out of memory trying to allocate 9.02GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 9689890816 bytes.

In [None]:
scalars

{'learning_rate': DeviceArray(0.0008, dtype=float32),
 'train_mse': DeviceArray(0.06758966, dtype=float32),
 'train_rmspe': DeviceArray(64.88431, dtype=float32)}

# Test Evaluation

In [None]:
e._eval_input = iterator

In [None]:
metrics = e.evaluate(jl_utils.bcast_local_devices(0), jl_utils.bcast_local_devices(next(rng)))

In [None]:
metrics

{'eval_mse': 0.0852733626961708, 'eval_rmspe': 124.36205291748047}