In [31]:
import os
import sys
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
# Assuming pad_sequences is available if needed, though direct use might change.
# from tensorflow.keras.preprocessing.sequence import pad_sequences
import matplotlib.pyplot as plt
import seaborn as sns # Seaborn was commented out in original, keeping it that way

In [None]:
# Helper function to compute increments and cumulative times from proportions and total time
def calculate_times_from_proportions(proportions_per_step, total_time_for_sequence, mask_per_step):
    """
    Calculates time increments and cumulative times from step-wise proportions and a total time.

    Args:
        proportions_per_step (tf.Tensor or np.ndarray): 
            Proportions for each step in sequences (batch_size, seq_len).
        total_time_for_sequence (tf.Tensor or np.ndarray): 
            The total time for each sequence (batch_size, 1) or (batch_size,).
        mask_per_step (tf.Tensor or np.ndarray): 
            Mask indicating valid steps (batch_size, seq_len).

    Returns:
        tuple: (normalized_proportions, increments, cumulative_times)
               all as tf.Tensor.
    """
    proportions_tf = tf.cast(proportions_per_step, tf.float32)
    total_time_tf = tf.cast(total_time_for_sequence, tf.float32)
    mask_tf = tf.cast(mask_per_step, tf.float32)

    # Ensure total_time_tf is (batch_size, 1) for broadcasting
    if len(tf.shape(total_time_tf)) == 1:
        total_time_tf = tf.expand_dims(total_time_tf, axis=-1)

    # Apply mask to proportions
    masked_proportions = proportions_tf * mask_tf

    # Normalize proportions over valid steps
    row_sums = tf.reduce_sum(masked_proportions, axis=1, keepdims=True)
    row_sums = tf.where(tf.equal(row_sums, 0), tf.ones_like(row_sums), row_sums) # Avoid division by zero
    normalized_proportions = masked_proportions / row_sums

    # Calculate increments
    increments = normalized_proportions * total_time_tf # Broadcasting

    # Calculate cumulative times
    cumulative_times = tf.cumsum(increments, axis=1)
    
    # Ensure masked steps in final outputs are zero
    increments *= mask_tf
    cumulative_times *= mask_tf
    normalized_proportions *= mask_tf

    return normalized_proportions, increments, cumulative_times

