In [7]:
import os
import sys
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
# Assuming pad_sequences is available if needed, though direct use might change.
# from tensorflow.keras.preprocessing.sequence import pad_sequences
import matplotlib.pyplot as plt
import seaborn as sns # Seaborn was commented out in original, keeping it that way
#import test train split sklearn
from sklearn.model_selection import train_test_split

In [None]:
# Helper function to compute increments and cumulative times from proportions and total time
def calculate_times_from_proportions(proportions_per_step, total_time_for_sequence, mask_per_step):
    """
    Calculates time increments and cumulative times from step-wise proportions and a total time.

    Args:
        proportions_per_step (tf.Tensor or np.ndarray): 
            Proportions for each step in sequences (batch_size, seq_len).
        total_time_for_sequence (tf.Tensor or np.ndarray): 
            The total time for each sequence (batch_size, 1) or (batch_size,).
        mask_per_step (tf.Tensor or np.ndarray): 
            Mask indicating valid steps (batch_size, seq_len).

    Returns:
        tuple: (normalized_proportions, increments, cumulative_times)
               all as tf.Tensor.
    """
    proportions_tf = tf.cast(proportions_per_step, tf.float32)
    total_time_tf = tf.cast(total_time_for_sequence, tf.float32)
    mask_tf = tf.cast(mask_per_step, tf.float32)

    # Ensure total_time_tf is (batch_size, 1) for broadcasting
    if len(tf.shape(total_time_tf)) == 1:
        total_time_tf = tf.expand_dims(total_time_tf, axis=-1)

    # Apply mask to proportions
    masked_proportions = proportions_tf * mask_tf

    # Normalize proportions over valid steps
    row_sums = tf.reduce_sum(masked_proportions, axis=1, keepdims=True)
    # Prevent division by zero if a sequence is all padding or has zero proportions
    row_sums = tf.where(tf.equal(row_sums, 0), tf.ones_like(row_sums), row_sums) 
    normalized_proportions = masked_proportions / row_sums

    # Calculate increments
    increments = normalized_proportions * total_time_tf # Broadcasting

    # Calculate cumulative times
    cumulative_times = tf.cumsum(increments, axis=1)
    
    # Ensure masked steps in final outputs are zero
    increments *= mask_tf
    cumulative_times *= mask_tf
    normalized_proportions *= mask_tf

    return normalized_proportions, increments, cumulative_times

# %%
class TotalTimeLSTM(tf.keras.Model):
    """
    LSTM model to predict only the total time of a sequence.
    """
    def __init__(self, hidden_units=64, num_heads=4, dropout_rate=0.2):
        super(TotalTimeLSTM, self).__init__()
        
        self.hidden_units = hidden_units
        self.num_heads = num_heads

        self.lstm_layer = layers.LSTM(self.hidden_units, 
                                      return_sequences=True, 
                                      dropout=dropout_rate,
                                      recurrent_dropout=dropout_rate,
                                      name="lstm_1")
        
        self.bi_lstm = layers.Bidirectional(
            layers.LSTM(self.hidden_units, return_sequences=True, name="lstm_bidirectional_inner"),
            name="bidirectional_lstm_1"
        )
        
        mha_key_dim = (2 * self.hidden_units) // self.num_heads
        if (2 * self.hidden_units) % self.num_heads != 0:
            raise ValueError(f"(2 * hidden_units) must be divisible by num_heads. "
                             f"Got 2 * {self.hidden_units} and {self.num_heads}.")

        self.attention = layers.MultiHeadAttention(
            num_heads=self.num_heads, key_dim=mha_key_dim, name="multi_head_attention_1"
        )
        self.layer_norm_attn = layers.LayerNormalization(name="layer_norm_attention")
        self.global_avg_pool = layers.GlobalAveragePooling1D(name="global_avg_pooling_1d")
        self.total_time_head = layers.Dense(1, activation='linear', name="total_time_dense_output") 
        
    def call(self, inputs, training=False): 
        # Create a boolean mask from inputs. True for non-padded (valid), False for padded.
        mask_bool = tf.reduce_any(tf.not_equal(inputs, 0.0), axis=-1)

        lstm_out = self.lstm_layer(inputs, mask=mask_bool, training=training)
        bi_lstm_out = self.bi_lstm(lstm_out, mask=mask_bool, training=training)
        
        # Attention mask for MHA: (batch_size, 1, 1, key_seq_len)
        # True for allowed tokens, False for masked tokens.
        mha_attention_mask = mask_bool[:, tf.newaxis, tf.newaxis, :]
        attn_output = self.attention(query=bi_lstm_out, value=bi_lstm_out, key=bi_lstm_out, 
                                     attention_mask=mha_attention_mask, 
                                     training=training)
        
        x = self.layer_norm_attn(attn_output + bi_lstm_out) # Residual connection
        sequence_encoding = self.global_avg_pool(x, mask=mask_bool)
        total_time_pred = self.total_time_head(sequence_encoding)
        # Optionally, apply ReLU here if total time must be non-negative and model struggles
        # total_time_pred = tf.nn.relu(total_time_pred) 
        return total_time_pred

