In [1]:
import tensorflow as tf
import h5py
import numpy as np
from layers.mae import MAE
from layers.diffusion import DenoiseCT

 The versions of TensorFlow you are currently using is 2.10.1 and is not supported. 
Some things might work, some things might not.
If you were to encounter a bug, do not file an issue.
If you want to make sure you're using a tested and supported configuration, either change the TensorFlow version or the TensorFlow Addons's version. 
You can find the compatibility matrix in TensorFlow Addon's readme:
https://github.com/tensorflow/addons


In [2]:
@tf.function
def transform(sinogram, gt):
    rand_indices = tf.argsort(
            tf.random.uniform(shape=(8, 1024)), axis=-1
        )
    mask_indices = rand_indices[:, : 768]
    unmask_indices = rand_indices[:, 768:]
    sinogram = tf.expand_dims(sinogram - 42.932495, -1) / 31.87962
    sinogram = tf.image.resize(sinogram, (1024, 513))
    gt = tf.expand_dims(gt - 0.16737686, -1) / 0.11505456
    gt = tf.image.resize(gt, (512, 512))
    return sinogram

In [10]:
@tf.function
def transform(sinogram, gt):
    rand_indices = tf.argsort(
            tf.random.uniform(shape=(8, 1024)), axis=-1
        )
    mask_indices = rand_indices[:, : 512]
    unmask_indices = rand_indices[:, 512:]
    sinogram = tf.expand_dims(sinogram - 42.932495, -1) / 31.87962
    sinogram = tf.image.resize(sinogram, (1024, 513))

    gt = tf.expand_dims(gt - 0.16737686, -1) / 0.11505456
    noisy_gt = gt + tf.random.normal(shape=(8, 362, 362, 1), mean=0, stddev=0.1)
    return (sinogram, mask_indices, unmask_indices, noisy_gt), gt

In [11]:
feature_desc = {
    'observation': tf.io.FixedLenFeature([], tf.string),
    'ground_truth': tf.io.FixedLenFeature([], tf.string)
}

def _parse_example(example_proto):
    res = tf.io.parse_single_example(example_proto, feature_desc)
    observation = tf.io.parse_tensor(res['observation'], out_type=tf.float32)
    ground_truth = tf.io.parse_tensor(res['ground_truth'], out_type=tf.float32)
    observation.set_shape((1000, 513))
    ground_truth.set_shape((362, 362))
    return observation, ground_truth

In [12]:
train_ds = (tf.data.TFRecordDataset('lodopab_full_dose_train.tfrecord')
            .map(_parse_example)
            .batch(8)
            .map(transform)
            .prefetch(tf.data.AUTOTUNE))

test_ds = (tf.data.TFRecordDataset('lodopab_full_dose_validation.tfrecord.tfrecord')
            .map(_parse_example)
            .batch(8)
            .map(transform)
            .prefetch(tf.data.AUTOTUNE))

In [5]:
model = MAE(enc_layers=1, dec_layers=1, sinogram_width=513, sinogram_height=1, input_shape=(1024, 513, 1), enc_dim=512, enc_mlp_units=2048, dec_dim=513,
                         dec_mlp_units=2048)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5), loss='mse')
model.fit(train_ds, epochs=1, validation_data=test_ds)

    142/Unknown - 80s 535ms/step - loss: 0.7669

KeyboardInterrupt: 

In [6]:
from layers.diffusion import CircleTransformer

In [7]:
ds_model = CircleTransformer(model, 512, 256, 1, output_width=362, output_height=362)

In [8]:
ds_model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 1024, 513,   0           []                               
                                1)]                                                               
                                                                                                  
 mae_patches (Patches)          (None, None, 513)    0           ['input_1[0][0]']                
                                                                                                  
 input_2 (InputLayer)           [(None, 512)]        0           []                               
                                                                                                  
 input_3 (InputLayer)           [(None, 512)]        0           []                           

In [13]:
ds_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5), loss='mse')
ds_model.fit(train_ds, epochs=1, validation_data=test_ds)

    515/Unknown - 269s 516ms/step - loss: 0.6920

KeyboardInterrupt: 