In [11]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.sequence import pad_sequences

In [12]:
# Constants and Encoding
START_TOKEN = 13
END_TOKEN = 14
# Add a PAD token if not using mask_zero, but we are using mask_zero=END_TOKEN
# So tokens 0-12 are source IDs, 13 is START, 14 is END (and PAD)
ENCODING_LEGEND = {
    'MRI_CCS_11': 1, 'MRI_EXU_95': 2, 'MRI_FRR_18': 3, 'MRI_FRR_257': 4,
    'MRI_FRR_264': 5, 'MRI_FRR_2': 6, 'MRI_FRR_3': 7, 'MRI_FRR_34': 8, 'MRI_MPT_1005': 9,
    'MRI_MSR_100': 10, 'MRI_MSR_104': 11, 'MRI_MSR_21': 12, 'MRI_MSR_24': 99,
    'START': START_TOKEN, 'END': END_TOKEN
}
reverse_encoding = {v: k for k, v in ENCODING_LEGEND.items()}

# Define valid source IDs for filtering (excluding START and END tokens)
VALID_SOURCE_IDS = set([k for k in ENCODING_LEGEND.keys() if k not in ['START', 'END']])

# Define the columns from the original data to keep in the final output
COLUMNS_TO_KEEP = ['timediff', 'PTAB', 'BodyGroup_from', 'BodyGroup_to', 'PatientID_from', 'PatientID_to']

# Binning parameters
NUM_BINS = 250
# Define bin edges from 0 to 1 (inclusive of 0, exclusive of 1 for all but the last bin)
# The last bin will include 1.0
BIN_EDGES = np.linspace(0.0, 1.0, NUM_BINS + 1)

In [13]:
# Replace the load_and_preprocess_data function (cell [36]) with this one.

def load_and_preprocess_data(data_file):
    """
    Loads and preprocesses data from a CSV file, filtering out invalid sourceIDs.
    Splits data into sequences based on 'MRI_MSR_104' (start) and 'MRI_MSR_100' (end).
    Assigns a sequence number during loading and keeps specified additional columns.
    """
    print(f"Loading data from {data_file}...")
    data = pd.read_csv(data_file)

    all_sequences_tokens = []
    all_sequences_times = []
    all_sequences_sourceids = []
    all_sequences_extra_data = [] # New: To store the extra columns

    current_tokens = []
    current_times = []
    current_sourceids = []
    current_extra_data = [] # New: For the current sequence

    # Iterate through rows to build sequences
    for idx, row in data.iterrows():
        s_id = str(row['sourceID'])
        t_diff = float(row['timediff'])

        if s_id not in VALID_SOURCE_IDS:
            continue

        # New: Extract extra data for the current valid row
        extra_data = {col: row.get(col) for col in COLUMNS_TO_KEEP}

        if s_id == 'MRI_MSR_104':
            if current_tokens:
                token_seq = [START_TOKEN] + [int(ENCODING_LEGEND[x]) for x in current_tokens] + [END_TOKEN]
                time_seq = [0.0] + current_times
                all_sequences_tokens.append(token_seq)
                all_sequences_times.append(time_seq)
                all_sequences_sourceids.append(current_sourceids)
                all_sequences_extra_data.append(current_extra_data) # New

            current_tokens = [s_id]
            current_times = [t_diff]
            current_sourceids = [s_id]
            current_extra_data = [extra_data] # New

        elif s_id == 'MRI_MSR_100':
             if current_tokens:
                current_tokens.append(s_id)
                current_times.append(t_diff)
                current_sourceids.append(s_id)
                current_extra_data.append(extra_data) # New

                token_seq = [START_TOKEN] + [int(ENCODING_LEGEND[x]) for x in current_tokens] + [END_TOKEN]
                time_seq = [0.0] + current_times
                all_sequences_tokens.append(token_seq)
                all_sequences_times.append(time_seq)
                all_sequences_sourceids.append(current_sourceids)
                all_sequences_extra_data.append(current_extra_data) # New

                current_tokens, current_times, current_sourceids, current_extra_data = [], [], [], [] # New

        elif current_tokens:
            current_tokens.append(s_id)
            current_times.append(t_diff)
            current_sourceids.append(s_id)
            current_extra_data.append(extra_data) # New

    if current_tokens:
         token_seq = [START_TOKEN] + [int(ENCODING_LEGEND[x]) for x in current_tokens] + [END_TOKEN]
         time_seq = [0.0] + current_times
         all_sequences_tokens.append(token_seq)
         all_sequences_times.append(time_seq)
         all_sequences_sourceids.append(current_sourceids)
         all_sequences_extra_data.append(current_extra_data) # New

    print(f"Loaded {len(all_sequences_tokens)} sequences.")
    # New: Return the extra data list as well
    return all_sequences_tokens, all_sequences_times, all_sequences_sourceids, all_sequences_extra_data