# %%
def process_input_data_for_lstm(transformer_predictions_file):
    """
    Process the transformer predictions CSV to prepare data for LSTM training.
    Target total time is now the max GroundTruth_Cumulative for each sequence.
    """
    print(f"Processing data from: {transformer_predictions_file}")
    if not os.path.exists(transformer_predictions_file):
        raise FileNotFoundError(f"Transformer predictions file not found: {transformer_predictions_file}")
    df = pd.read_csv(transformer_predictions_file)
    
    required_cols = ['Predicted_Proportion', 'GroundTruth_Cumulative', 'GroundTruth_Increment', 'Sequence', 'Step']
    for col in required_cols:
        if col not in df.columns:
            raise ValueError(f"CSV must contain '{col}' column.")

    sequences = df['Sequence'].unique()
    
    X_data_list = []
    y_total_times_list = []
    transformer_proportions_list = [] 
    ground_truth_increments_list = []
    ground_truth_cumulative_list = [] # Store the full GT cumulative for reference
    original_dfs_list = [] 

    for seq_id in sequences:
        seq_df = df[df['Sequence'] == seq_id].sort_values('Step').copy()
        if seq_df.empty:
            original_dfs_list.append(seq_df) 
            continue
        original_dfs_list.append(seq_df) 

        current_max_steps = seq_df['Step'].max()
        if current_max_steps == 0: current_max_steps = 1 
        
        features = np.column_stack([
            seq_df['Predicted_Proportion'].values,
            seq_df['Step'].values / current_max_steps 
        ])
        X_data_list.append(features)
        
        gt_cumulative_for_seq = seq_df['GroundTruth_Cumulative'].values
        # **CRITICAL CHANGE HERE**: Use max cumulative time as the target total time
        total_time_for_seq = np.max(gt_cumulative_for_seq) if len(gt_cumulative_for_seq) > 0 else 0.0
        
        y_total_times_list.append(total_time_for_seq)
        
        transformer_proportions_list.append(seq_df['Predicted_Proportion'].values)
        ground_truth_increments_list.append(seq_df['GroundTruth_Increment'].values)
        ground_truth_cumulative_list.append(gt_cumulative_for_seq) # Store original full cumulative path

    if not X_data_list:
        raise ValueError("No valid sequences processed. Check CSV content or processing logic.")

    y_total_times_array_unpadded = np.array(y_total_times_list)
    print(f"\nStatistics for TARGET y_total_times_list (max GT_Cumulative per seq, {len(y_total_times_array_unpadded)} sequences):")
    print(f"  Mean: {np.mean(y_total_times_array_unpadded):.4f}, Std Dev: {np.std(y_total_times_array_unpadded):.4f}")
    print(f"  Min: {np.min(y_total_times_array_unpadded):.4f}, Max: {np.max(y_total_times_array_unpadded):.4f}")
    print(f"  Number of zeros (<=1e-6): {np.sum(y_total_times_array_unpadded <= 1e-6)}")
    print(f"  Number non-positive (<=0): {np.sum(y_total_times_array_unpadded <= 0)}\n")


    max_length = max(len(x) for x in X_data_list)
    num_features = X_data_list[0].shape[1]

    X_padded = np.zeros((len(X_data_list), max_length, num_features), dtype=np.float32)
    masks_padded_float = np.zeros((len(X_data_list), max_length), dtype=np.float32) 
    transformer_proportions_padded = np.zeros((len(X_data_list), max_length), dtype=np.float32)
    # Storing the original GT increments and *full* GT cumulative for final CSV comparison
    gt_increments_padded_original = np.zeros((len(X_data_list), max_length), dtype=np.float32)
    gt_cumulative_padded_original = np.zeros((len(X_data_list), max_length), dtype=np.float32)

    for i in range(len(X_data_list)):
        seq_len = len(X_data_list[i])
        if seq_len > 0:
            X_padded[i, :seq_len, :] = X_data_list[i]
            masks_padded_float[i, :seq_len] = 1.0 
            transformer_proportions_padded[i, :seq_len] = transformer_proportions_list[i]
            gt_increments_padded_original[i, :seq_len] = ground_truth_increments_list[i]
            gt_cumulative_padded_original[i, :seq_len] = ground_truth_cumulative_list[i] # Store full path
        
    y_total_times_np = np.array(y_total_times_list, dtype=np.float32)

    return {
        'X_lstm_input': X_padded,
        'y_lstm_target_total_times': y_total_times_np, # This is now max GT_Cumulative
        'masks_for_calc': masks_padded_float, 
        'sequences_ids': sequences, 
        'original_dfs': original_dfs_list, 
        'transformer_proportions_padded': transformer_proportions_padded, 
        'gt_increments_padded_original': gt_increments_padded_original, # Original GT increments
        'gt_cumulative_padded_original': gt_cumulative_padded_original, # Original GT cumulative path
        'max_len': max_length,
        'num_features': num_features
    }