# %%
class TotalTimeLSTM(tf.keras.Model):
    """
    LSTM model to predict only the total time of a sequence,
    using features that can include Transformer-predicted proportions.
    """
    def __init__(self, hidden_units=64, num_heads=4, dropout_rate=0.2):
        super(TotalTimeLSTM, self).__init__()
        
        self.hidden_units = hidden_units
        self.num_heads = num_heads

        # LSTM for sequence processing
        self.lstm_layer = layers.LSTM(self.hidden_units, 
                                      return_sequences=True, 
                                      dropout=dropout_rate,
                                      recurrent_dropout=dropout_rate,
                                      name="lstm_1")
        
        # Bidirectional LSTM. Output dim: 2 * hidden_units
        self.bi_lstm = layers.Bidirectional(
            layers.LSTM(self.hidden_units, return_sequences=True, name="lstm_bidirectional_inner"),
            name="bidirectional_lstm_1"
        )
        
        # Attention mechanism
        mha_key_dim = (2 * self.hidden_units) // self.num_heads
        if (2 * self.hidden_units) % self.num_heads != 0:
            # This check ensures that the output dimension of MHA can match BiLSTM output
            raise ValueError(f"(2 * hidden_units) must be divisible by num_heads for the intended MHA output dimension. "
                             f"Got 2 * {self.hidden_units} (={2*self.hidden_units}) and num_heads={self.num_heads}.")

        self.attention = layers.MultiHeadAttention(
            num_heads=self.num_heads, key_dim=mha_key_dim, name="multi_head_attention_1"
        )
        self.layer_norm_attn = layers.LayerNormalization(name="layer_norm_attention")

        self.global_avg_pool = layers.GlobalAveragePooling1D(name="global_avg_pooling_1d")
        self.total_time_head = layers.Dense(1, activation='relu', name="total_time_dense_output") 
        
    def call(self, inputs, training=False): 
        # inputs shape: (batch_size, seq_len, num_features)
        
        # Create a boolean mask from inputs. Assumes padding is all zeros.
        # True for non-padded, False for padded.
        mask_bool = tf.reduce_any(tf.not_equal(inputs, 0.0), axis=-1)

        lstm_out = self.lstm_layer(inputs, mask=mask_bool, training=training)
        bi_lstm_out = self.bi_lstm(lstm_out, mask=mask_bool, training=training)
        
        # Prepare attention mask for MHA: (batch_size, 1, 1, seq_length)
        # This will be broadcast to (batch_size, num_heads, query_seq_length, key_seq_length)
        # True means keep, False means mask out.
        # MHA expects False for positions to mask. So, if mask_bool is True for valid,
        # we might need to invert it depending on MHA's interpretation or use it as is if it masks where mask is False.
        # Keras MHA `attention_mask`: True for allowed, False for masked.
        # Our `mask_bool` is True for allowed. So, it should be fine.
        # The shape should be (B, T, S) or (B, N, T, S). For self-attention, T=S.
        # Mask for MHA should indicate which tokens are padding.
        # (B, S) -> (B, 1, 1, S) for key padding mask.
        mha_attention_mask = mask_bool[:, tf.newaxis, tf.newaxis, :] # (batch, 1, 1, key_seq_len)

        attn_output = self.attention(query=bi_lstm_out, value=bi_lstm_out, key=bi_lstm_out, 
                                     attention_mask=mha_attention_mask, 
                                     training=training)
        
        x = self.layer_norm_attn(attn_output + bi_lstm_out) # Residual connection
        
        sequence_encoding = self.global_avg_pool(x, mask=mask_bool)
        
        total_time_pred = self.total_time_head(sequence_encoding) 
        
        return total_time_pred