In [14]:
def get_bin_indices(proportions, bin_edges):
    """
    Maps continuous proportions to discrete bin indices.
    Handles the edge case for the maximum value (1.0).
    """
    # Use np.digitize to find the bin index for each proportion
    # digitize returns index i if bin_edges[i-1] <= x < bin_edges[i]
    # For the last bin, we want to include the upper edge (1.0)
    # np.digitize with right=False is default: bins[i-1] <= x < bins[i]
    # To include the rightmost edge in the last bin, we can adjust values >= 1.0
    proportions = np.clip(proportions, bin_edges[0], bin_edges[-1]) # Clip to [0, 1] range

    # Use right=True to include the rightmost edge in the last bin
    # bins[i-1] < x <= bins[i]
    bin_indices = np.digitize(proportions, bin_edges, right=True) - 1 # -1 because bin_edges has N+1 edges for N bins

    # Handle values exactly equal to the last edge (1.0) - np.digitize with right=True puts them in N+1 bin
    # We want them in bin N-1 (0-indexed)
    bin_indices[proportions == bin_edges[-1]] = len(bin_edges) - 2 # Index of the last bin (0-indexed)

    # Ensure indices are within valid range [0, NUM_BINS - 1]
    bin_indices = np.clip(bin_indices, 0, len(bin_edges) - 2)

    return bin_indices

def get_bin_centers(bin_indices, bin_edges):
    """
    Returns the center value for a given array of bin indices.
    """
    # Calculate bin centers as the midpoint of each bin
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
    return bin_centers[bin_indices]