# %%
def train_total_time_lstm(transformer_predictions_file, epochs=50, batch_size=32, val_split_ratio=0.2):
    """
    Train LSTM model to predict total_time with manual validation split and checks.
    """
    print("Processing data for LSTM training...")
    data_for_lstm = process_input_data_for_lstm(transformer_predictions_file)
    
    print(f"Number of features for LSTM input: {data_for_lstm['num_features']}")
    print(f"Max sequence length for LSTM input: {data_for_lstm['max_len']}")

    lstm_model = TotalTimeLSTM(hidden_units=64, num_heads=4, dropout_rate=0.2) 
    
    input_shape = (None, data_for_lstm['max_len'], data_for_lstm['num_features'])
    # Build the model by making a call with some data, or using .build()
    # Using .build() is cleaner here.
    lstm_model.build(input_shape=input_shape) 
    
    print("\nLSTM Model Summary:")
    lstm_model.summary() 

    lstm_model.compile(
        optimizer=Adam(learning_rate=0.001), 
        loss='mse' 
    )
    
    y_targets_all = data_for_lstm['y_lstm_target_total_times']
    X_inputs_all = data_for_lstm['X_lstm_input']
    
    # Basic checks for NaN/Inf in data
    if np.any(np.isnan(X_inputs_all)) or np.any(np.isinf(X_inputs_all)):
        print("CRITICAL WARNING: NaN or Inf found in X_inputs_all. Training may fail or be unstable.")
    if np.any(np.isnan(y_targets_all)) or np.any(np.isinf(y_targets_all)):
        print("CRITICAL WARNING: NaN or Inf found in y_targets_all. Training may fail or be unstable.")
    if len(y_targets_all) > 0 and np.all(np.abs(y_targets_all) <= 1e-6) : 
        print("CRITICAL WARNING: All target total times are near zero. Model will likely predict zero and not learn effectively.")


    if len(X_inputs_all) < 5: 
        print("Warning: Very few samples (<5), using all for training and no validation during fit.")
        X_train, y_train = X_inputs_all, y_targets_all
        validation_data_for_fit = None
    else:
        X_train, X_val, y_train, y_val = train_test_split(
            X_inputs_all, y_targets_all, test_size=val_split_ratio, random_state=42, shuffle=True
        )
        validation_data_for_fit = (X_val, y_val)
        print(f"\nManually split data: {len(X_train)} train, {len(X_val)} validation samples.")
        print("Training target statistics (y_train):")
        print(f"  Mean: {np.mean(y_train):.4f}, Std: {np.std(y_train):.4f}, Min: {np.min(y_train):.4f}, Max: {np.max(y_train):.4f}")
        print(f"  Number of zeros (<=1e-6): {np.sum(np.abs(y_train) <= 1e-6)}")
        if len(y_val) > 0:
            print("Validation target statistics (y_val):")
            print(f"  Mean: {np.mean(y_val):.4f}, Std: {np.std(y_val):.4f}, Min: {np.min(y_val):.4f}, Max: {np.max(y_val):.4f}")
            print(f"  Number of zeros (<=1e-6): {np.sum(np.abs(y_val) <= 1e-6)}\n")
            if np.all(np.abs(y_val) <= 1e-6): 
                print("CRITICAL WARNING: All validation targets (y_val) are effectively zero. val_loss will likely be zero.")
        else:
            print("No validation samples after split.\n")


    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=20, # Increased patience further
        restore_best_weights=True,
        verbose=1
    )
    
    print("Starting LSTM model training...")
    
    history = lstm_model.fit(
        X_train, 
        y_train,          
        epochs=epochs,
        batch_size=batch_size,
        validation_data=validation_data_for_fit, 
        callbacks=[early_stopping],
        verbose=1
    )
    
    print("LSTM training finished.")
    return lstm_model, history, data_for_lstm