# %%
def process_input_data_for_lstm(transformer_predictions_file):
    """
    Process the transformer predictions CSV to prepare data for LSTM training.
    """
    print(f"Processing data from: {transformer_predictions_file}")
    if not os.path.exists(transformer_predictions_file):
        raise FileNotFoundError(f"Transformer predictions file not found: {transformer_predictions_file}")
    df = pd.read_csv(transformer_predictions_file)
    
    if 'Predicted_Proportion' not in df.columns:
        raise ValueError("CSV must contain 'Predicted_Proportion' column from Transformer.")
    if 'GroundTruth_Cumulative' not in df.columns:
        raise ValueError("CSV must contain 'GroundTruth_Cumulative' column for target extraction.")
    if 'GroundTruth_Increment' not in df.columns:
         raise ValueError("CSV must contain 'GroundTruth_Increment' column.")


    sequences = df['Sequence'].unique()
    
    X_data_list = []
    y_total_times_list = []
    masks_list = [] # This will store boolean masks (True for valid)
    transformer_proportions_list = [] 
    ground_truth_increments_list = []
    ground_truth_cumulative_list = []
    original_dfs_list = [] 

    for seq_id in sequences:
        seq_df = df[df['Sequence'] == seq_id].sort_values('Step').copy()
        
        if seq_df.empty:
            print(f"Warning: Sequence {seq_id} is empty. Skipping.")
            original_dfs_list.append(seq_df) # Append empty df to keep counts consistent if needed
            continue
        
        original_dfs_list.append(seq_df) 

        current_max_steps = seq_df['Step'].max()
        if current_max_steps == 0: current_max_steps = 1 
        
        features = np.column_stack([
            seq_df['Predicted_Proportion'].values,
            seq_df['Step'].values / current_max_steps 
        ])
        
        X_data_list.append(features)
        
        gt_cumulative_for_seq = seq_df['GroundTruth_Cumulative'].values
        total_time_for_seq = gt_cumulative_for_seq[-1] if len(gt_cumulative_for_seq) > 0 else 0.0
        y_total_times_list.append(total_time_for_seq)
        
        masks_list.append(np.ones(len(features), dtype=bool)) # Store boolean True for valid steps
        
        transformer_proportions_list.append(seq_df['Predicted_Proportion'].values)
        ground_truth_increments_list.append(seq_df['GroundTruth_Increment'].values)
        ground_truth_cumulative_list.append(gt_cumulative_for_seq)

    if not X_data_list:
        raise ValueError("No valid sequences processed. Check CSV content or processing logic.")

    max_length = max(len(x) for x in X_data_list)
    num_features = X_data_list[0].shape[1]

    X_padded = np.zeros((len(X_data_list), max_length, num_features), dtype=np.float32)
    # masks_padded stores the float mask for calculate_times_from_proportions
    masks_padded_float = np.zeros((len(X_data_list), max_length), dtype=np.float32) 
    transformer_proportions_padded = np.zeros((len(X_data_list), max_length), dtype=np.float32)
    gt_increments_padded = np.zeros((len(X_data_list), max_length), dtype=np.float32)
    gt_cumulative_padded = np.zeros((len(X_data_list), max_length), dtype=np.float32)

    for i in range(len(X_data_list)):
        seq_len = len(X_data_list[i])
        if seq_len > 0:
            X_padded[i, :seq_len, :] = X_data_list[i]
            masks_padded_float[i, :seq_len] = 1.0 # Float mask for helper
            transformer_proportions_padded[i, :seq_len] = transformer_proportions_list[i]
            gt_increments_padded[i, :seq_len] = ground_truth_increments_list[i]
            gt_cumulative_padded[i, :seq_len] = ground_truth_cumulative_list[i]
        
    y_total_times_np = np.array(y_total_times_list, dtype=np.float32)

    return {
        'X_lstm_input': X_padded,
        'y_lstm_target_total_times': y_total_times_np,
        'masks_for_calc': masks_padded_float, # Float mask for calculate_times_from_proportions
        'sequences_ids': sequences, 
        'original_dfs': original_dfs_list, 
        'transformer_proportions_padded': transformer_proportions_padded, 
        'gt_increments_padded': gt_increments_padded, 
        'gt_cumulative_padded': gt_cumulative_padded, 
        'max_len': max_length,
        'num_features': num_features
    }

# %%
def train_total_time_lstm(transformer_predictions_file, epochs=50, batch_size=32):
    """
    Train LSTM model to predict total_time.
    """
    print("Processing data for LSTM training...")
    data_for_lstm = process_input_data_for_lstm(transformer_predictions_file)
    
    print(f"Number of features for LSTM input: {data_for_lstm['num_features']}")
    print(f"Max sequence length for LSTM input: {data_for_lstm['max_len']}")

    lstm_model = TotalTimeLSTM(hidden_units=64, num_heads=4, dropout_rate=0.2) 
    
    # Explicitly build the model
    # Batch dimension is None, seq_len is max_len, features is num_features
    input_shape = (None, data_for_lstm['max_len'], data_for_lstm['num_features'])
    lstm_model.build(input_shape=input_shape) 
    
    print("LSTM Model Summary:")
    lstm_model.summary() 

    lstm_model.compile(
        optimizer=Adam(learning_rate=0.001), 
        loss='mse' 
    )
    
    y_train_lstm = data_for_lstm['y_lstm_target_total_times']
    X_train_lstm = data_for_lstm['X_lstm_input']
    
    # Check for NaN/inf in inputs and targets
    if np.any(np.isnan(X_train_lstm)) or np.any(np.isinf(X_train_lstm)):
        print("Warning: NaN or Inf found in X_train_lstm input.")
    if np.any(np.isnan(y_train_lstm)) or np.any(np.isinf(y_train_lstm)):
        print("Warning: NaN or Inf found in y_train_lstm targets.")


    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=10, 
        restore_best_weights=True,
        verbose=1
    )
    
    print("Starting LSTM model training...")
    
    history = lstm_model.fit(
        X_train_lstm, 
        y_train_lstm,          
        epochs=epochs,
        batch_size=batch_size,
        validation_split=0.2, 
        callbacks=[early_stopping],
        verbose=1
    )
    
    print("LSTM training finished.")
    return lstm_model, history, data_for_lstm

