In [10]:
import os
os.environ['JAX_PLATFORMS'] = 'cpu'

import jax.numpy as jnp
from jax import vmap

from pidon.dataset import DataGenerator

In [3]:
N = 2    # Number of input samples in the training data-set
m = 3    # Number of input sensors (locations for evaluating the input functions u)
P = 4    # Number of output sensors (locations for evaluating the output functions G(u))
Q = 5    # Number of collocation points for evaluating the PDE residual

In [15]:
# (NxP, m)
def one_u(v, m, P):
    u = v*jnp.ones((m,))
    u = jnp.tile(u, (P, 1))
    return u

u = vmap(lambda v: one_u(v, m, P))(jnp.array([[1], [2]]))
u = jnp.float32(u.reshape(N*P, -1))

print(u.shape)
print(u)

(8, 3)
[[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]
 [2. 2. 2.]
 [2. 2. 2.]
 [2. 2. 2.]
 [2. 2. 2.]]


In [43]:
# (NxP, 1)
def one_y(v, P):
    y = v*jnp.linspace(0, 1, P)
    return y

y = vmap(lambda v: one_y(v, P))(jnp.array([[1], [2]]))
y = jnp.float32(y.reshape(N*P, -1))
print(y.shape)
print(y)

(8, 1)
[[0.        ]
 [0.33333334]
 [0.6666667 ]
 [1.        ]
 [0.        ]
 [0.6666667 ]
 [1.3333334 ]
 [2.        ]]


In [44]:
# (NxP, 1)
def one_s(v, P):
    s = v*jnp.linspace(0, 1, P)
    return s

s = vmap(lambda v: one_s(v, P))(jnp.array([[5], [6]]))
s = jnp.float32(s.reshape(N*P, -1))
print(s.shape)
print(s)

(8, 1)
[[0.       ]
 [1.6666667]
 [3.3333335]
 [5.       ]
 [0.       ]
 [2.       ]
 [4.       ]
 [6.       ]]


In [45]:
batch_size = 4
dataset = DataGenerator(u, y, s, batch_size)

In [46]:
# inputs  : (u, y)
# outputs : s
inputs, outputs = next(iter(dataset))

In [47]:
print(inputs[0].shape)
print(inputs[0])

(4, 3)
[[1. 1. 1.]
 [1. 1. 1.]
 [2. 2. 2.]
 [2. 2. 2.]]


In [48]:
print(inputs[1].shape)
print(inputs[1])

(4, 1)
[[0.        ]
 [0.33333334]
 [1.3333334 ]
 [2.        ]]


In [49]:
print(outputs.shape)
print(outputs)

(4, 1)
[[0.       ]
 [1.6666667]
 [4.       ]
 [6.       ]]