# %%
def generate_refined_predictions_with_lstm(lstm_model, processed_data):
    """
    Generate refined time predictions using LSTM's total_time.
    Output CSV changed to _v4.
    """
    print("Generating refined predictions using LSTM's total time...")
    
    X_input_for_prediction = processed_data['X_lstm_input']
    lstm_predicted_total_times = lstm_model.predict(X_input_for_prediction) 
    lstm_predicted_total_times = np.squeeze(lstm_predicted_total_times)
    # Ensure predicted total times are non-negative
    lstm_predicted_total_times = np.maximum(0, lstm_predicted_total_times)


    transformer_step_proportions = processed_data['transformer_proportions_padded'] 
    masks_for_calc = processed_data['masks_for_calc'] 

    _, lstm_refined_increments, lstm_refined_cumulative = calculate_times_from_proportions(
        transformer_step_proportions,
        lstm_predicted_total_times, 
        masks_for_calc 
    )

    lstm_refined_increments_np = lstm_refined_increments.numpy()
    lstm_refined_cumulative_np = lstm_refined_cumulative.numpy()

    # Get original GT increments and cumulative for comparison in the final CSV
    # These were stored with "_original" suffix in processed_data
    gt_increments_original_padded = processed_data['gt_increments_padded_original']
    gt_cumulative_original_padded = processed_data['gt_cumulative_padded_original']


    results_list_df = []
    original_dfs_from_processing = processed_data['original_dfs'] # These are the DFs as read from CSV

    for i, seq_id in enumerate(processed_data['sequences_ids']):
        if i >= len(original_dfs_from_processing):
            print(f"Warning: Index {i} out of bounds for original_dfs_from_processing. Skipping sequence ID {seq_id}.")
            continue
        
        # Use the original DataFrame for this sequence as the base
        # This df already contains 'Predicted_Increment', 'Predicted_Cumulative' from the Transformer
        # and 'GroundTruth_Increment', 'GroundTruth_Cumulative'
        current_seq_df_base = original_dfs_from_processing[i].copy()
        seq_len = len(current_seq_df_base)

        if seq_len == 0:
            if current_seq_df_base.empty: 
                 results_list_df.append(current_seq_df_base) 
            continue

        if i >= len(lstm_predicted_total_times):
            print(f"Warning: Index {i} out of bounds for lstm_predicted_total_times. Skipping sequence ID {seq_id}.")
            continue

        # Add LSTM predictions to this DataFrame
        current_seq_df_base['LSTM_Predicted_TotalTime'] = lstm_predicted_total_times[i]
        current_seq_df_base['LSTM_Predicted_Increment'] = lstm_refined_increments_np[i, :seq_len]
        current_seq_df_base['LSTM_Predicted_Cumulative'] = lstm_refined_cumulative_np[i, :seq_len]
        
        # Ensure GroundTruth columns are present for MAE calculation
        if 'GroundTruth_Increment' in current_seq_df_base.columns and 'Predicted_Increment' in current_seq_df_base.columns:
            gt_increment = current_seq_df_base['GroundTruth_Increment'].fillna(0) 
            transformer_pred_increment = current_seq_df_base['Predicted_Increment'].fillna(0)
            lstm_pred_increment = current_seq_df_base['LSTM_Predicted_Increment'].fillna(0)
            
            diff_transformer = np.abs(gt_increment - transformer_pred_increment)
            diff_lstm = np.abs(gt_increment - lstm_pred_increment)
            
            current_seq_df_base['Increment_MAE_Transformer'] = diff_transformer
            current_seq_df_base['Increment_MAE_LSTM'] = diff_lstm
            current_seq_df_base['Increment_Improvement_Pct'] = np.where(
                diff_transformer > 1e-6, (diff_transformer - diff_lstm) / diff_transformer * 100, 0 )
        
        if 'GroundTruth_Cumulative' in current_seq_df_base.columns and 'Predicted_Cumulative' in current_seq_df_base.columns:
            gt_cumulative = current_seq_df_base['GroundTruth_Cumulative'].fillna(0)
            transformer_pred_cumulative = current_seq_df_base['Predicted_Cumulative'].fillna(0)
            lstm_pred_cumulative = current_seq_df_base['LSTM_Predicted_Cumulative'].fillna(0)

            diff_transformer_cum = np.abs(gt_cumulative - transformer_pred_cumulative)
            diff_lstm_cum = np.abs(gt_cumulative - lstm_pred_cumulative)

            current_seq_df_base['Cumulative_MAE_Transformer'] = diff_transformer_cum
            current_seq_df_base['Cumulative_MAE_LSTM'] = diff_lstm_cum
            current_seq_df_base['Cumulative_Improvement_Pct'] = np.where(
                diff_transformer_cum > 1e-6, (diff_transformer_cum - diff_lstm_cum) / diff_transformer_cum * 100, 0 )
        
        results_list_df.append(current_seq_df_base)
    
    if not results_list_df:
        print("Warning: No results to concatenate for the final CSV after processing sequences.")
        return pd.DataFrame()

    final_results_df = pd.concat(results_list_df, ignore_index=True)
    
    output_filename = 'predictions_lstm_refined_total_time_v4.csv' # Changed filename
    final_results_df.to_csv(output_filename, index=False)
    print(f"Combined and refined predictions saved to {output_filename}")
    
    # Overall Performance Metrics
    if not final_results_df.empty:
        # MAE for Increments
        if 'Increment_MAE_Transformer' in final_results_df.columns and 'Increment_MAE_LSTM' in final_results_df.columns:
            # Filter out potential NaNs before mean calculation if any rows were skipped or had issues
            valid_inc_mae_transformer = final_results_df['Increment_MAE_Transformer'].dropna()
            valid_inc_mae_lstm = final_results_df['Increment_MAE_LSTM'].dropna()
            if not valid_inc_mae_transformer.empty and not valid_inc_mae_lstm.empty:
                avg_transformer_inc_mae = valid_inc_mae_transformer.mean()
                avg_lstm_inc_mae = valid_inc_mae_lstm.mean()
                print("\n--- Mean Absolute Error for Increments ---")
                print(f"Transformer MAE (Increments): {avg_transformer_inc_mae:.4f}")
                print(f"LSTM-Refined MAE (Increments): {avg_lstm_inc_mae:.4f}")
                if avg_transformer_inc_mae > 1e-6: 
                    improvement_inc = (avg_transformer_inc_mae - avg_lstm_inc_mae) / avg_transformer_inc_mae * 100
                    print(f"Improvement (Increments): {improvement_inc:.2f}%")

        # MAE for Cumulative Times
        if 'Cumulative_MAE_Transformer' in final_results_df.columns and 'Cumulative_MAE_LSTM' in final_results_df.columns:
            valid_cum_mae_transformer = final_results_df['Cumulative_MAE_Transformer'].dropna()
            valid_cum_mae_lstm = final_results_df['Cumulative_MAE_LSTM'].dropna()
            if not valid_cum_mae_transformer.empty and not valid_cum_mae_lstm.empty:
                avg_transformer_cum_mae = valid_cum_mae_transformer.mean()
                avg_lstm_cum_mae = valid_cum_mae_lstm.mean()
                print("\n--- Mean Absolute Error for Cumulative Times ---")
                print(f"Transformer MAE (Cumulative): {avg_transformer_cum_mae:.4f}")
                print(f"LSTM-Refined MAE (Cumulative): {avg_lstm_cum_mae:.4f}")
                if avg_transformer_cum_mae > 1e-6: 
                    improvement_cum = (avg_transformer_cum_mae - avg_lstm_cum_mae) / avg_transformer_cum_mae * 100
                    print(f"Improvement (Cumulative): {improvement_cum:.2f}%")

    # LSTM Total Time Prediction Performance (against the targets it was trained on)
    gt_total_times_for_lstm_training = processed_data['y_lstm_target_total_times'] # These are max GT_Cumulative
    if len(lstm_predicted_total_times) == len(gt_total_times_for_lstm_training):
        mae_total_time_lstm = np.mean(np.abs(gt_total_times_for_lstm_training - lstm_predicted_total_times))
        print("\n--- LSTM Total Time Prediction Performance (vs Max GT Cumulative) ---")
        print(f"MAE for LSTM Predicted Total Time: {mae_total_time_lstm:.4f}")
    else:
        print("\nWarning: Mismatch in lengths for GT total times (max GT cum) and LSTM predicted total times. Cannot compute overall MAE for total time.")
    return final_results_df