# %%
def generate_refined_predictions_with_lstm(lstm_model, processed_data):
    """
    Generate refined time predictions using LSTM's total_time.
    """
    print("Generating refined predictions using LSTM's total time...")
    
    X_input_for_prediction = processed_data['X_lstm_input']
    lstm_predicted_total_times = lstm_model.predict(X_input_for_prediction) 
    lstm_predicted_total_times = np.squeeze(lstm_predicted_total_times) 

    transformer_step_proportions = processed_data['transformer_proportions_padded'] 
    # Use the float mask prepared for calculate_times_from_proportions
    masks_for_calc = processed_data['masks_for_calc'] 

    _, lstm_refined_increments, lstm_refined_cumulative = calculate_times_from_proportions(
        transformer_step_proportions,
        lstm_predicted_total_times, 
        masks_for_calc # Pass the correct mask here
    )

    lstm_refined_increments_np = lstm_refined_increments.numpy()
    lstm_refined_cumulative_np = lstm_refined_cumulative.numpy()

    results_list_df = []
    original_dfs = processed_data['original_dfs'] 

    for i, seq_id in enumerate(processed_data['sequences_ids']):
        # Ensure index i is within bounds for original_dfs
        if i >= len(original_dfs):
            print(f"Warning: Index {i} out of bounds for original_dfs. Skipping sequence ID {seq_id}.")
            continue
        original_seq_df = original_dfs[i].copy() 
        seq_len = len(original_seq_df)

        if seq_len == 0:
            # This sequence was empty in the original data, skip adding LSTM predictions to it
            # but it might be part of final_results_df if original_dfs included it.
            # Best to ensure original_dfs only contains non-empty DFs if they are used this way.
            # For now, if it's empty, we just append it as is or skip.
            # If process_input_data_for_lstm filters them, this won't be an issue.
            if original_seq_df.empty: # If it was truly empty and not just filtered later
                 results_list_df.append(original_seq_df) # Add empty if it was there
            continue


        # Ensure index i is within bounds for lstm_predicted_total_times
        if i >= len(lstm_predicted_total_times):
            print(f"Warning: Index {i} out of bounds for lstm_predicted_total_times. Skipping sequence ID {seq_id}.")
            continue

        original_seq_df['LSTM_Predicted_TotalTime'] = lstm_predicted_total_times[i]
        original_seq_df['LSTM_Predicted_Increment'] = lstm_refined_increments_np[i, :seq_len]
        original_seq_df['LSTM_Predicted_Cumulative'] = lstm_refined_cumulative_np[i, :seq_len]
        
        # MAE Calculation and Improvement Pct
        if 'GroundTruth_Increment' in original_seq_df.columns and 'Predicted_Increment' in original_seq_df.columns:
            gt_increment = original_seq_df['GroundTruth_Increment'].fillna(0) # Handle potential NaNs
            transformer_pred_increment = original_seq_df['Predicted_Increment'].fillna(0)
            lstm_pred_increment = original_seq_df['LSTM_Predicted_Increment'].fillna(0)
            
            diff_transformer = np.abs(gt_increment - transformer_pred_increment)
            diff_lstm = np.abs(gt_increment - lstm_pred_increment)
            
            original_seq_df['Increment_MAE_Transformer'] = diff_transformer
            original_seq_df['Increment_MAE_LSTM'] = diff_lstm
            
            original_seq_df['Increment_Improvement_Pct'] = np.where(
                diff_transformer > 1e-6, 
                (diff_transformer - diff_lstm) / diff_transformer * 100,
                0 
            )
        
        if 'GroundTruth_Cumulative' in original_seq_df.columns and 'Predicted_Cumulative' in original_seq_df.columns:
            gt_cumulative = original_seq_df['GroundTruth_Cumulative'].fillna(0)
            transformer_pred_cumulative = original_seq_df['Predicted_Cumulative'].fillna(0)
            lstm_pred_cumulative = original_seq_df['LSTM_Predicted_Cumulative'].fillna(0)

            diff_transformer_cum = np.abs(gt_cumulative - transformer_pred_cumulative)
            diff_lstm_cum = np.abs(gt_cumulative - lstm_pred_cumulative)

            original_seq_df['Cumulative_MAE_Transformer'] = diff_transformer_cum
            original_seq_df['Cumulative_MAE_LSTM'] = diff_lstm_cum

            original_seq_df['Cumulative_Improvement_Pct'] = np.where(
                diff_transformer_cum > 1e-6,
                (diff_transformer_cum - diff_lstm_cum) / diff_transformer_cum * 100,
                0
            )
        results_list_df.append(original_seq_df)
    
    if not results_list_df:
        print("Warning: No results to concatenate for the final CSV after processing sequences.")
        return pd.DataFrame()

    final_results_df = pd.concat(results_list_df, ignore_index=True)
    
    output_filename = 'predictions_lstm_refined_total_time_approach.csv'
    final_results_df.to_csv(output_filename, index=False)
    print(f"Combined and refined predictions saved to {output_filename}")
    
    # Overall Performance Metrics
    if not final_results_df.empty:
        if 'Increment_MAE_Transformer' in final_results_df.columns and 'Increment_MAE_LSTM' in final_results_df.columns:
            avg_transformer_inc_mae = final_results_df['Increment_MAE_Transformer'].mean()
            avg_lstm_inc_mae = final_results_df['Increment_MAE_LSTM'].mean()
            print("\n--- Mean Absolute Error for Increments ---")
            print(f"Transformer MAE (Increments): {avg_transformer_inc_mae:.4f}")
            print(f"LSTM-Refined MAE (Increments): {avg_lstm_inc_mae:.4f}")
            if avg_transformer_inc_mae > 1e-6: # Avoid division by zero
                improvement_inc = (avg_transformer_inc_mae - avg_lstm_inc_mae) / avg_transformer_inc_mae * 100
                print(f"Improvement (Increments): {improvement_inc:.2f}%")

        if 'Cumulative_MAE_Transformer' in final_results_df.columns and 'Cumulative_MAE_LSTM' in final_results_df.columns:
            avg_transformer_cum_mae = final_results_df['Cumulative_MAE_Transformer'].mean()
            avg_lstm_cum_mae = final_results_df['Cumulative_MAE_LSTM'].mean()
            print("\n--- Mean Absolute Error for Cumulative Times ---")
            print(f"Transformer MAE (Cumulative): {avg_transformer_cum_mae:.4f}")
            print(f"LSTM-Refined MAE (Cumulative): {avg_lstm_cum_mae:.4f}")
            if avg_transformer_cum_mae > 1e-6: # Avoid division by zero
                improvement_cum = (avg_transformer_cum_mae - avg_lstm_cum_mae) / avg_transformer_cum_mae * 100
                print(f"Improvement (Cumulative): {improvement_cum:.2f}%")

    gt_total_times_all_seqs = processed_data['y_lstm_target_total_times']
    # Ensure lstm_predicted_total_times has the same length as gt_total_times_all_seqs
    if len(lstm_predicted_total_times) == len(gt_total_times_all_seqs):
        mae_total_time_lstm = np.mean(np.abs(gt_total_times_all_seqs - lstm_predicted_total_times))
        print("\n--- LSTM Total Time Prediction Performance ---")
        print(f"MAE for LSTM Predicted Total Time (vs GT Total Time): {mae_total_time_lstm:.4f}")
    else:
        print("\nWarning: Mismatch in lengths for GT total times and predicted total times. Cannot compute overall MAE for total time.")


    return final_results_df

