Grants command for Access on Demand (AoD)

https://grants.corp.google.com/#/grants?request=20h%2Fchr-ards-electrodes-deid-colab-jobs&reason=b%2F314799341


# Load TF Dataset

In [None]:
from colabtools import adhoc_import
import matplotlib.pyplot as plt
import numpy as np
from google3.pyglib import gfile
import tensorflow as tf
import tensorflow_datasets as tfds

ds = tfds.load('lsm_prod/lsm_300min_100K_unimpute', data_dir='/namespace/fitbit-medical-sandboxes/partner/encrypted/chr-ards-electrodes/deid/raw/datasets/msa_1_5/lsm_tfds_datasets', split='train', shuffle_files=True)
assert isinstance(ds, tf.data.Dataset)
print(ds)

with gfile.Open('/namespace/fitbit-medical-sandboxes/partner/encrypted/chr-ards-electrodes/deid/raw/datasets/msa_1_5/lsm_tfds_datasets/lsm_prod/lsm_300min_100K_unimpute/1.0.0/Dataset_FeatureNames.csv', 'r') as f:
  df = pd.read_csv(f)

features = df.columns

# Plot Minutely Data Sample

In [None]:
ds = ds.take(1)  # Only take a single example

for example in ds:  # example is `{'image': tf.Tensor, 'label': tf.Tensor}`
  print(list(example.keys()))
  inputs = example["input_signal"]
  label = example["label"]
  print(inputs.shape, label)

  plt.figure(figsize=(15,10))
  imgplot = plt.imshow(np.swapaxes(inputs,0,1))
  plt.grid(None)
  plt.xlabel('Time (minutes)')
  plt.ylabel('Feature #')
  plt.show()

In [None]:
ds = ds.take(1)  # Only take a single example

for example in ds:  # example is `{'image': tf.Tensor, 'label': tf.Tensor}`
  print(list(example.keys()))
  inputs = example["input_signal"]
  label = example["label"]
  print(inputs.shape, label)

  fig, axs = plt.subplots(25, 1, figsize=(10,35))#, layout='constrained')
  #plt.figure(figsize=(15,10))

  for i, ax in enumerate(axs):
    ax.plot(inputs[:,i])
    ax.set_title(features[i])
    if i < len(axs) - 1:
      ax.get_xaxis().set_ticks([])

# Example of Training

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow_datasets as tfds
import tensorflow as tf
from colabtools import adhoc_import
with adhoc_import.Google3(behavior='preferred'):
  import tensorflow_addons as tfa


train_ds = tfds.load('lsm_prod/lsm_300min_100K_unimpute', data_dir='/namespace/fitbit-medical-sandboxes/partner/encrypted/chr-ards-electrodes/deid/raw/datasets/msa_1_5/lsm_tfds_datasets', split='train', shuffle_files=True)
assert isinstance(ds, tf.data.Dataset)
print(ds)
test_ds = tfds.load('lsm_prod/lsm_300min_100K_unimpute', data_dir='/namespace/fitbit-medical-sandboxes/partner/encrypted/chr-ards-electrodes/deid/raw/datasets/msa_1_5/lsm_tfds_datasets', split='test', shuffle_files=True)

train_ds = train_ds.map(lambda example: (tf.clip_by_value(tf.transpose(example['input_signal']+5, [1, 0, 2]),0, 10)))
test_ds = test_ds.map(lambda example: (tf.clip_by_value(tf.transpose(example['input_signal']+5, [1, 0, 2]),0, 10)))

examples = train_ds.take(1)  # Only take a single example

for example in examples:
  inputs = example
  imgplot = plt.imshow(inputs)
  plt.xlabel('Time (minutes)')
  plt.ylabel('Feature #')
  plt.colorbar()
  plt.show()

assert isinstance(train_ds, tf.data.Dataset)
print(train_ds)