# %%
def visualize_lstm_results(results_df, processed_data, lstm_model, training_history):
    if results_df.empty:
        print("Results DataFrame is empty, skipping visualizations.")
        return
    print("Generating visualizations for LSTM results...")
    plt.style.use('ggplot')
    if training_history and hasattr(training_history, 'history'):
        if 'loss' in training_history.history and 'val_loss' in training_history.history:
            plt.figure(figsize=(10, 5))
            plt.plot(training_history.history['loss'], label='Training Loss')
            plt.plot(training_history.history['val_loss'], label='Validation Loss')
            plt.title('LSTM Model Loss (Predicting Total Time)')
            plt.xlabel('Epoch'); plt.ylabel('Mean Squared Error (Loss)'); plt.legend(); plt.tight_layout()
            plt.savefig('lstm_total_time_training_loss_v4.png'); print("Saved LSTM training loss plot."); plt.close() # v4
        else: print("Warning: Training history does not contain 'loss' or 'val_loss' keys.")
    else: print("Warning: No training history provided or history object is not as expected.")

    sample_sequence_ids = results_df['Sequence'].unique()
    if len(sample_sequence_ids) == 0 : print("No sequences found in results_df for plotting."); return
    sample_sequence_ids = sample_sequence_ids[:min(5, len(sample_sequence_ids))]
    
    if len(sample_sequence_ids) > 0:
        num_plots = len(sample_sequence_ids)
        fig_height = max(8, 3 * num_plots) 
        plt.figure(figsize=(15, fig_height))
        for i, seq_id in enumerate(sample_sequence_ids):
            seq_data_plot = results_df[results_df['Sequence'] == seq_id].sort_values('Step')
            if seq_data_plot.empty: continue
            plt.subplot(num_plots, 1, i + 1)
            plt.plot(seq_data_plot['Step'], seq_data_plot['GroundTruth_Cumulative'], 'o-', label='GT Cumul.', ms=4)
            if 'Predicted_Cumulative' in seq_data_plot.columns: plt.plot(seq_data_plot['Step'], seq_data_plot['Predicted_Cumulative'], 's--', label='Transformer Cumul.', ms=4)
            if 'LSTM_Predicted_Cumulative' in seq_data_plot.columns: plt.plot(seq_data_plot['Step'], seq_data_plot['LSTM_Predicted_Cumulative'], '^-.', label='LSTM-Refined Cumul.', ms=4)
            plt.title(f'Cumulative Times: Seq {seq_id}'); plt.xlabel('Step'); plt.ylabel('Cumulative Time'); plt.legend()
        plt.tight_layout(); plt.savefig('lstm_refined_cumulative_time_comparison_v4.png'); print("Saved LSTM-refined cumulative time comparison plot."); plt.close() # v4

        plt.figure(figsize=(15, fig_height))
        for i, seq_id in enumerate(sample_sequence_ids):
            seq_data_plot = results_df[results_df['Sequence'] == seq_id].sort_values('Step')
            if seq_data_plot.empty: continue
            plt.subplot(num_plots, 1, i + 1)
            plt.plot(seq_data_plot['Step'], seq_data_plot['GroundTruth_Increment'], 'o-', label='GT Incr.', ms=4)
            if 'Predicted_Increment' in seq_data_plot.columns: plt.plot(seq_data_plot['Step'], seq_data_plot['Predicted_Increment'], 's--', label='Transformer Incr.', ms=4)
            if 'LSTM_Predicted_Increment' in seq_data_plot.columns: plt.plot(seq_data_plot['Step'], seq_data_plot['LSTM_Predicted_Increment'], '^-.', label='LSTM-Refined Incr.', ms=4)
            plt.title(f'Time Increments: Seq {seq_id}'); plt.xlabel('Step'); plt.ylabel('Time Increment'); plt.legend()
        plt.tight_layout(); plt.savefig('lstm_refined_increment_comparison_v4.png'); print("Saved LSTM-refined increment comparison plot."); plt.close() # v4

    # Target total times used for LSTM training (max GT cumulative)
    gt_total_times_for_lstm_training = processed_data.get('y_lstm_target_total_times', np.array([]))
    
    if lstm_model is not None and 'X_lstm_input' in processed_data:
        X_input_tensor = tf.convert_to_tensor(processed_data['X_lstm_input'], dtype=tf.float32)
        lstm_pred_total_t_for_plot = lstm_model.predict(X_input_tensor).squeeze()
        if lstm_pred_total_t_for_plot.ndim == 0: lstm_pred_total_t_for_plot = np.array([lstm_pred_total_t_for_plot])
            
        if gt_total_times_for_lstm_training.size > 0 and lstm_pred_total_t_for_plot.size > 0:
            plt.figure(figsize=(12, 6))
            plt.subplot(1, 2, 1); 
            plt.hist(gt_total_times_for_lstm_training, bins=30, alpha=0.7, label='GT Total Times (Max Cumul.)')
            plt.hist(lstm_pred_total_t_for_plot, bins=30, alpha=0.7, label='LSTM Pred Total Times')
            plt.xlabel('Total Time'); plt.ylabel('Frequency'); plt.title('Distribution of Total Times'); plt.legend()
            
            plt.subplot(1, 2, 2); 
            if len(gt_total_times_for_lstm_training) == len(lstm_pred_total_t_for_plot):
                errors_total_time = gt_total_times_for_lstm_training - lstm_pred_total_t_for_plot
                plt.hist(errors_total_time, bins=30, alpha=0.7, color='red')
                plt.xlabel('Prediction Error (GT Max Cumul. - Pred)'); plt.ylabel('Frequency'); plt.title('LSTM Total Time Prediction Errors')
                if errors_total_time.size > 0: 
                    mean_error_val = errors_total_time.mean()
                    plt.axvline(mean_error_val, color='k', ls='--', lw=1, label=f'Mean Error: {mean_error_val:.2f}')
                plt.legend()
            else:
                print("Warning: Mismatch length GT total times and predictions for error histogram.")

            plt.tight_layout(); plt.savefig('lstm_total_time_prediction_analysis_v4.png'); print("Saved LSTM total time prediction analysis plot."); plt.close() # v4
        else: print("Warning: Not enough data for total time distribution plots (GT or Pred).")
    else: print("Warning: LSTM model or input data missing for total time prediction plot.")
    print("Visualizations for LSTM (total time approach) completed!")

