In [4]:
import numpy as np
import tensorflow as tf

In [7]:
def preprocess(
        state: tf.Tensor,
        varsigma: tf.Tensor = tf.constant(0.02, dtype=tf.float32),
    ) -> tf.Tensor:
        """Applies time transformation as in Traeger (2014) and downscales carbon reservoirs from GtCO2 to 1000GtCO2.

        Time scaled as: hat(t) = -exp(-varsigma t)

        Args:
            state (tf.Tensor): Shape [batch, state, state variables]

        Returns:
            preprocessed states (tf.Tensor): Shape [batch, state, state variables]
        """
        # Time located at index 6, where we select index 6 to, but not including 7.
        t = state[..., 6:7]
        hat_t = -tf.exp(-varsigma * t) + 1

        m = state[..., 1:4]
        m_scaled = m/1000

        preprocessed_state = tf.concat([
            state[..., :1],
            m_scaled,
            state[..., 4:6],
            hat_t
        ], axis=-1)

        return preprocessed_state


In [12]:
def test_preprocess():
    # Create a test input tensor
    mock_state = tf.constant([
        [0.1, 2000.0, 3000.0, 4000.0, 0.5, 0.6, 1.0],  # Sample 1
        [0.2, 5000.0, 6000.0, 7000.0, 1.5, 1.6, 2.0],  # Sample 2
    ], dtype=tf.float32)
    varsigma = tf.constant(0.02, dtype=tf.float32)

    # Expected output
    expected_m_scaled = mock_state[..., 1:4] / 1000
    hat_t = -tf.exp(-varsigma * mock_state[..., 6:7]) + 1
    expected_output = tf.concat([
        mock_state[..., :1],
        expected_m_scaled,
        mock_state[..., 4:6],
        hat_t
    ], axis=-1)

    # Run the preprocess function
    processed_state = preprocess(mock_state, varsigma)

    # Check if the output matches the expected output
    np.testing.assert_array_almost_equal(processed_state.numpy(), expected_output.numpy(), decimal=5)
    print("Test passed for preprocess function.")
    print(processed_state)

# Execute the test
test_preprocess()

Test passed for preprocess function.
tf.Tensor(
[[0.1        2.         3.         4.         0.5        0.6
  0.01980132]
 [0.2        5.         6.         7.         1.5        1.6
  0.03921056]], shape=(2, 7), dtype=float32)