# Create a Mirrored scope to allow for training across multiple GPUs
mirrored_strategy = tf.distribute.MirroredStrategy()
with mirrored_strategy.scope():

    from tensorflow.keras import layers
    #import tensorflow_addons as tfa
    from tensorflow import keras
    import tensorflow as tf

    from datetime import datetime
    import matplotlib.pyplot as plt
    import numpy as np
    import random

    # Setting seeds for reproducibility.
    SEED = 42
    tf.random.set_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)

    # DATA
    BUFFER_SIZE = 1024
    BATCH_SIZE = 32#256
    AUTO = tf.data.AUTOTUNE
    INPUT_SHAPE = (300, 26, 1)
    NUM_CLASSES = 10

    # OPTIMIZER
    LEARNING_RATE = 5e-3
    WEIGHT_DECAY = 1e-4

    # TRAINING
    EPOCHS = 20

    # AUGMENTATION
    IMAGE_SIZE_1 = 25  # We'll resize input images to this size.
    IMAGE_SIZE_2 = 300  # We'll resize input images to this size.
    PATCH_SIZE_1 = 5  # Size of the patches to be extract from the input images.
    PATCH_SIZE_2 = 10  # Size of the patches to be extract from the input images.
    NO_CHANNELS = 1
    NUM_PATCHES = (IMAGE_SIZE_1 // PATCH_SIZE_1) * (IMAGE_SIZE_2 // PATCH_SIZE_2) * NO_CHANNELS
    MASK_PROPORTION = 0.75

    # ENCODER and DECODER
    LAYER_NORM_EPS = 1e-6
    ENC_PROJECTION_DIM = 32#128
    DEC_PROJECTION_DIM = 16#64
    ENC_NUM_HEADS = 4
    ENC_LAYERS = 3
    DEC_NUM_HEADS = 4
    DEC_LAYERS = 1 # The decoder is lightweight but should be reasonably deep for reconstruction.
    ENC_TRANSFORMER_UNITS = [
        ENC_PROJECTION_DIM * 2,
        ENC_PROJECTION_DIM,
    ]  # Size of the transformer layers.
    DEC_TRANSFORMER_UNITS = [
        DEC_PROJECTION_DIM * 2,
        DEC_PROJECTION_DIM,
    ]

    def get_train_augmentation_model():
        model = keras.Sequential(
            [
                layers.Rescaling(1 / 10.0),
                #layers.Resizing(INPUT_SHAPE[0] + 20, INPUT_SHAPE[0] + 20),
                #layers.RandomCrop(IMAGE_SIZE, IMAGE_SIZE),
                #layers.RandomFlip("horizontal"),
            ],
            name="train_data_augmentation",
        )
        return model


    def get_test_augmentation_model():
        model = keras.Sequential(
            [layers.Rescaling(1 / 10.0), ],#layers.Resizing(IMAGE_SIZE, IMAGE_SIZE),],
            name="test_data_augmentation",
        )
        return model

    class Patches(layers.Layer):
        def __init__(self, patch_size_1=PATCH_SIZE_1, patch_size_2=PATCH_SIZE_2, no_channels=NO_CHANNELS, img_size_1=IMAGE_SIZE_1, img_size_2=IMAGE_SIZE_2, **kwargs):
            super().__init__(**kwargs)
            self.patch_size_1 = patch_size_1
            self.patch_size_2 = patch_size_2
            self.img_size_1 = img_size_1
            self.img_size_2 = img_size_2
            self.no_channels = no_channels

            # Assuming the image has three channels each patch would be
            # of size (patch_size, patch_size, 3).
            self.resize = layers.Reshape((-1, patch_size_1 * patch_size_2 * no_channels))

        def call(self, images):
            # Create patches from the input images
            #print(images)
            patches = tf.image.extract_patches(
                images=images,
                sizes=[1, self.patch_size_1, self.patch_size_2, 1],
                strides=[1, self.patch_size_1, self.patch_size_2, 1],
                rates=[1, 1, 1, 1],
                padding="VALID",
            )

            # Reshape the patches to (batch, num_patches, patch_area) and return it.
            patches = self.resize(patches)
            return patches

        def show_patched_image(self, images, patches):
            # This is a utility function which accepts a batch of images and its
            # corresponding patches and help visualize one image and its patches
            # side by side.
            idx = np.random.choice(patches.shape[0])
            #print(f"Index selected: {idx}.")

            plt.figure(figsize=(4, 4))
            plt.imshow(keras.utils.array_to_img(images[idx]))
            plt.axis("off")
            plt.show()

            n = int(np.sqrt(patches.shape[1]))
            plt.figure(figsize=(4, 4))
            for i, patch in enumerate(patches[idx]):
                ax = plt.subplot(images[idx].shape[0]//self.patch_size_1, images[idx].shape[1]//self.patch_size_2, i + 1)
                patch_img = tf.reshape(patch, (self.patch_size_1, self.patch_size_2, self.no_channels))
                plt.imshow(keras.utils.img_to_array(patch_img))
                plt.axis("off")
            plt.show()

            # Return the index chosen to validate it outside the method.
            return idx

        # taken from https://stackoverflow.com/a/58082878/10319735
        def reconstruct_from_patch(self, patch):
            # This utility function takes patches from a *single* image and
            # reconstructs it back into the image. This is useful for the train
            # monitor callback.

            num_patches = patch.shape[0]

            num_patches_1 = self.img_size_1 // self.patch_size_1
            num_patches_2 = self.img_size_2 // self.patch_size_2
            #n = int(np.sqrt(num_patches))
            patch = tf.reshape(patch, (num_patches, self.patch_size_1, self.patch_size_2, self.no_channels))
            rows = tf.split(patch, num_patches_1, axis=0)
            rows = [tf.concat(tf.unstack(x), axis=1) for x in rows]
            reconstructed = tf.concat(rows, axis=0)
            return reconstructed

    # Get a batch of images.
    train_ds = train_ds.batch(BATCH_SIZE)
    val_ds = test_ds.batch(BATCH_SIZE)
    image_batch = next(iter(train_ds))
    #image_batch = image_batch["input_signal"]
    #image_batch = tf.transpose(image_batch, [0, 2, 1, 3])

    # Augment the images.
    augmentation_model = get_train_augmentation_model()
    augmeneted_images = augmentation_model(image_batch)

    # Define the patch layer.
    patch_layer = Patches()

    # Get the patches from the batched images.
    patches = patch_layer(images=augmeneted_images)

    # Now pass the images and the corresponding patches
    # to the `show_patched_image` method.
    random_index = patch_layer.show_patched_image(images=augmeneted_images, patches=patches)

    # Chose the same chose image and try reconstructing the patches
    # into the original image.
    image = patch_layer.reconstruct_from_patch(patches[random_index])
    plt.imshow(image)
    plt.axis("off")
    plt.show()

    class PatchEncoder(layers.Layer):
        def __init__(
            self,
            patch_size_1=PATCH_SIZE_1,
            patch_size_2=PATCH_SIZE_2,
            no_channels=NO_CHANNELS,
            projection_dim=ENC_PROJECTION_DIM,
            mask_proportion=MASK_PROPORTION,
            downstream=False,
            **kwargs,
        ):
            super().__init__(**kwargs)
            self.patch_size_1 = patch_size_1
            self.patch_size_2 = patch_size_2
            self.no_channels = no_channels
            self.projection_dim = projection_dim
            self.mask_proportion = mask_proportion
            self.downstream = downstream

            # This is a trainable mask token initialized randomly from a normal
            # distribution.
            self.mask_token = tf.Variable(
                tf.random.normal([1, patch_size_1 * patch_size_2 * no_channels]), trainable=True
            )

        def build(self, input_shape):
            (_, self.num_patches, self.patch_area) = input_shape

            # Create the projection layer for the patches.
            self.projection = layers.Dense(units=self.projection_dim)

            # Create the positional embedding layer.
            self.position_embedding = layers.Embedding(
                input_dim=self.num_patches, output_dim=self.projection_dim
            )

            # Number of patches that will be masked.
            self.num_mask = int(self.mask_proportion * self.num_patches)

        def call(self, patches):
            # Get the positional embeddings.
            batch_size = tf.shape(patches)[0]
            positions = tf.range(start=0, limit=self.num_patches, delta=1)
            pos_embeddings = self.position_embedding(positions[tf.newaxis, ...])
            pos_embeddings = tf.tile(
                pos_embeddings, [batch_size, 1, 1]
            )  # (B, num_patches, projection_dim)

            # Embed the patches.
            patch_embeddings = (
                self.projection(patches) + pos_embeddings
            )  # (B, num_patches, projection_dim)

            if self.downstream:
                return patch_embeddings
            else:
                mask_indices, unmask_indices = self.get_random_indices(batch_size)
                # The encoder input is the unmasked patch embeddings. Here we gather
                # all the patches that should be unmasked.
                unmasked_embeddings = tf.gather(
                    patch_embeddings, unmask_indices, axis=1, batch_dims=1
                )  # (B, unmask_numbers, projection_dim)

                # Get the unmasked and masked position embeddings. We will need them
                # for the decoder.
                unmasked_positions = tf.gather(
                    pos_embeddings, unmask_indices, axis=1, batch_dims=1
                )  # (B, unmask_numbers, projection_dim)
                masked_positions = tf.gather(
                    pos_embeddings, mask_indices, axis=1, batch_dims=1
                )  # (B, mask_numbers, projection_dim)

                # Repeat the mask token number of mask times.
                # Mask tokens replace the masks of the image.
                mask_tokens = tf.repeat(self.mask_token, repeats=self.num_mask, axis=0)
                mask_tokens = tf.repeat(
                    mask_tokens[tf.newaxis, ...], repeats=batch_size, axis=0
                )

                # Get the masked embeddings for the tokens.
                masked_embeddings = self.projection(mask_tokens) + masked_positions
                return (
                    unmasked_embeddings,  # input to the encoder
                    masked_embeddings,  # first part of input to the decoder
                    unmasked_positions,  # added to the encoder outputs
                    mask_indices,  # the indices that were masked
                    unmask_indices,  # the indices that were unmaksed
                )

        def get_random_indices(self, batch_size):
            # Create random indices from a uniform distribution and then split
            # it into mask and unmask indices.
            rand_indices = tf.argsort(
                tf.random.uniform(shape=(batch_size, self.num_patches)), axis=-1
            )
            mask_indices = rand_indices[:, : self.num_mask]
            unmask_indices = rand_indices[:, self.num_mask :]

            return mask_indices, unmask_indices

        def show_masked_image(self, patches, unmask_indices):
            # choose a random patch and it corresponding unmask index
            idx = np.random.choice(patches.shape[0])
            patch = patches[idx]
            unmask_index = unmask_indices[idx]

            # build a numpy array of same shape as pathc
            new_patch = np.zeros_like(patch)

            # iterate of the new_patch and plug the unmasked patches
            count = 0
            for i in range(unmask_index.shape[0]):
                new_patch[unmask_index[i]] = patch[unmask_index[i]]
            return new_patch, idx

    # Create the patch encoder layer.
    patch_encoder = PatchEncoder()

    # Get the embeddings and positions.
    (
        unmasked_embeddings,
        masked_embeddings,
        unmasked_positions,
        mask_indices,
        unmask_indices,
    ) = patch_encoder(patches=patches)

    # Show a maksed patch image.
    new_patch, random_index = patch_encoder.show_masked_image(patches, unmask_indices)

    plt.figure(figsize=(10, 10))
    plt.subplot(1, 2, 1)
    img = patch_layer.reconstruct_from_patch(new_patch)
    plt.imshow(keras.utils.array_to_img(img))
    plt.axis("off")
    plt.title("Masked")
    plt.subplot(1, 2, 2)
    img = augmeneted_images[random_index]
    plt.imshow(keras.utils.array_to_img(img))
    plt.axis("off")
    plt.title("Original")
    plt.show()

    def mlp(x, dropout_rate, hidden_units):
        for units in hidden_units:
            x = layers.Dense(units, activation=tf.nn.gelu)(x)
            x = layers.Dropout(dropout_rate)(x)
        return x

    def create_encoder(num_heads=ENC_NUM_HEADS, num_layers=ENC_LAYERS):
        inputs = layers.Input((None, ENC_PROJECTION_DIM))
        x = inputs

        #x = layers.BatchNormalization()(x)

        for _ in range(num_layers):
            # Layer normalization 1.
            x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x)

            # Create a multi-head attention layer.
            attention_output = layers.MultiHeadAttention(
                num_heads=num_heads, key_dim=ENC_PROJECTION_DIM, dropout=0.1
            )(x1, x1)

            # Skip connection 1.
            x2 = layers.Add()([attention_output, x])

            # Layer normalization 2.
            x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)

            # MLP.
            x3 = mlp(x3, hidden_units=ENC_TRANSFORMER_UNITS, dropout_rate=0.1)

            # Skip connection 2.
            x = layers.Add()([x3, x2])

        outputs = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x)
        return keras.Model(inputs, outputs, name="mae_encoder")

    def create_decoder(
        num_layers=DEC_LAYERS, num_heads=DEC_NUM_HEADS, image_size_1=IMAGE_SIZE_1, image_size_2=IMAGE_SIZE_2, no_channels=NO_CHANNELS
    ):
        inputs = layers.Input((NUM_PATCHES, ENC_PROJECTION_DIM))
        x = layers.Dense(DEC_PROJECTION_DIM)(inputs)

        for _ in range(num_layers):
            # Layer normalization 1.
            x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x)

            # Create a multi-head attention layer.
            attention_output = layers.MultiHeadAttention(
                num_heads=num_heads, key_dim=DEC_PROJECTION_DIM, dropout=0.1
            )(x1, x1)

            # Skip connection 1.
            x2 = layers.Add()([attention_output, x])

            # Layer normalization 2.
            x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)

            # MLP.
            x3 = mlp(x3, hidden_units=DEC_TRANSFORMER_UNITS, dropout_rate=0.1)

            # Skip connection 2.
            x = layers.Add()([x3, x2])

        x = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x)
        x = layers.Flatten()(x)
        pre_final = layers.Dense(units=image_size_1 * image_size_2 * no_channels, activation="sigmoid")(x)
        outputs = layers.Reshape((image_size_1, image_size_2, no_channels))(pre_final)

        return keras.Model(inputs, outputs, name="mae_decoder")

    class MaskedAutoencoder(keras.Model):
        def __init__(
            self,
            train_augmentation_model,
            test_augmentation_model,
            patch_layer,
            patch_encoder,
            encoder,
            decoder,
            **kwargs
        ):
            super().__init__(**kwargs)
            self.train_augmentation_model = train_augmentation_model
            self.test_augmentation_model = test_augmentation_model
            self.patch_layer = patch_layer
            self.patch_encoder = patch_encoder
            self.encoder = encoder
            self.decoder = decoder

        def calculate_loss(self, images, test=False):
            # Augment the input images.
            if test:
                augmeneted_images = self.test_augmentation_model(images)
            else:
                augmeneted_images = self.train_augmentation_model(images)

            # Patch the augmented images.
            patches = self.patch_layer(augmeneted_images)

            # Encode the patches.
            (
                unmasked_embeddings,
                masked_embeddings,
                unmasked_positions,
                mask_indices,
                unmask_indices,
            ) = self.patch_encoder(patches)

            # Pass the unmaksed patche to the encoder.
            encoder_outputs = self.encoder(unmasked_embeddings)

            # Create the decoder inputs.
            encoder_outputs = encoder_outputs + unmasked_positions
            decoder_inputs = tf.concat([encoder_outputs, masked_embeddings], axis=1)

            # Decode the inputs.
            decoder_outputs = self.decoder(decoder_inputs)
            decoder_patches = self.patch_layer(decoder_outputs)

            loss_patch = tf.gather(patches, mask_indices, axis=1, batch_dims=1)
            loss_output = tf.gather(decoder_patches, mask_indices, axis=1, batch_dims=1)

            # Compute the total loss.
            total_loss = self.compiled_loss(loss_patch, loss_output)

            return total_loss, loss_patch, loss_output

        def train_step(self, images):
            with tf.GradientTape() as tape:
                total_loss, loss_patch, loss_output = self.calculate_loss(images)

            # Apply gradients.
            train_vars = [
                self.train_augmentation_model.trainable_variables,
                self.patch_layer.trainable_variables,
                self.patch_encoder.trainable_variables,
                self.encoder.trainable_variables,
                self.decoder.trainable_variables,
            ]
            grads = tape.gradient(total_loss, train_vars)
            tv_list = []
            for (grad, var) in zip(grads, train_vars):
                for g, v in zip(grad, var):
                    tv_list.append((g, v))
            self.optimizer.apply_gradients(tv_list)

            # Report progress.
            self.compiled_metrics.update_state(loss_patch, loss_output)
            return {m.name: m.result() for m in self.metrics}

        def test_step(self, images):
            total_loss, loss_patch, loss_output = self.calculate_loss(images, test=True)

            # Update the trackers.
            self.compiled_metrics.update_state(loss_patch, loss_output)
            return {m.name: m.result() for m in self.metrics}

    keras.backend.clear_session()

    train_augmentation_model = get_train_augmentation_model()
    test_augmentation_model = get_test_augmentation_model()
    patch_layer = Patches()
    patch_encoder = PatchEncoder()
    encoder = create_encoder()
    decoder = create_decoder()

    mae_model = MaskedAutoencoder(
        train_augmentation_model=train_augmentation_model,
        test_augmentation_model=test_augmentation_model,
        patch_layer=patch_layer,
        patch_encoder=patch_encoder,
        encoder=encoder,
        decoder=decoder,
    )

    # Taking a batch of test inputs to measure model's progress.
    test_images = next(iter(val_ds))
    #test_images = test_images["input_signal"]

    class TrainMonitor(tf.keras.callbacks.Callback):
        def __init__(self, epoch_interval=None):
            self.epoch_interval = epoch_interval

        def on_epoch_end(self, epoch, logs=None):
            if self.epoch_interval and epoch % self.epoch_interval == 0:
                test_augmeneted_images = self.model.test_augmentation_model(test_images)
                test_patches = self.model.patch_layer(test_augmeneted_images)
                (
                    test_unmasked_embeddings,
                    test_masked_embeddings,
                    test_unmasked_positions,
                    test_mask_indices,
                    test_unmask_indices,
                ) = self.model.patch_encoder(test_patches)
                test_encoder_outputs = self.model.encoder(test_unmasked_embeddings)
                test_encoder_outputs = test_encoder_outputs + test_unmasked_positions
                test_decoder_inputs = tf.concat(
                    [test_encoder_outputs, test_masked_embeddings], axis=1
                )
                test_decoder_outputs = self.model.decoder(test_decoder_inputs)

                # Show a maksed patch image.
                test_masked_patch, idx = self.model.patch_encoder.show_masked_image(
                    test_patches, test_unmask_indices
                )
                print(f"\nIdx chosen: {idx}")
                original_image = test_augmeneted_images[idx]
                masked_image = self.model.patch_layer.reconstruct_from_patch(
                    test_masked_patch
                )
                reconstructed_image = test_decoder_outputs[idx]

                fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))
                ax[0].imshow(original_image)
                ax[0].set_title(f"Original: {epoch:03d}")

                ax[1].imshow(masked_image)
                ax[1].set_title(f"Masked: {epoch:03d}")

                ax[2].imshow(reconstructed_image)
                ax[2].set_title(f"Resonstructed: {epoch:03d}")

                plt.show()
                plt.close()


    class WarmUpCosine(keras.optimizers.schedules.LearningRateSchedule):
        def __init__(
            self, learning_rate_base, total_steps, warmup_learning_rate, warmup_steps
        ):
            super(WarmUpCosine, self).__init__()

            self.learning_rate_base = learning_rate_base
            self.total_steps = total_steps
            self.warmup_learning_rate = warmup_learning_rate
            self.warmup_steps = warmup_steps
            self.pi = tf.constant(np.pi)

        def __call__(self, step):
            if self.total_steps < self.warmup_steps:
                raise ValueError("Total_steps must be larger or equal to warmup_steps.")

            cos_annealed_lr = tf.cos(
                self.pi
                * (tf.cast(step, tf.float32) - self.warmup_steps)
                / float(self.total_steps - self.warmup_steps)
            )
            learning_rate = 0.5 * self.learning_rate_base * (1 + cos_annealed_lr)

            if self.warmup_steps > 0:
                if self.learning_rate_base < self.warmup_learning_rate:
                    raise ValueError(
                        "Learning_rate_base must be larger or equal to "
                        "warmup_learning_rate."
                    )
                slope = (
                    self.learning_rate_base - self.warmup_learning_rate
                ) / self.warmup_steps
                warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate
                learning_rate = tf.where(
                    step < self.warmup_steps, warmup_rate, learning_rate
                )
            return tf.where(
                step > self.total_steps, 0.0, learning_rate, name="learning_rate"
            )

    total_steps = int((len(train_ds) / BATCH_SIZE) * EPOCHS)
    warmup_steps = int(total_steps * 0.15)
    scheduled_lrs = WarmUpCosine(
        learning_rate_base=LEARNING_RATE,
        total_steps=total_steps,
        warmup_learning_rate=0.0,
        warmup_steps=warmup_steps,
    )

    lrs = [scheduled_lrs(step) for step in range(total_steps)]
    plt.plot(lrs)
    plt.xlabel("Step", fontsize=14)
    plt.ylabel("LR", fontsize=14)
    plt.show()

    timestamp = datetime.utcnow().strftime("%y%m%d-%H%M%S")

    train_callbacks = [
        keras.callbacks.TensorBoard(log_dir=f"mae_logs_{timestamp}"),
        TrainMonitor(epoch_interval=1),
    ]

    optimizer = tfa.optimizers.AdamW(learning_rate=scheduled_lrs, weight_decay=WEIGHT_DECAY)

    mae_model.compile(
        optimizer=optimizer, loss=keras.losses.MeanSquaredError(), metrics=["mae"]
    )

    history = mae_model.fit(
        train_ds, epochs=EPOCHS, validation_data=val_ds, callbacks=train_callbacks,
    )

    loss, mae = mae_model.evaluate(val_ds)
    print(f"Loss: {loss:.2f}")
    print(f"MAE: {mae:.2f}")

# Move Dataset

In [None]:
import os
from google3.pyglib import gfile
gfile.MakeDirs("/namespace/fitbit-medical-sandboxes/partner/encrypted/chr-ards-electrodes/deid/exp")

In [None]:
read_dir = "/namespace/fitbit-medical-sandboxes/partner/encrypted/chr-ards-electrodes/deid/raw/datasets/msa_1_5/lsm_tfds_datasets/lsm_prod/lsm_300min_100K_unimpute"



from google3.pyglib.contrib.gfile_util import gfile_util
gfile_util.CopyDir(read_dir, write_dir)

In [None]:
from colabtools import adhoc_import
import matplotlib.pyplot as plt
import numpy as np
from google3.pyglib import gfile
import tensorflow as tf
import tensorflow_datasets as tfds

ds = tfds.load('lsm_prod/lsm_300min_100K_unimpute', data_dir='/namespace/fitbit-medical-sandboxes/partner/encrypted/chr-ards-electrodes/deid/raw/datasets/msa_1_5/lsm_tfds_datasets')
print(ds)

ds['train'].save('/namespace/fitbit-medical-sandboxes/partner/encrypted/chr-ards-electrodes/deid/exp/lsm_300min_100K_unimpute')