def prepare_training_data(sequences_tokens, sequences_times, bin_edges):
    """
    Prepares sequences for transformer training, including padding and masks.
    Calculates target cumulative times, total times, and binned proportion targets.
    """
    X_list, Y_list, masks_list, total_times_list, Y_binned_list = [], [], [], [], []

    for tokens, times in zip(sequences_tokens, sequences_times):
        # Ensure sequence has at least START and END tokens plus one event
        if len(tokens) < 3:
            # print(f"Skipping short sequence with {len(tokens)} tokens.") # Optional: uncomment for debugging
            continue

        # The last element in times should be the cumulative time of the last event
        # which corresponds to the total time of the sequence.
        total_time = times[-1]

        # Input sequence X: START, Event1, Event2, ... EventN
        # We predict the time *until* the event represented by the input token.
        # So, the input sequence should be tokens[:-1] (START, Event1, ..., EventN-1)
        x_seq = tokens[:-1]

        # Target cumulative times Y: Time1, Time2, ... TimeN
        # These are the cumulative times *at the end* of each step.
        # These correspond to the time *at* the event represented by the token at the corresponding index in the input sequence.
        # The first target time (times[1]) corresponds to the time of the first event (input token at index 1).
        y_seq = times[1:]

        # Calculate true time differences for proportion calculation
        # time_diffs shape: (seq_len - 1) - corresponds to steps 1 to N
        # These are the durations between events: duration[i] = time[i+1] - time[i]
        time_diffs_unpadded = np.diff(times) # time_diffs[i] = times[i+1] - times[i]

        # Calculate true proportions for the steps *after* the START token
        # true_total is the last cumulative time
        true_total = times[-1]
        # Avoid division by zero
        true_total_safe = true_total if true_total > 0 else 1.0
        # true_props_unpadded shape: (seq_len - 1) - corresponds to steps 1 to N
        # These are the proportions of the total time for each time difference.
        true_props_unpadded = time_diffs_unpadded / true_total_safe

        # Pad true_props to match input sequence length (X_list)
        # The first position (corresponding to START token input at index 0) should have 0 proportion.
        # The proportions for events 1 to N (indices 1 to N in input) are in true_props_unpadded.
        true_props_padded = np.pad(true_props_unpadded, (1, 0), constant_values=0.0)

        # Bin the true proportions
        # Y_binned_seq shape: (seq_len) - corresponds to the input sequence length
        y_binned_seq = get_bin_indices(true_props_padded, bin_edges)


        # Mask: 1 for valid input tokens (not END_TOKEN), 0 otherwise
        # The mask applies to the *input* sequence (X_list).
        mask_seq = [1 if t != END_TOKEN else 0 for t in x_seq]

        X_list.append(x_seq)
        Y_list.append(y_seq) # Keep cumulative targets for CSV generation
        masks_list.append(mask_seq)
        total_times_list.append(total_time) # Keep total times for CSV generation
        Y_binned_list.append(y_binned_seq) # Add binned targets


    if not X_list:
        print("No valid sequences found after preprocessing.")
        return np.array([]), np.array([]), np.array([]), np.array([]), np.array([])


    # Determine max length based on the processed sequences
    max_len = max(len(x) for x in X_list)
    print(f"Padding sequences to max length: {max_len}")

    # Pad sequences
    # X_train: pad with END_TOKEN (mask_zero=True in embedding will ignore this)
    X_train = pad_sequences(X_list, maxlen=max_len, padding='post', value=END_TOKEN)
    # Y_cum_target: pad with 0.0
    Y_cum_target = pad_sequences(Y_list, maxlen=max_len, padding='post', value=0.0)
    # mask_train: pad with 0
    mask_train = pad_sequences(masks_list, maxlen=max_len, padding='post', value=0)
    # Y_binned_target: pad with a value that is within the valid bin range (e.g., 0)
    # We will use the mask to ignore padded positions in the loss calculation
    Y_binned_target = pad_sequences(Y_binned_list, maxlen=max_len, padding='post', value=0) # Changed padding value to 0

    X_train = np.array(X_train, dtype=np.int32)
    Y_cum_target = np.array(Y_cum_target, dtype=np.float32)
    mask_train = np.array(mask_train, dtype=np.float32)
    total_times = np.array(total_times_list, dtype=np.float32)
    Y_binned_target = np.array(Y_binned_target, dtype=np.int32) # Binned targets are integers

    print(f"Prepared {X_train.shape[0]} sequences for training.")
    return X_train, Y_cum_target, mask_train, total_times, Y_binned_target

In [15]:
# ----------------------------
# Transformer Components (unchanged)
# ----------------------------
def positional_encoding(length, depth):
    depth = depth / 2
    positions = np.arange(length)[:, np.newaxis]
    depths = np.arange(depth)[np.newaxis, :] / depth
    angle_rates = 1 / (10000 ** depths)
    angle_rads = positions * angle_rates
    pos_encoding = np.concatenate([np.sin(angle_rads), np.cos(angle_rads)], axis=-1)
    return tf.cast(pos_encoding, dtype=tf.float32)

class PositionalEmbedding(layers.Layer):
    def __init__(self, vocab_size, d_model, max_len=16384, use_embedding=True):
        super(PositionalEmbedding, self).__init__()
        self.d_model = d_model
        self.use_embedding = use_embedding
        if self.use_embedding:
            # Set mask_zero to the actual padding value (END_TOKEN)
            self.embedding = layers.Embedding(vocab_size, d_model, mask_zero=END_TOKEN)
        else:
            # If not using embedding, assume input is already dense (e.g., time features)
            self.embedding = layers.Dense(d_model, activation="relu")
        self.max_len = max_len
        # Ensure pos_encoding is created once and is large enough
        self.pos_encoding = positional_encoding(self.max_len, d_model)

    # Correct compute_mask signature to accept optional mask argument
    def compute_mask(self, x, mask=None):
         # If using embedding with mask_zero, the mask is computed based on mask_zero value
         if self.use_embedding:
              # Return a boolean mask indicating which elements are NOT the mask_zero value
              return tf.math.not_equal(x, self.embedding.mask_zero)
         # Otherwise, assume all steps are valid unless explicitly masked later
         return None

    def call(self, x):
        # x is assumed to be token IDs if use_embedding is True, otherwise dense features
        if self.use_embedding:
            # The embedding layer itself computes and propagates the mask because mask_zero is set
            x = self.embedding(x)
        else:
             # Apply dense layer if input is not token IDs
             x = self.embedding(x)

        # Scale the embedding output
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))

        # Add positional encoding
        seq_len = tf.shape(x)[1]
        # Ensure positional encoding slice matches sequence length
        x += self.pos_encoding[tf.newaxis, :seq_len, :]
        return x