# %%
def visualize_lstm_results(results_df, processed_data, lstm_model, training_history):
    """
    Generate visualizations for the LSTM model that predicts total time.
    """
    if results_df.empty:
        print("Results DataFrame is empty, skipping visualizations.")
        return
        
    print("Generating visualizations for LSTM results...")
    plt.style.use('ggplot')

    if training_history and hasattr(training_history, 'history'):
        if 'loss' in training_history.history and 'val_loss' in training_history.history:
            plt.figure(figsize=(10, 5))
            plt.plot(training_history.history['loss'], label='Training Loss')
            plt.plot(training_history.history['val_loss'], label='Validation Loss')
            plt.title('LSTM Model Loss (Predicting Total Time)')
            plt.xlabel('Epoch')
            plt.ylabel('Mean Squared Error (Loss)')
            plt.legend()
            plt.tight_layout()
            plt.savefig('lstm_total_time_training_loss.png')
            print("Saved LSTM training loss plot.")
            plt.close()
        else:
            print("Warning: Training history does not contain 'loss' or 'val_loss' keys.")
    else:
        print("Warning: No training history provided or history object is not as expected.")


    sample_sequence_ids = results_df['Sequence'].unique()
    if len(sample_sequence_ids) == 0 :
        print("No sequences found in results_df for plotting.")
        return
    sample_sequence_ids = sample_sequence_ids[:min(5, len(sample_sequence_ids))]
    
    if len(sample_sequence_ids) > 0:
        num_plots = len(sample_sequence_ids)
        fig_height_cumulative = max(8, 3 * num_plots) 
        plt.figure(figsize=(15, fig_height_cumulative))
        for i, seq_id in enumerate(sample_sequence_ids):
            seq_data_plot = results_df[results_df['Sequence'] == seq_id].sort_values('Step')
            if seq_data_plot.empty: continue
            plt.subplot(num_plots, 1, i + 1)
            plt.plot(seq_data_plot['Step'], seq_data_plot['GroundTruth_Cumulative'], 'o-', label='Ground Truth Cumul.', markersize=4)
            if 'Predicted_Cumulative' in seq_data_plot.columns:
                 plt.plot(seq_data_plot['Step'], seq_data_plot['Predicted_Cumulative'], 's--', label='Transformer Pred. Cumul.', markersize=4)
            if 'LSTM_Predicted_Cumulative' in seq_data_plot.columns:
                plt.plot(seq_data_plot['Step'], seq_data_plot['LSTM_Predicted_Cumulative'], '^-.', label='LSTM-Refined Pred. Cumul.', markersize=4)
            else:
                print(f"Warning: LSTM_Predicted_Cumulative not found for seq {seq_id}")
            plt.title(f'Cumulative Times: Sequence {seq_id}')
            plt.xlabel('Step')
            plt.ylabel('Cumulative Time')
            plt.legend()
        plt.tight_layout()
        plt.savefig('lstm_refined_cumulative_time_comparison.png')
        print("Saved LSTM-refined cumulative time comparison plot.")
        plt.close()

        fig_height_increment = max(8, 3 * num_plots)
        plt.figure(figsize=(15, fig_height_increment))
        for i, seq_id in enumerate(sample_sequence_ids):
            seq_data_plot = results_df[results_df['Sequence'] == seq_id].sort_values('Step')
            if seq_data_plot.empty: continue
            plt.subplot(num_plots, 1, i + 1)
            plt.plot(seq_data_plot['Step'], seq_data_plot['GroundTruth_Increment'], 'o-', label='Ground Truth Incr.', markersize=4)
            if 'Predicted_Increment' in seq_data_plot.columns:
                 plt.plot(seq_data_plot['Step'], seq_data_plot['Predicted_Increment'], 's--', label='Transformer Pred. Incr.', markersize=4)
            if 'LSTM_Predicted_Increment' in seq_data_plot.columns:
                plt.plot(seq_data_plot['Step'], seq_data_plot['LSTM_Predicted_Increment'], '^-.', label='LSTM-Refined Pred. Incr.', markersize=4)
            else:
                 print(f"Warning: LSTM_Predicted_Increment not found for seq {seq_id}")
            plt.title(f'Time Increments: Sequence {seq_id}')
            plt.xlabel('Step')
            plt.ylabel('Time Increment')
            plt.legend()
        plt.tight_layout()
        plt.savefig('lstm_refined_increment_comparison.png')
        print("Saved LSTM-refined increment comparison plot.")
        plt.close()

    gt_total_times = processed_data.get('y_lstm_target_total_times', np.array([]))
    if lstm_model is not None and 'X_lstm_input' in processed_data:
        X_input_tensor = tf.convert_to_tensor(processed_data['X_lstm_input'], dtype=tf.float32)
        lstm_pred_total_t = lstm_model.predict(X_input_tensor).squeeze()

        if gt_total_times.size > 0 and lstm_pred_total_t.size > 0:
            plt.figure(figsize=(12, 6))
            plt.subplot(1, 2, 1)
            plt.hist(gt_total_times, bins=30, alpha=0.7, label='Ground Truth Total Times')
            plt.hist(lstm_pred_total_t, bins=30, alpha=0.7, label='LSTM Predicted Total Times')
            plt.xlabel('Total Time')
            plt.ylabel('Frequency')
            plt.title('Distribution of Total Times')
            plt.legend()

            plt.subplot(1, 2, 2)
            errors_total_time = gt_total_times - lstm_pred_total_t
            plt.hist(errors_total_time, bins=30, alpha=0.7, color='red')
            plt.xlabel('Prediction Error (GT - Pred)')
            plt.ylabel('Frequency')
            plt.title('Distribution of LSTM Total Time Prediction Errors')
            if errors_total_time.size > 0:
                mean_error_val = errors_total_time.mean()
                plt.axvline(mean_error_val, color='k', linestyle='dashed', linewidth=1, label=f'Mean Error: {mean_error_val:.2f}')
            plt.legend()
            
            plt.tight_layout()
            plt.savefig('lstm_total_time_prediction_analysis.png')
            print("Saved LSTM total time prediction analysis plot.")
            plt.close()
        else:
            print("Warning: Not enough data for total time distribution plots.")
    else:
        print("Warning: LSTM model or input data missing for total time prediction plot.")
    
    print("Visualizations for LSTM (total time approach) completed!")