# %%
def main_lstm_total_time_flow():
    try:
        transformer_predictions_file = "predictions_transformer_182625.csv" 
        if not os.path.exists(transformer_predictions_file):
            print(f"Error: Transformer predictions file not found: {transformer_predictions_file}")
            print("Attempting to create a DUMMY CSV for testing flow.")
            dummy_data = []
            for seq_idx in range(150): # More dummy data for robust testing
                num_steps = np.random.randint(5, 40) 
                steps = np.arange(1, num_steps + 1)
                gt_increments = np.random.gamma(shape=2.5, scale=np.random.uniform(low=3.0, high=12.0), size=num_steps) + np.random.uniform(low=0.1, high=2.5)
                gt_increments = np.maximum(gt_increments, 0.01) # Ensure positive
                gt_cumulative = np.cumsum(gt_increments)
                
                raw_props = np.random.rand(num_steps) + 0.01 
                pred_proportions = raw_props / raw_props.sum() 
                
                # For dummy data, let Transformer's predicted total time be somewhat related to GT max cumulative
                # This is what the LSTM will try to predict if this dummy CSV is used.
                # The "Predicted_Increment" and "Predicted_Cumulative" in the dummy CSV will be based on this.
                # The actual GT_Cumulative's max will be the LSTM's target.
                # Let's make the dummy Transformer's prediction of total time (used for its props)
                # also based on the max of GT_Cumulative to be consistent with what the LSTM should learn.
                # This makes the dummy "Predicted_Proportion" more meaningful relative to the LSTM target.
                
                # The actual total time for this sequence (target for LSTM)
                actual_sequence_total_time = gt_cumulative[-1] if num_steps > 0 else 0.1
                actual_sequence_total_time = max(actual_sequence_total_time, 0.1) # Ensure positive

                # Transformer's predicted increments/cumulative in the dummy CSV
                # These are based on its own (potentially flawed) idea of total time, reflected by its proportions.
                # For simplicity in dummy data, let's assume its proportions are decent and apply them to a
                # slightly perturbed version of the actual_sequence_total_time.
                dummy_transformer_effective_total_time = actual_sequence_total_time * np.random.uniform(0.7, 1.3)
                dummy_transformer_effective_total_time = max(dummy_transformer_effective_total_time, 0.1)

                pred_increments_from_transformer = pred_proportions * dummy_transformer_effective_total_time
                pred_cumulative_from_transformer = np.cumsum(pred_increments_from_transformer)

                for s_idx in range(num_steps):
                    dummy_data.append({
                        'Sequence': seq_idx, 'Step': steps[s_idx], 'SourceID': f'MRI_DUMMY_{s_idx%4 +1}',
                        'Predicted_Proportion': pred_proportions[s_idx], 
                        'Predicted_Increment': pred_increments_from_transformer[s_idx], # From dummy Transformer
                        'Predicted_Cumulative': pred_cumulative_from_transformer[s_idx], # From dummy Transformer
                        'GroundTruth_Increment': gt_increments[s_idx], # Actual GT
                        'GroundTruth_Cumulative': gt_cumulative[s_idx]  # Actual GT
                    })
            if not dummy_data: 
                 dummy_data.append({ 'Sequence': 0, 'Step': 1, 'SourceID': 'MRI_DUMMY_0', 'Predicted_Proportion': 1.0, 
                                     'Predicted_Increment': 10.0, 'Predicted_Cumulative': 10.0, 
                                     'GroundTruth_Increment': 10.0, 'GroundTruth_Cumulative': 10.0})
            dummy_df = pd.DataFrame(dummy_data)
            dummy_df.to_csv(transformer_predictions_file, index=False)
            print(f"Dummy '{transformer_predictions_file}' created with {len(dummy_df)} rows and {len(dummy_df['Sequence'].unique())} sequences.")
        
        lstm_model, lstm_history, processed_lstm_data = train_total_time_lstm(
            transformer_predictions_file, epochs=100, batch_size=16 ) 
        
        if lstm_model is None or processed_lstm_data is None:
            print("LSTM training failed or returned None. Exiting."); return

        refined_results_df = generate_refined_predictions_with_lstm(lstm_model, processed_lstm_data)
        if not refined_results_df.empty:
            print("\nSample of Refined Predictions (LSTM Total Time Approach):")
            display_cols = [ 'Sequence', 'Step', 'SourceID', 
                             'Predicted_Increment', 'LSTM_Predicted_Increment', 'GroundTruth_Increment', 
                             'Predicted_Cumulative', 'LSTM_Predicted_Cumulative', 'GroundTruth_Cumulative',
                             'LSTM_Predicted_TotalTime', 'Increment_Improvement_Pct']
            actual_display_cols = [col for col in display_cols if col in refined_results_df.columns]
            print(refined_results_df[actual_display_cols].head(10))
            print("\nGenerating visualizations for LSTM (total time approach)...")
            visualize_lstm_results(refined_results_df, processed_lstm_data, lstm_model, lstm_history)
        else: print("No refined predictions were generated by the LSTM flow.")
    except Exception as e:
        print(f"Error in LSTM (total time) main function: {e}"); import traceback; traceback.print_exc()

