# PAT Explainability: Fill out to find what model you want to examine

WARNING: For now you have to use TPU v2 to load the model

In [None]:
# Dependencies
import tensorflow as tf
import numpy as np
import os
from tensorflow.keras import layers, models
from sklearn.preprocessing import StandardScaler


# Find what model you want and load that model

In [None]:
"""
Please Fill out Parameters Below
"""
## Model size
# eg. ["small", "medium", "large", "huge"]
size = "large"

## Mask ratio
# eg. [.25, .50, .75]
mask_ratio = 0.90

## Smoothing
# eg. [True, False]
smoothing = False

## Loss Function
# eg. [True, False], meaning MSE on only the masked portion or everything in the reconstruction
mse_only_masked = False

## dataset
# eg. ["n100", "n250", etc]
dataset = "n5769"

## Finetuning Style
# eg. ["full", "linear_probe"]
finetuning_style = "full"

## Where is this model located
model_root = "/content/drive/MyDrive/Extra Curricular /ActigraphyTransformer/A-NEW/PAT Experiments /PAT Finetuning/Models"

## Where the original encoder is from
encoder_root = "/content/drive/MyDrive/Extra Curricular /ActigraphyTransformer/A-NEW/PAT Experiments /PAT Pretraining/Encoders"

In [None]:
# Encoder naming
mask_name = int(mask_ratio*100)

encoder_name = f"/encoder_{size}_{mask_name}"

if smoothing == True:
  encoder_name = f"{encoder_name}_smoothed"
else:
  encoder_name = f"{encoder_name}_unsmoothed"

if mse_only_masked == True:
  encoder_name = f"{encoder_name}_mse_only_masked.h5"
else:
  encoder_name = f"{encoder_name}_mse_all.h5"

print(encoder_name)

In [None]:
# FT model name
mask_name = int(mask_ratio*100)

# Start of finetuning name
ft_name = f"/AcT_{size}_{mask_name}"

if smoothing == True:
  ft_name = f"{ft_name}_smoothed"
else:
  ft_name = f"{ft_name}_unsmoothed"

if mse_only_masked == True:
  ft_name = f"{ft_name}_mse_only_masked"
else:
  ft_name = f"{ft_name}_mse_all"

ft_name = f"{ft_name}_{dataset}_{finetuning_style}.h5"

print(ft_name)

# hyperparameter additional info

In [None]:
"""
Model Size
"""
## Model Size
if size == "small":

  patch_size = 18
  embed_dim = 96
  # encoder
  encoder_num_heads = 6
  encoder_ff_dim = 256
  encoder_num_layers = 1
  encoder_rate = 0.1
  # decoder
  decoder_num_heads = 6
  decoder_ff_dim = 256
  decoder_num_layers = 1
  decoder_rate = 0.1

if size == "medium":

  patch_size = 18
  embed_dim = 96
  # encoder
  encoder_num_heads = 12
  encoder_ff_dim = 256
  encoder_num_layers = 2
  encoder_rate = 0.1
  # decoder
  decoder_num_heads = 12
  decoder_ff_dim = 256
  decoder_num_layers = 1
  decoder_rate = 0.1

if size == "large":

  patch_size = 9
  embed_dim = 96
  # encoder
  encoder_num_heads = 12
  encoder_ff_dim = 256
  encoder_num_layers = 4
  encoder_rate = 0.1
  # decoder
  decoder_num_heads = 12
  decoder_ff_dim = 256
  decoder_num_layers = 1
  decoder_rate = 0.1

if size == "huge":

  patch_size = 5
  embed_dim = 96
  # encoder
  encoder_num_heads = 12
  encoder_ff_dim = 256
  encoder_num_layers = 8
  encoder_rate = 0.1
  # decoder
  decoder_num_heads = 12
  decoder_ff_dim = 256
  decoder_num_layers = 1
  decoder_rate = 0.1

In [None]:
"""
Smoothing
"""
if smoothing == True:
  data_folder_location = "/content/drive/MyDrive/Extra Curricular /ActigraphyTransformer/A-NEW/Baseline Tests/Data_2013/All_Meds/Smooth/TestSize2000_set1"