# %%
def main_lstm_total_time_flow():
    """
    Main function to run the LSTM model for total time prediction.
    """
    try:
        transformer_predictions_file = "predictions_transformer_182625.csv" 
        
        if not os.path.exists(transformer_predictions_file):
            print(f"Error: Transformer predictions file not found: {transformer_predictions_file}")
            print("Attempting to create a DUMMY CSV for testing flow.")
            dummy_data = []
            for seq_idx in range(20): # Increased dummy sequences
                num_steps = np.random.randint(5, 25) # Varied sequence lengths
                steps = np.arange(1, num_steps + 1)
                # Ensure gt_increments are positive and have some variance
                gt_increments = np.random.gamma(2, scale=5, size=num_steps) + 1 
                gt_cumulative = np.cumsum(gt_increments)
                
                # Ensure raw_props are positive and sum to 1 for realistic proportions
                raw_props = np.random.rand(num_steps) + 0.01 
                pred_proportions = raw_props / raw_props.sum() 
                
                # Transformer's predicted increments and cumulative (can be different from GT)
                # For dummy, let's make them somewhat related but not identical to GT
                dummy_transformer_total_time = gt_cumulative[-1] * np.random.uniform(0.8, 1.2) if num_steps > 0 else 0
                pred_increments = pred_proportions * dummy_transformer_total_time
                pred_cumulative = np.cumsum(pred_increments)

                for s_idx in range(num_steps):
                    dummy_data.append({
                        'Sequence': seq_idx,
                        'Step': steps[s_idx],
                        'SourceID': f'MRI_DUMMY_{s_idx%3 +1}',
                        'Predicted_Proportion': pred_proportions[s_idx],
                        'Predicted_Increment': pred_increments[s_idx],
                        'Predicted_Cumulative': pred_cumulative[s_idx],
                        'GroundTruth_Increment': gt_increments[s_idx],
                        'GroundTruth_Cumulative': gt_cumulative[s_idx]
                    })
            if not dummy_data:
                 print("Critical Error: Failed to generate any dummy data steps.")
                 dummy_data.append({
                        'Sequence': 0, 'Step': 1, 'SourceID': 'MRI_DUMMY_0',
                        'Predicted_Proportion': 1.0, 'Predicted_Increment': 10.0,
                        'Predicted_Cumulative': 10.0, 'GroundTruth_Increment': 10.0,
                        'GroundTruth_Cumulative': 10.0
                    }) # Fallback single row

            dummy_df = pd.DataFrame(dummy_data)
            dummy_df.to_csv(transformer_predictions_file, index=False)
            print(f"Dummy '{transformer_predictions_file}' created with {len(dummy_df)} rows and {len(dummy_df['Sequence'].unique())} sequences.")
        
        lstm_model, lstm_history, processed_lstm_data = train_total_time_lstm(
            transformer_predictions_file, epochs=50, batch_size=16 
        )
        
        if lstm_model is None or processed_lstm_data is None:
            print("LSTM training failed or returned None. Exiting.")
            return

        refined_results_df = generate_refined_predictions_with_lstm(lstm_model, processed_lstm_data)
        
        if not refined_results_df.empty:
            print("\nSample of Refined Predictions (LSTM Total Time Approach):")
            display_cols = [
                'Sequence', 'Step', 'SourceID', 
                'Predicted_Increment', # Transformer's original
                'LSTM_Predicted_Increment', # LSTM-refined
                'GroundTruth_Increment',
                'Predicted_Cumulative', # Transformer's original
                'LSTM_Predicted_Cumulative', # LSTM-refined
                'GroundTruth_Cumulative',
                'LSTM_Predicted_TotalTime',
                'Increment_Improvement_Pct'
            ]
            actual_display_cols = [col for col in display_cols if col in refined_results_df.columns]
            print(refined_results_df[actual_display_cols].head(10))
        
            print("\nGenerating visualizations for LSTM (total time approach)...")
            visualize_lstm_results(refined_results_df, processed_lstm_data, lstm_model, lstm_history)
        else:
            print("No refined predictions were generated by the LSTM flow.")
            
    except Exception as e:
        print(f"Error in LSTM (total time) main function: {e}")
        import traceback
        traceback.print_exc()

In [33]:
if __name__ == "__main__":
    main_lstm_total_time_flow()

Processing data for LSTM training...
Processing data from: predictions_transformer_182625.csv
Number of features for LSTM input: 2
Error in LSTM (total time) main function: object of type 'NoneType' has no len()


Traceback (most recent call last):
  File "C:\Users\lukis\AppData\Local\Temp\ipykernel_60648\3529871220.py", line 529, in main_lstm_total_time_flow
    lstm_model, lstm_history, processed_lstm_data = train_total_time_lstm(
  File "C:\Users\lukis\AppData\Local\Temp\ipykernel_60648\3529871220.py", line 229, in train_total_time_lstm
    lstm_model.summary() # Now this should work
  File "C:\Users\lukis\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\keras\src\utils\traceback_utils.py", line 122, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "C:\Users\lukis\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\keras\src\utils\summary_utils.py", line 114, in format_layer_shape
    if len(output_shapes) == 1:
TypeError: object of type 'NoneType' has no len()
