In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '6'

# Import and Build Model

In [2]:
from deeprte.experiment import Trainer
import jax
import haiku as hk

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from deeprte.config import get_config
config = get_config()
config = config.experiment_kwargs.config
config.dataset.data_path = "/workspaces/deeprte/rte_data/rte_data/matlab/eval-data/test_shape.mat"
rng = hk.PRNGSequence(jax.random.PRNGKey(42))
e = Trainer(mode = "train", init_rng=0, config = config)

# Data Shape Visualization

In [4]:
import tensorflow as tf
tf.nest.map_structure(lambda x: x.shape, (e.tf_data))

{'sigma': TensorShape([8, 1681, 2]),
 'psi_label': TensorShape([8, 40344]),
 'scattering_kernel': TensorShape([8, 40344, 24]),
 'self_scattering_kernel': TensorShape([8, 24, 24]),
 'boundary': TensorShape([8, 1968]),
 'position_coords': TensorShape([1681, 2]),
 'velocity_coords': TensorShape([24, 2]),
 'phase_coords': TensorShape([40344, 4]),
 'boundary_coords': TensorShape([1968, 4]),
 'boundary_weights': TensorShape([1968]),
 'velocity_weights': TensorShape([24])}

## Train Input

In [5]:
train_input = e._build_train_input()
tf.nest.map_structure(lambda x: x.shape, next(train_input))

Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


{'sigma': (1, 2, 1681, 2),
 'psi_label': (1, 2, 150),
 'scattering_kernel': (1, 2, 150, 24),
 'self_scattering_kernel': (1, 2, 24, 24),
 'boundary': (1, 2, 1968),
 'position_coords': (1, 1681, 2),
 'velocity_coords': (1, 24, 2),
 'phase_coords': (1, 150, 4),
 'boundary_coords': (1, 1968, 4),
 'boundary_weights': (1, 1968),
 'velocity_weights': (1, 24)}

## Val Input

In [6]:
val_input = e._build_eval_input()
tf.nest.map_structure(lambda x: x.shape, next(val_input))

{'sigma': (1, 2, 1681, 2),
 'psi_label': (1, 2, 40344),
 'scattering_kernel': (1, 2, 40344, 24),
 'self_scattering_kernel': (1, 2, 24, 24),
 'boundary': (1, 2, 1968),
 'position_coords': (1, 1681, 2),
 'velocity_coords': (1, 24, 2),
 'phase_coords': (1, 40344, 4),
 'boundary_coords': (1, 1968, 4),
 'boundary_weights': (1, 1968),
 'velocity_weights': (1, 24)}

## Dummy Input for Model Init

In [7]:
ds = e._build_dummy_input()

In [8]:
tf.nest.map_structure(lambda x: x.shape, ds)

{'sigma': (1, 1, 1681, 2),
 'psi_label': (1, 1, 1),
 'scattering_kernel': (1, 1, 1, 24),
 'self_scattering_kernel': (1, 1, 24, 24),
 'boundary': (1, 1, 1968),
 'position_coords': (1, 1681, 2),
 'velocity_coords': (1, 24, 2),
 'phase_coords': (1, 1, 4),
 'boundary_coords': (1, 1968, 4),
 'boundary_weights': (1, 1968),
 'velocity_weights': (1, 24)}

# Test Train Step

In [9]:
from jaxline import utils as jl_utils
init_rng = jl_utils.bcast_local_devices(next(rng))
e._initialize_training()
iterator=tf.data.Dataset.from_tensors(ds).as_numpy_iterator()
e._train_input = iterator

In [10]:
%time scalars = e.step(jl_utils.bcast_local_devices(0), jl_utils.bcast_local_devices(next(rng)))

KeyboardInterrupt: 

# Test Evaluation

In [11]:
e._eval_input = iterator

In [12]:
metrics = e.evaluate(jl_utils.bcast_local_devices(0), jl_utils.bcast_local_devices(next(rng)))

In [None]:
%time metrics = e.evaluate(jl_utils.bcast_local_devices(0), jl_utils.bcast_local_devices(next(rng)))

CPU times: user 2min 9s, sys: 31.7 s, total: 2min 40s
Wall time: 2min 25s