else:
  data_folder_location = "/content/drive/MyDrive/Extra Curricular /ActigraphyTransformer/A-NEW/Baseline Tests/Data_2013/All_Meds/Raw/TestSize2000_set1"

In [None]:
# Function to load the encoder and build the fine-tuning model with consistent patching and positional embedding
# Modified Transformer Block to output attention weights with explicit layer names (otherwise the same as the )
def TransformerBlock(embed_dim, num_heads, ff_dim, rate=0.1, name_prefix="encoder"):
    input_layer = layers.Input(shape=(None, embed_dim), name=f"{name_prefix}_input")
    attention_layer = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim, name=f"{name_prefix}_attention")
    attention_output, attention_weights = attention_layer(input_layer, input_layer, return_attention_scores=True)
    attention_output = layers.Dropout(rate, name=f"{name_prefix}_dropout")(attention_output)
    out1 = layers.LayerNormalization(epsilon=1e-6, name=f"{name_prefix}_norm1")(input_layer + attention_output)
    ff_output = layers.Dense(ff_dim, activation="relu", name=f"{name_prefix}_ff1")(out1)
    ff_output = layers.Dense(embed_dim, name=f"{name_prefix}_ff2")(ff_output)
    ff_output = layers.Dropout(rate, name=f"{name_prefix}_dropout2")(ff_output)
    final_output = layers.LayerNormalization(epsilon=1e-6, name=f"{name_prefix}_norm2")(out1 + ff_output)
    return models.Model(inputs=input_layer, outputs=[final_output, attention_weights], name=f"{name_prefix}_transformer")

# Sine/Cosine positional embeddings
def get_positional_embeddings(num_patches, embed_dim):
    position = tf.range(num_patches, dtype=tf.float32)[:, tf.newaxis]
    div_term = tf.exp(tf.range(0, embed_dim, 2, dtype=tf.float32) * (-tf.math.log(10000.0) / embed_dim))
    pos_embeddings = tf.concat([tf.sin(position * div_term), tf.cos(position * div_term)], axis=-1)
    return pos_embeddings

def create_model(encoder_path=encoder_root+encoder_name, input_size=10080, patch_size=patch_size, embed_dim=embed_dim, return_attention=False):

    # Load the saved encoder model
    encoder_model = tf.keras.models.load_model(encoder_path, custom_objects={'TransformerBlock': TransformerBlock, 'get_positional_embeddings': get_positional_embeddings})

    # Define new inputs for the fine-tuning model
    inputs = layers.Input(shape=(input_size,), name="finetuning_inputs")

    # Get encoder outputs
    encoder_outputs = encoder_model(inputs)
    encoder_outputs, attention_weights = encoder_outputs[0], encoder_outputs[1:]

    # Pass through a GlobalAveragePooling layer
    x = layers.GlobalAveragePooling1D(name="global_avg_pool")(encoder_outputs)
    x = layers.Dropout(0.1, name="dropout")(x)
    x = layers.Dense(128, activation='relu', name="dense_128")(x)
    outputs = layers.Dense(1, activation="sigmoid", name="output")(x)

    # Include attention weights in the final model outputs if requested
    if return_attention:
        outputs = [outputs] + attention_weights

    # Create and return the fine-tuning model
    finetuning_model = models.Model(inputs=inputs, outputs=outputs, name="finetuning_model")
    return finetuning_model

## Once you have your model loaded this is all you need to do

In [None]:
eval_model = create_model(return_attention=True)
eval_model.summary()

In [None]:
eval_model.load_weights(model_root+ft_name)

# Load Data

In [None]:
test_size = 2000 # fixed
# first save the test sets
X_test = np.load(os.path.join(data_folder_location, f'X_test_{test_size}.npy'))
y_test = np.load(os.path.join(data_folder_location, f'y_test_{test_size}.npy'))


# Scale the test set
scaler = StandardScaler()
scaler.fit(X_test)
X_test = scaler.transform(X_test)

