In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import plotly.graph_objects as go  # for 3d plots

from models.agg_3d import Agg3D

In [None]:
strategy = tf.distribute.MirroredStrategy()
    
print("REPLICAS: ", strategy.num_replicas_in_sync)

Loading data and defining some functions for data visualisation

In [None]:
# List of materials by radiation length
radiation_lengths = [
    1000000000000000000000000, # air
    49.834983498349835,  # benzene
    49.82309830679809,   # methanol
    36.08,               # water
    14.385057471264368,  # magnesium
    11.552173913043479,  # concrete
    10.607758620689655,  # gypsum
    10.412903225806451,  # calcium
    9.75,                # sulfur
    9.368827823100043,   # silicon
    8.895887365690998,   # aluminium
    4.436732514682328,   # caesium
    1.967741935483871,   # manganese
    1.7576835153670307,  # iron
    1.7200811359026373,  # iodine
    1.4243990114580993,  # nickel
    0.9589041095890413,  # molybdenum
    0.8542857142857143,  # silver
    0.6609442060085837,  # polonium
    0.5612334801762114,  # lead
    0.33436853002070394, # gold
    0.316622691292876    # uranium
]

inverse_radiation_length = [1/x for x in radiation_lengths]

In [None]:
# Create a description of the features.
feature_description = {
    'x': tf.io.FixedLenFeature([], tf.string),
    'y': tf.io.FixedLenFeature([], tf.string)
}

def _parse_example(example_proto):
    res = tf.io.parse_single_example(example_proto, feature_description)
    x = tf.io.parse_tensor(res['x'], out_type=tf.double)
    y = tf.io.parse_tensor(res['y'], out_type=tf.int32)
    y.set_shape((64, 64, 64))
    
    x = tf.cast(x, dtype=tf.float32)
    return x, y

def set_dosage(x, y, dosage):
    x = tf.random.shuffle(x)
    x = x[:dosage]
    x.set_shape((dosage, 12))
    return x, y

def construct_ds(dosage, detector_resolution=-1, p_error=0.2, batch_size=8, dataset="test"):
    train_path = "/path/to/dataset/voxels_prediction.tfrecord"
    test_path = "/path/to/dataset/voxels_prediction_test.tfrecord"
    val_path = "/path/to/dataset/voxels_prediction_val.tfrecord"
    
    if dataset == "test": path = test_path
    elif dataset == "val": path = val_path
    else: path = train_path
    
    return (
        tf.data.TFRecordDataset(path, compression_type="GZIP")
        .map(_parse_example)
        .filter(lambda x, y: len(x) >= dosage)
        .map(lambda x, y: set_dosage(x, y, dosage))
        .map(
            lambda x, y: (
                tf.concat([
                    (x[:, :3] / 1000 + 0.5) if detector_resolution < 0 else tf.cast(tf.math.rint(x[:, :3] / 1000 * detector_resolution), tf.float32) / detector_resolution + 0.5, 
                    x[:, 3:6] / tf.norm(x[:, 3:6], axis=-1, keepdims=True),
                    (x[:, 6:9] / 1000 + 0.5) if detector_resolution < 0 else tf.cast(tf.math.rint(x[:, 6:9] / 1000 * detector_resolution), tf.float32) / detector_resolution + 0.5,
                    x[:, 9:12],
                    (tf.norm(x[:, 3:6], axis=-1, keepdims=True) * tf.random.normal((tf.shape(x[:, :1])), 1, p_error) - 5585.2666) / 13839.263 if p_error > -1e-8 else 0 * tf.norm(x[:, 3:6], axis=-1, keepdims=True),
                    tf.norm(x[:, 9:12] - x[:, 3:6] / tf.norm(x[:, 3:6], axis=-1, keepdims=True), axis=-1, keepdims=True),
                ], axis=1), tf.gather_nd(
                    inverse_radiation_length, 
                    tf.cast(y[..., tf.newaxis], tf.int32)
                )[..., tf.newaxis]
            )
        ).batch(batch_size).shuffle(200)
    )

Plot the scattering points predicted by point of closest approach (POCA)