class FeedForward(layers.Layer):
    def __init__(self, d_model, dff, dropout_rate=0.1):
        super().__init__()
        self.seq = tf.keras.Sequential([
            layers.Dense(dff, activation='relu'),
            layers.Dense(d_model),
            layers.Dropout(dropout_rate)
        ])
        self.add = layers.Add()
        self.layer_norm = layers.LayerNormalization()

    def call(self, x):
        # Apply feed forward network with residual connection and layer normalization
        x = self.add([x, self.seq(x)])
        x = self.layer_norm(x)
        return x

class CausalSelfAttention(layers.Layer):
    def __init__(self, num_heads, d_model, dropout_rate=0.1):
        super().__init__()
        # MultiHeadAttention layer with causal mask
        self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model, dropout=dropout_rate)
        self.add = layers.Add()
        self.layer_norm = layers.LayerNormalization()

    def call(self, x):
        # Apply multi-head self-attention
        # Keras automatically uses the mask attached to the input 'x'
        attn_output = self.mha(query=x, key=x, value=x, use_causal_mask=True)
        # Add residual connection and layer normalization
        x = self.add([x, attn_output])
        x = self.layer_norm(x)
        return x

class SelfAttentionFeedForwardLayer(layers.Layer):
    def __init__(self, d_model, num_heads, dff, dropout_rate=0.1):
        super().__init__()
        # Composes CausalSelfAttention and FeedForward layers
        self.self_attention = CausalSelfAttention(num_heads=num_heads, d_model=d_model, dropout_rate=dropout_rate)
        self.ffn = FeedForward(d_model, dff, dropout_rate)

    def call(self, x):
        # Pass input through self-attention and then feed-forward network
        # Mask from 'x' is propagated through these layers
        x = self.self_attention(x)
        x = self.ffn(x)
        return x

class Encoder(tf.keras.Model):
    def __init__(self, num_layers, d_model, num_heads, dff, vocab_size, dropout_rate=0.1, max_len=16384):
        super().__init__()
        # Positional embedding for the input tokens
        self.pos_embedding = PositionalEmbedding(vocab_size, d_model, max_len=max_len)
        # Stack of encoder layers
        self.enc_layers = [SelfAttentionFeedForwardLayer(d_model, num_heads, dff, dropout_rate)
                           for _ in range(num_layers)]
        self.dropout = layers.Dropout(dropout_rate)

    def call(self, x):
        # Apply positional embedding and dropout.
        # The output 'x' from pos_embedding will carry the mask computed by PositionalEmbedding.compute_mask.
        x = self.pos_embedding(x)
        x = self.dropout(x)

        # Pass through encoder layers. Keras will automatically propagate the mask
        # through the layers that support masking (like MultiHeadAttention).
        for layer in self.enc_layers:
            x = layer(x)

        return x # The output tensor carries the mask

In [16]:
class TimeDiffTransformer(tf.keras.Model):
    """
    Transformer model predicting proportions of total time for each sequence step.
    This version predicts a probability distribution over bins for proportions.
    """
    def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, num_bins, dropout_rate=0.1, max_len=16384):
        super().__init__()
        # Encoder processes the input sequence of tokens
        self.encoder = Encoder(num_layers, d_model, num_heads, dff, input_vocab_size, dropout_rate, max_len)

        # Head to predict the probability distribution over bins for proportions
        # Output is NUM_BINS values per sequence step with softmax activation
        self.proportion_head = layers.Dense(num_bins, activation='softmax')

    def call(self, inputs):
        # Pass input through the encoder
        encoder_out = self.encoder(inputs) # encoder_out shape: (batch_size, seq_len, d_model)
        # The mask from the embedding layer is propagated to encoder_out

        # Predict probability distribution over bins for each step
        # pred_bin_probs shape: (batch_size, seq_len, num_bins)
        pred_bin_probs = self.proportion_head(encoder_out)

        # Return the predicted bin probabilities
        return pred_bin_probs # pred_bin_probs shape: (batch_size, seq_len, num_bins)