In [None]:
print(X_test.shape)
print(y_test.shape)

# Plot Attention Weights

In [None]:
# @title Graphing Functions
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

##### SQUARE PLOTTER -----------

def plot_average_attention_weights(attention_weights, layer_num):
    # Extract the weights for the specified layer
    layer_weights = attention_weights[layer_num]  # Shape: (batch_size, num_heads, seq_length, seq_length)

    # Compute the mean across all heads
    mean_weights = np.mean(layer_weights, axis=1)  # Mean over the head dimension, resulting in (batch_size, seq_length, seq_length)

    # Set the figure and specify its size
    fig, ax = plt.subplots(figsize=(7, 7))  # Adjust the figsize as needed to make the plot appear bigger

    # Plotting the mean attention weights for the first sample in the batch using 'bwr' colormap
    cax = ax.matshow(mean_weights[0], cmap='bwr')  # Index 0 for the first sample

    # Create a colorbar with a smaller size
    fig.colorbar(cax, fraction=0.046, pad=0.04)  # Adjust the fraction to make the colorbar thinner or thicker

    ax.set_title(f'Layer {layer_num+1} Mean Attention Weights Across All Heads')
    ax.set_xlabel('Key')
    ax.set_ylabel('Query')
    plt.show()

#@title Function 2
def process_attention_weights(attention_weights, layer_num, patch_size):
    # Sum the attention weights across the key positions for each query position
    attention_profile = np.sum(attention_weights[layer_num][0], axis=1)

    # Average across heads
    attention_profile = np.mean(attention_profile, axis=0)

    # Normalize the attention profile to sum to 1
    attention_profile /= np.sum(attention_profile)

    # Expand the attention weights to match the original data's timeline
    expanded_attention_profile = np.repeat(attention_profile, patch_size)

    return expanded_attention_profile


def plot_data_with_attention_overlay_7d(original_data, attention_weights, layer_num, patch_size):
    # Process the attention weights
    processed_attention_profile = process_attention_weights(attention_weights, layer_num, patch_size)

    # Create a colormap based on the processed attention weights
    cmap = plt.cm.seismic
    norm = mcolors.Normalize(vmin=processed_attention_profile.min(), vmax=processed_attention_profile.max())

    fig, ax1 = plt.subplots(figsize=(12, 4))

    # Plot the original data
    ax1.plot(original_data, label='Original Data', color='black')
    ax1.set_xlabel('Time (minutes)')
    ax1.set_ylabel('Original Data Value', color='blue')
    ax1.tick_params(axis='y', labelcolor='blue')

    # Set up x-axis for daily ticks and labels
    days = ['Day 1', 'Day 2', 'Day 3', 'Day 4', 'Day 5', 'Day 6', 'Day 7']
    minutes_per_day = 1440
    week_minutes = minutes_per_day * len(days)
    day_starts = np.arange(0, week_minutes, minutes_per_day)
    ax1.set_xticks(day_starts)
    ax1.set_xticklabels(days, rotation=45)
    ax1.set_xlim([0, week_minutes])

    # Create a twin Axes sharing the x-axis to plot the attention weights
    ax2 = ax1.twinx()

    # Plot the attention as a background colormap
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])  # Only needed for older versions of matplotlib
    ax2.imshow([processed_attention_profile], aspect='auto', extent=[0, len(original_data), ax1.get_ylim()[0], ax1.get_ylim()[1]], cmap=cmap, norm=norm, alpha=0.3)
    ax2.set_yticks([])
    ax2.set_yticklabels([])

    # Add a colorbar for the attention overlay
    cbar = plt.colorbar(sm, ax=ax2, pad=0.1, aspect=10)
    cbar.set_label('Attention Weight', rotation=270, labelpad=15)

    fig.tight_layout()  # To ensure no overlap of y-ticks
    plt.show()


#@title Function 3

def downsample_to_daily(data, days=7, minutes_per_day=1440):
    # Reshape the data to (days, minutes_per_day)
    reshaped_data = data.reshape((days, minutes_per_day))
    # Calculate the mean across days for each minute
    daily_mean_data = reshaped_data.mean(axis=0)
    return daily_mean_data