In [9]:
if __name__ == "__main__":
    main_lstm_total_time_flow()

Processing data for LSTM training...
Processing data from: predictions_transformer_182625.csv

Statistics for y_total_times_list (unpadded, 186 sequences):
  Mean: 0.0000, Std Dev: 0.0000
  Min: 0.0000, Max: 0.0000
  Number of zeros (<=1e-6): 186
  Number non-positive (<=0): 186

Number of features for LSTM input: 2
Max sequence length for LSTM input: 42

LSTM Model Summary:





Manually split data: 148 train, 38 validation samples.
Training target statistics (y_train):
  Mean: 0.0000, Std: 0.0000, Min: 0.0000, Max: 0.0000
  Number of zeros (<=1e-6): 148
Validation target statistics (y_val):
  Mean: 0.0000, Std: 0.0000, Min: 0.0000, Max: 0.0000
  Number of zeros (<=1e-6): 38

Starting LSTM model training...
Epoch 1/100
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 90ms/step - loss: 0.9654 - val_loss: 0.1496
Epoch 2/100
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - loss: 0.1066 - val_loss: 0.1212
Epoch 3/100
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - loss: 0.0825 - val_loss: 0.0271
Epoch 4/100
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - loss: 0.0369 - val_loss: 4.1180e-04
Epoch 5/100
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - loss: 0.0169 - val_loss: 0.0012
Epoch 6/100
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m