In [None]:
def plot_scattering_points(pts):
    fig = go.Figure(data=[go.Scatter3d(x=pts[:, 0], y=pts[:, 1], z=pts[:, 2], mode='markers')])
    fig.update_layout(
        scene = dict(
            xaxis = dict(nticks=4, range=[0,1]),
            yaxis = dict(nticks=4, range=[0,1]),
            zaxis = dict(nticks=4, range=[0,1]),
            aspectratio={"x": 1, "y": 1, "z": 1}
        )
    )
    
    fig.show()

Plot the voxels

In [None]:
def plot_voxels(voxels, resolution=64, maximum=3.4, file=None):
    x_vals, y_vals, z_vals = tf.meshgrid(tf.range(0,resolution,1), tf.range(0,resolution,1), tf.range(0,resolution,1))
    x_vals = tf.reshape(x_vals, (resolution*resolution*resolution,))
    y_vals = tf.reshape(y_vals, (resolution*resolution*resolution,))
    z_vals = tf.reshape(z_vals, (resolution*resolution*resolution,))
    values = tf.reshape(voxels, (resolution*resolution*resolution,))

    fig = go.Figure(data=go.Volume(
        x=x_vals/resolution,
        y=y_vals/resolution,
        z=z_vals/resolution,
        value=values,
        isomin=0,
        isomax=maximum,
        opacityscale=[[0,0],[1,1],[1,1]],
        surface_count=20, # needs to be a large number for good volume rendering,
        colorscale="blues",
    ))
    
    if file is None:
        fig.show()
    else:
        fig.write_image(file)

Now, we will build the model.

In [None]:
def psnr_max_value(y, y_pred, max_value=inverse_radiation_length[-1]):
    print(y_pred.shape, y.shape, max_value)
    return 20 * tf.math.log(max_value / tf.sqrt(tf.math.reduce_mean(tf.square(y_pred - y)))) / tf.math.log(10.0)

In [None]:
dosage = 1024
print(f"Training on {dosage}...")

with strategy.scope():
    model = Agg3D(
        **{
            'point_size': 1,
            'downward_convs': [1, 2, 3, 4, 5],
            'downward_filters': [8, 16, 32, 64, 128],
            'upward_convs': [4, 3, 2, 1],
            'upward_filters': [64, 32, 16, 8],
            'resolution': 64,
            'threshold': 1e-8
        }
    )
    model.compile(
        optimizer=tf.keras.optimizers.AdamW(learning_rate=2e-3), 
        loss="mse", metrics=["mse", "mae", psnr_max_value]
    )

    print(model(tf.random.normal((8, dosage, 14))).shape)

    model.summary()

# splitting dataset into train and validation
train_ds = construct_ds(dosage, dataset="train")
val_ds = construct_ds(dosage, dataset="val")
test_ds = construct_ds(dosage, dataset="test")

# training the model
model.fit(train_ds, validation_data=val_ds, epochs=15)
model.save_weights(f"model_{dosage}.h5")

# evaluate how good the model really is
model.evaluate(test_ds)

Looking at some sample reconstructions.

In [None]:
model.evaluate(test_ds)

In [None]:
for x, y in test_ds.skip(3).take(1): break
y_pred = model.predict(x)

In [None]:
fig, ax = plt.subplots(4, 4, figsize=(20,20))

for i in range(4):
    ax[0, i].imshow(y[i, :, :, 32])
    
for i in range(4):
    ax[1, i].imshow(y_pred[i, :, :, 32])
    
for i in range(4):
    ax[2, i].imshow(y[4+i, :, :, 32])
    
for i in range(4):
    ax[3, i].imshow(y_pred[4+i, :, :, 32])

In [None]:
fig, ax = plt.subplots(4, 4, figsize=(20,20))

for i in range(4):
    ax[0, i].imshow(y[i, :, 32, :])
    
for i in range(4):
    ax[1, i].imshow(y_pred[i, :, 32, :])
    
for i in range(4):
    ax[2, i].imshow(y[4+i, :, 32, :])
    
for i in range(4):
    ax[3, i].imshow(y_pred[4+i, :, 32, :])