def plot_data_with_attention_overlay_1day(original_data, attention_weights, layer_num, patch_size):
    # Process the attention weights
    processed_attention_profile = process_attention_weights(attention_weights, layer_num, patch_size)

    # Downsample both the original data and the attention profile to a daily scale
    daily_original_data = downsample_to_daily(original_data)
    daily_attention_profile = downsample_to_daily(processed_attention_profile)

    # Create a colormap based on the processed attention weights
    cmap = plt.cm.bwr
    norm = mcolors.Normalize(vmin=daily_attention_profile.min(), vmax=daily_attention_profile.max())

    fig, ax1 = plt.subplots(figsize=(12, 4))

    # Plot the downsampled original data
    ax1.plot(daily_original_data, label='Daily Mean Original Data', color='black')
    ax1.set_xlabel('Time (hours)')
    ax1.set_ylabel('Daily Mean Original Data Value', color='black')
    ax1.tick_params(axis='y', labelcolor='black')

    # Define custom tick positions and labels for hours
    hours = np.arange(0, 25, 3)  # Hours from 0 to 24
    minutes_per_hour = 60
    hour_positions = hours * minutes_per_hour
    hour_labels = [f'{hour:02d}:00' for hour in hours]
    ax1.set_xticks(hour_positions)
    ax1.set_xticklabels(hour_labels)

    # Create a twin Axes sharing the x-axis to plot the attention weights
    ax2 = ax1.twinx()

    # Plot the attention as a background colormap
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])  # Only needed for older versions of matplotlib
    minutes_per_day = 1440
    ax2.imshow([daily_attention_profile], aspect='auto', extent=[0, minutes_per_day, ax1.get_ylim()[0], ax1.get_ylim()[1]], cmap=cmap, norm=norm, alpha=0.5)
    ax2.set_yticks([])
    ax2.set_yticklabels([])

    # Add a colorbar for the attention overlay
    cbar = plt.colorbar(sm, ax=ax2, pad=0.1, aspect=10)
    cbar.set_label('Attention Weight', rotation=270, labelpad=15)

    fig.tight_layout()  # To ensure no overlap of y-ticks

    plt.show()

In [None]:
# Plot for control patients
i = 0
total = 0
for participant in y_test:

  # only plot for 50 participants
  if i == 50:
    break

  # if they are in the control group
  if participant != 1:
    i+=1  #increase plotted number of participants
    # get the attention weights
    predictions, *att_weights = eval_model.predict(X_test[total:total+1])
    a = scaler.inverse_transform([X_test[total].reshape(10080,)]).reshape(10080,1)

    # Make All The PLOTS
    print("=====================================================")
    print("CONTROL PARTICIPANT: " + str(total+1))
    print("=====================================================")
    plot_average_attention_weights(att_weights, encoder_num_layers-1)
    plot_data_with_attention_overlay_7d(a, att_weights, encoder_num_layers-1, patch_size)
    plot_data_with_attention_overlay_1day(a, att_weights, encoder_num_layers-1, patch_size)
  total += 1



In [None]:
# Plot for exp patients
i = 0
total = 0
for participant in y_test:

  # only plot for 50 participants
  if i == 50:
    break

  # if they are in the control group
  if participant == 1:
    i+=1  #increase plotted number of participants
    # get the attention weights
    predictions, *att_weights = eval_model.predict(X_test[total:total+1])
    a = scaler.inverse_transform([X_test[total].reshape(10080,)]).reshape(10080,1)

    # Make All The PLOTS
    print("=====================================================")
    print("NOT CONTROL PARTICIPANT: " + str(total+1))
    print("=====================================================")
    plot_average_attention_weights(att_weights, encoder_num_layers-1)
    plot_data_with_attention_overlay_7d(a, att_weights, encoder_num_layers-1, patch_size)
    plot_data_with_attention_overlay_1day(a, att_weights, encoder_num_layers-1, patch_size)
  total += 1