In [17]:
# Replace the train_transformer function (cell [42]) with this one.

def train_transformer(data_file, epochs=50, batch_size=32, num_bins=NUM_BINS, bin_edges=BIN_EDGES):
    """
    Trains the TimeDiffTransformer model with proportion binning.
    """
    try:
        # Load and preprocess data, now also gets sequences_extra_data
        sequences_tokens, sequences_times, sequences_sourceids, sequences_extra_data = load_and_preprocess_data(data_file)

        # Prepare data for training (this function does not need the extra data)
        X_train, Y_cum_target, mask_train, total_times, Y_binned_target = prepare_training_data(
            sequences_tokens, sequences_times, bin_edges
        )

        if X_train.shape[0] == 0:
            print("No data available for training after preprocessing.")
            return None, None, None, None, None, None, None, None

        # --- Model definition, optimizer, loss, and training loop remain the same ---
        vocab_size = max(ENCODING_LEGEND.values()) + 1
        max_seq_len = X_train.shape[1]
        model = TimeDiffTransformer(
            num_layers=3, d_model=64, num_heads=8, dff=128,
            input_vocab_size=vocab_size, num_bins=num_bins,
            dropout_rate=0.1, max_len=max_seq_len
        )
        optimizer = tf.keras.optimizers.Adam()
        proportion_loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)

        @tf.function
        def train_step(x, y_binned, mask):
            with tf.GradientTape() as tape:
                pred_bin_probs = model(x)
                mask_float = tf.cast(mask, tf.float32)
                masked_props_loss = proportion_loss_fn(y_binned, pred_bin_probs, sample_weight=mask_float)
                total_loss = masked_props_loss
            grads = tape.gradient(total_loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
            return total_loss, masked_props_loss

        print("Starting training...")
        train_dataset = tf.data.Dataset.from_tensor_slices((X_train, Y_binned_target, mask_train)).batch(batch_size)

        for epoch in range(epochs):
            total_epoch_loss, total_proportion_loss, num_batches = 0, 0, 0
            for step, (batch_x, batch_y_binned, batch_mask) in enumerate(train_dataset):
                loss, props_loss = train_step(batch_x, batch_y_binned, batch_mask)
                total_epoch_loss += loss
                total_proportion_loss += props_loss
                num_batches += 1
            avg_epoch_loss = total_epoch_loss / num_batches if num_batches > 0 else 0
            avg_proportion_loss = total_proportion_loss / num_batches if num_batches > 0 else 0
            print(f"Epoch {epoch+1}/{epochs} - Total Loss: {avg_epoch_loss.numpy():.4f} - Proportion Loss: {avg_proportion_loss.numpy():.4f}")
        # --- End of unchanged training section ---

        print("Training finished.")
        # New: Return the extra data along with other results
        return model, X_train, Y_cum_target, mask_train, total_times, sequences_sourceids, Y_binned_target, sequences_extra_data

    except Exception as e:
        print(f"Error in train_transformer: {e}")
        import traceback
        traceback.print_exc()
        return None, None, None, None, None, None, None, None

In [18]:
def compute_time_differences(proportions, total_time, mask):
    """
    Computes predicted increments and cumulative times from proportions and total time.
    Applies masking to ignore padded steps.

    Args:
        proportions: Predicted proportions for each step (batch_size, seq_len).
        total_time: The total time for each sequence (batch_size, 1).
        mask: Mask indicating valid steps (batch_size, seq_len).

    Returns:
        proportions: Normalized proportions (batch_size, seq_len).
        increments: Predicted time increments (batch_size, seq_len).
        cumulative_times: Predicted cumulative times (batch_size, seq_len).
    """
    # Apply mask to ensure only valid tokens contribute to calculations
    proportions *= tf.cast(mask, tf.float32)

    # Compute row-wise sum for normalization to handle variable-length sequences
    # Sum across the sequence length dimension (axis=1)
    row_sums = tf.reduce_sum(proportions, axis=1, keepdims=True)
    # Prevent division by zero if a sequence is entirely masked (shouldn't happen with START token)
    row_sums = tf.where(tf.equal(row_sums, 0), tf.ones_like(row_sums), row_sums)

    # Normalize proportions so they sum to 1 over the valid (unmasked) steps
    proportions /= row_sums

    # Compute increments by multiplying normalized proportions by the total time
    # total_time should have shape (batch_size, 1) for correct broadcasting
    increments = proportions * total_time # Broadcasting total_time

    # Compute cumulative times by summing increments along the sequence dimension
    cumulative_times = tf.math.cumsum(increments, axis=1)

    return proportions, increments, cumulative_times

In [None]:
# Replace the generate_predictions_csv function (cell [43]) with this one.

def generate_predictions_csv(model, X_train, Y_cum_target, mask_train, total_times, sequences_sourceids, sequences_extra_data, bin_edges):
    """
    Generates predictions and saves to CSV, including additional original data columns.
    """
    if model is None:
        print("Model is None, cannot generate predictions.")
        return pd.DataFrame()

    print("Generating predictions...")

    # --- Prediction and calculation logic remains the same ---
    pred_bin_probs = model(X_train)
    predicted_bin_indices = tf.argmax(pred_bin_probs, axis=-1, output_type=tf.int32)
    predicted_proportions_continuous = get_bin_centers(predicted_bin_indices.numpy(), bin_edges)
    total_times_tf = tf.constant(total_times, dtype=tf.float32)
    total_times_expanded = tf.expand_dims(total_times_tf, axis=1)
    proportions_pred_norm, increments_pred, cumulative_pred = compute_time_differences(
        tf.constant(predicted_proportions_continuous, dtype=tf.float32),
        total_times_expanded,
        mask_train
    )
    proportions_pred_np = proportions_pred_norm.numpy()
    increments_pred_np = increments_pred.numpy()
    cumulative_pred_np = cumulative_pred.numpy()
    X_train_np, Y_cum_target_np, mask_train_np = X_train, Y_cum_target, mask_train
    gt_increments = np.zeros_like(Y_cum_target_np)
    gt_increments[:, 0] = Y_cum_target_np[:, 0]
    gt_increments[:, 1:] = Y_cum_target_np[:, 1:] - Y_cum_target_np[:, :-1]
    gt_increments *= mask_train_np
    # --- End of unchanged prediction logic ---

    output_records = []
    for seq_idx in range(X_train_np.shape[0]):
        valid_mask = mask_train_np[seq_idx] == 1
        valid_indices = np.where(valid_mask)[0]
        safe_sourceids = sequences_sourceids[seq_idx] if seq_idx < len(sequences_sourceids) else []
        safe_extra_data = sequences_extra_data[seq_idx] if seq_idx < len(sequences_extra_data) else [] # New
        step_counter = 1

        for i in range(len(valid_indices)):
            valid_idx = valid_indices[i]
            if valid_idx > 0:
                source_id_index = valid_idx - 1
                if source_id_index < len(safe_sourceids):
                    source_id = safe_sourceids[source_id_index]
                    
                    # New: Get the extra data for this specific step
                    extra_data_record = safe_extra_data[source_id_index] if source_id_index < len(safe_extra_data) else {}

                    # Create the base record with predictions and ground truth
                    record = {
                        'Sequence': seq_idx,
                        'Step': step_counter,
                        'SourceID': source_id,
                        'Predicted_Proportion': proportions_pred_np[seq_idx, valid_idx],
                        'Predicted_Increment': increments_pred_np[seq_idx, valid_idx],
                        'Predicted_Cumulative': cumulative_pred_np[seq_idx, valid_idx],
                        'GroundTruth_Increment': gt_increments[seq_idx, valid_idx],
                        'GroundTruth_Cumulative': Y_cum_target_np[seq_idx, valid_idx]
                    }
                    
                    # New: Merge the extra data into the record
                    record.update(extra_data_record)
                    output_records.append(record)
                    step_counter += 1

    # New: Define final column order, including the kept columns
    final_column_order = [
        'Sequence', 'Step', 'SourceID', 'Predicted_Proportion',
        'Predicted_Increment', 'Predicted_Cumulative',
        'GroundTruth_Increment', 'GroundTruth_Cumulative'
    ] + COLUMNS_TO_KEEP

    if not output_records:
        print("Warning: No valid prediction records generated.")
        predictions_df = pd.DataFrame(columns=final_column_order)
    else:
        predictions_df = pd.DataFrame(output_records)
        # Reorder columns and ensure all are present
        existing_cols = [col for col in final_column_order if col in predictions_df.columns]
        predictions_df = predictions_df[existing_cols]

    output_csv_path = 'predictions_transformer_182625_with_details.csv' # New output filename
    try:
        predictions_df.to_csv(output_csv_path, index=False)
        print(f"Predictions saved successfully to {output_csv_path}")
    except Exception as e:
        print(f"Error saving predictions to CSV: {e}")

    return predictions_df

In [20]:
# Replace the main function call (cell [44]) with this one.

def main():
    """
    Main function to run the training and prediction process.
    """
    try:
        data_file = "data/182625/encoded_182625_condensed.csv"
        if not os.path.exists(data_file):
            print(f"Error: Data file not found at {data_file}")
            return

        # New: Unpack the extra data from the training result
        result = train_transformer(data_file, epochs=50, num_bins=NUM_BINS, bin_edges=BIN_EDGES)

        if result is None or result[0] is None:
            print("Model training failed or no data was available. Exiting.")
            return

        model, X_train, Y_cum_target, mask_train, total_times, sequences_sourceids, Y_binned_target, sequences_extra_data = result

        # New: Pass the extra data to the prediction function
        predictions_df = generate_predictions_csv(
            model, X_train, Y_cum_target, mask_train, total_times, sequences_sourceids, sequences_extra_data, BIN_EDGES
        )

        if not predictions_df.empty:
            print("\nSample Predictions:")
            # Display a subset of columns for readability in the console
            display_cols = ['Sequence', 'Step', 'SourceID', 'Predicted_Cumulative', 'GroundTruth_Cumulative'] + COLUMNS_TO_KEEP
            print(predictions_df[display_cols].head(10))
        else:
            print("\nNo predictions were generated.")

    except Exception as e:
        print(f"Error in main: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()

Loading data from data/182625/encoded_182625_condensed.csv...
Loaded 186 sequences.
Padding sequences to max length: 43
Prepared 186 sequences for training.
Starting training...




Epoch 1/50 - Total Loss: 4.5470 - Proportion Loss: 4.5470
Epoch 2/50 - Total Loss: 3.8672 - Proportion Loss: 3.8672
Epoch 3/50 - Total Loss: 3.5588 - Proportion Loss: 3.5588
Epoch 4/50 - Total Loss: 3.3206 - Proportion Loss: 3.3206
Epoch 5/50 - Total Loss: 3.1682 - Proportion Loss: 3.1682
Epoch 6/50 - Total Loss: 3.0823 - Proportion Loss: 3.0823
Epoch 7/50 - Total Loss: 3.0330 - Proportion Loss: 3.0330
Epoch 8/50 - Total Loss: 3.0041 - Proportion Loss: 3.0041
Epoch 9/50 - Total Loss: 2.9837 - Proportion Loss: 2.9837
Epoch 10/50 - Total Loss: 2.9660 - Proportion Loss: 2.9660
Epoch 11/50 - Total Loss: 2.9486 - Proportion Loss: 2.9486
Epoch 12/50 - Total Loss: 2.9293 - Proportion Loss: 2.9293
Epoch 13/50 - Total Loss: 2.9076 - Proportion Loss: 2.9076
Epoch 14/50 - Total Loss: 2.8839 - Proportion Loss: 2.8839
Epoch 15/50 - Total Loss: 2.8533 - Proportion Loss: 2.8533
Epoch 16/50 - Total Loss: 2.8295 - Proportion Loss: 2.8295
Epoch 17/50 - Total Loss: 2.8060 - Proportion Loss: 2.8060
Epoch 