In [4]:
import os
import sys
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
# Assuming pad_sequences is available if needed, though direct use might change.
# from tensorflow.keras.preprocessing.sequence import pad_sequences
import matplotlib.pyplot as plt
import seaborn as sns # Seaborn was commented out in original, keeping it that way
#import test train split sklearn
from sklearn.model_selection import train_test_split
#import standard scaler
from sklearn.preprocessing import StandardScaler

In [5]:
# Helper function to compute increments and cumulative times from proportions and total time
def calculate_times_from_proportions(proportions_per_step, total_time_for_sequence, mask_per_step):
    """
    Calculates time increments and cumulative times from step-wise proportions and a total time.
    """
    proportions_tf = tf.cast(proportions_per_step, tf.float32)
    total_time_tf = tf.cast(total_time_for_sequence, tf.float32)
    mask_tf = tf.cast(mask_per_step, tf.float32)

    if len(tf.shape(total_time_tf)) == 1:
        total_time_tf = tf.expand_dims(total_time_tf, axis=-1)

    masked_proportions = proportions_tf * mask_tf
    row_sums = tf.reduce_sum(masked_proportions, axis=1, keepdims=True)
    row_sums = tf.where(tf.equal(row_sums, 0), tf.ones_like(row_sums), row_sums) 
    normalized_proportions = masked_proportions / row_sums
    increments = normalized_proportions * total_time_tf 
    cumulative_times = tf.cumsum(increments, axis=1)
    
    increments *= mask_tf
    cumulative_times *= mask_tf
    normalized_proportions *= mask_tf

    return normalized_proportions, increments, cumulative_times

# %%
class TotalTimeLSTM(tf.keras.Model):
    """
    Simplified LSTM model to predict only the total time of a sequence.
    Architecture: LSTM -> GlobalAveragePooling1D -> Dense
    """
    def __init__(self, hidden_units=64, dropout_rate=0.2):
        super(TotalTimeLSTM, self).__init__()
        
        self.hidden_units = hidden_units
        self.dropout_rate = dropout_rate

        self.lstm_layer = layers.LSTM(self.hidden_units, 
                                      return_sequences=True, 
                                      dropout=self.dropout_rate,
                                      recurrent_dropout=self.dropout_rate,
                                      name="lstm_simplified")
        
        self.global_avg_pool = layers.GlobalAveragePooling1D(name="global_avg_pooling_simplified")
        self.total_time_head = layers.Dense(1, activation='linear', name="total_time_dense_simplified") 
        
    def call(self, inputs, training=False): 
        mask_bool = tf.reduce_any(tf.not_equal(inputs, 0.0), axis=-1)
        x = self.lstm_layer(inputs, mask=mask_bool, training=training)
        x = self.global_avg_pool(x, mask=mask_bool)
        total_time_pred = self.total_time_head(x)
        return total_time_pred

# %%
def process_input_data_for_lstm(transformer_predictions_file):
    """
    Process the transformer predictions CSV to prepare data for LSTM training.
    Target total time is the max GroundTruth_Cumulative for each sequence.
    """
    print(f"Processing data from: {transformer_predictions_file}")
    if not os.path.exists(transformer_predictions_file):
        raise FileNotFoundError(f"Transformer predictions file not found: {transformer_predictions_file}")
    df = pd.read_csv(transformer_predictions_file)
    
    required_cols = ['Predicted_Proportion', 'GroundTruth_Cumulative', 'GroundTruth_Increment', 'Sequence', 'Step']
    for col in required_cols:
        if col not in df.columns:
            raise ValueError(f"CSV must contain '{col}' column.")

    sequences = df['Sequence'].unique()
    
    X_data_list = []
    y_total_times_list = [] 
    original_dfs_list = [] 
    transformer_proportions_list = [] 
    ground_truth_increments_list = []
    ground_truth_cumulative_list = [] 

    for seq_id in sequences:
        seq_df = df[df['Sequence'] == seq_id].sort_values('Step').copy()
        if seq_df.empty:
            original_dfs_list.append(seq_df) 
            continue
        original_dfs_list.append(seq_df) 

        current_max_steps = seq_df['Step'].max()
        if current_max_steps == 0: current_max_steps = 1 
        
        features = np.column_stack([
            seq_df['Predicted_Proportion'].values,
            seq_df['Step'].values / current_max_steps 
        ])
        X_data_list.append(features)
        
        gt_cumulative_for_seq = seq_df['GroundTruth_Cumulative'].values
        total_time_for_seq = np.max(gt_cumulative_for_seq) if len(gt_cumulative_for_seq) > 0 else 0.0
        y_total_times_list.append(total_time_for_seq)
        
        transformer_proportions_list.append(seq_df['Predicted_Proportion'].values)
        ground_truth_increments_list.append(seq_df['GroundTruth_Increment'].values)
        ground_truth_cumulative_list.append(gt_cumulative_for_seq) 

    if not X_data_list:
        raise ValueError("No valid sequences processed. Check CSV content or processing logic.")

    y_total_times_array_unpadded = np.array(y_total_times_list)
    print(f"\nStatistics for TARGET y_total_times_list (max GT_Cumulative per seq, {len(y_total_times_array_unpadded)} sequences):")
    print(f"  Mean: {np.mean(y_total_times_array_unpadded):.4f}, Std Dev: {np.std(y_total_times_array_unpadded):.4f}")
    print(f"  Min: {np.min(y_total_times_array_unpadded):.4f}, Max: {np.max(y_total_times_array_unpadded):.4f}")
    print(f"  Number of zeros (<=1e-6): {np.sum(y_total_times_array_unpadded <= 1e-6)}")
    print(f"  Number non-positive (<=0): {np.sum(y_total_times_array_unpadded <= 0)}\n")

    max_length = max(len(x) for x in X_data_list) if X_data_list else 0
    if max_length == 0: raise ValueError("Max length is 0 after processing sequences.")
    num_features = X_data_list[0].shape[1] if X_data_list and X_data_list[0].shape[0] > 0 else 0
    if num_features == 0: raise ValueError("Number of features is 0 after processing sequences.")


    X_padded = np.zeros((len(X_data_list), max_length, num_features), dtype=np.float32)
    masks_padded_float = np.zeros((len(X_data_list), max_length), dtype=np.float32) 
    transformer_proportions_padded = np.zeros((len(X_data_list), max_length), dtype=np.float32)
    gt_increments_padded_original = np.zeros((len(X_data_list), max_length), dtype=np.float32)
    gt_cumulative_padded_original = np.zeros((len(X_data_list), max_length), dtype=np.float32)

    for i in range(len(X_data_list)):
        seq_len = len(X_data_list[i])
        if seq_len > 0:
            X_padded[i, :seq_len, :] = X_data_list[i]
            masks_padded_float[i, :seq_len] = 1.0 
            transformer_proportions_padded[i, :seq_len] = transformer_proportions_list[i]
            gt_increments_padded_original[i, :seq_len] = ground_truth_increments_list[i]
            gt_cumulative_padded_original[i, :seq_len] = ground_truth_cumulative_list[i]
        
    y_total_times_np = np.array(y_total_times_list, dtype=np.float32)

    return {
        'X_lstm_input': X_padded,
        'y_lstm_target_total_times': y_total_times_np,
        'masks_for_calc': masks_padded_float, 
        'sequences_ids': sequences, 
        'original_dfs': original_dfs_list, 
        'transformer_proportions_padded': transformer_proportions_padded, 
        'gt_increments_padded_original': gt_increments_padded_original,
        'gt_cumulative_padded_original': gt_cumulative_padded_original,
        'max_len': max_length,
        'num_features': num_features
    }

# %%
def train_total_time_lstm(transformer_predictions_file, epochs=50, batch_size=32, val_split_ratio=0.2):
    """
    Train the simplified LSTM model to predict total_time, with target scaling.
    """
    print("Processing data for LSTM training...")
    data_for_lstm = process_input_data_for_lstm(transformer_predictions_file)
    
    print(f"Number of features for LSTM input: {data_for_lstm['num_features']}")
    print(f"Max sequence length for LSTM input: {data_for_lstm['max_len']}")

    lstm_model = TotalTimeLSTM(hidden_units=64, dropout_rate=0.2) 
    
    y_targets_all = data_for_lstm['y_lstm_target_total_times']
    X_inputs_all = data_for_lstm['X_lstm_input']
    
    # --- Explicit model call to build layers before summary and compile ---
    if len(X_inputs_all) > 0:
        sample_batch_for_build = tf.convert_to_tensor(X_inputs_all[:1], dtype=tf.float32)
        _ = lstm_model(sample_batch_for_build) # This call should build the layers
        print("\nSimplified LSTM Model Summary (after sample call):")
        lstm_model.summary() 
    else:
        # Fallback if X_inputs_all is empty, though process_input_data_for_lstm should raise error earlier
        input_shape_for_build = (None, data_for_lstm['max_len'], data_for_lstm['num_features'])
        lstm_model.build(input_shape=input_shape_for_build)
        print("\nSimplified LSTM Model Summary (after .build()):")
        lstm_model.summary()


    lstm_model.compile(
        optimizer=Adam(learning_rate=0.001), 
        loss='mse' 
    )
    
    if np.any(np.isnan(X_inputs_all)) or np.any(np.isinf(X_inputs_all)):
        print("CRITICAL WARNING: NaN or Inf found in X_inputs_all.")
    if np.any(np.isnan(y_targets_all)) or np.any(np.isinf(y_targets_all)):
        print("CRITICAL WARNING: NaN or Inf found in y_targets_all.")
    if len(y_targets_all) > 0 and np.all(np.abs(y_targets_all) <= 1e-6) : 
        print("CRITICAL WARNING: All target total times are near zero. Model training will be ineffective.")

    # --- Target Scaling ---
    target_scaler = StandardScaler()
    # Reshape y_targets_all to 2D for scaler
    y_targets_all_reshaped = y_targets_all.reshape(-1, 1)

    if len(X_inputs_all) < 5: 
        print("Warning: Very few samples (<5), using all for training and no validation during fit.")
        X_train, y_train_scaled = X_inputs_all, target_scaler.fit_transform(y_targets_all_reshaped)
        validation_data_for_fit = None
    else:
        X_train, X_val, y_train_orig, y_val_orig = train_test_split(
            X_inputs_all, y_targets_all_reshaped, test_size=val_split_ratio, random_state=42, shuffle=True
        )
        # Fit scaler ONLY on training data, then transform both train and val
        y_train_scaled = target_scaler.fit_transform(y_train_orig)
        y_val_scaled = target_scaler.transform(y_val_orig)
        validation_data_for_fit = (X_val, y_val_scaled)
        
        print(f"\nManually split data: {len(X_train)} train, {len(X_val)} validation samples.")
        # Print stats for original scale before scaling for clarity
        print("Training target statistics (y_train - original scale):")
        print(f"  Mean: {np.mean(y_train_orig.flatten()):.4f}, Std: {np.std(y_train_orig.flatten()):.4f}")
        print("Validation target statistics (y_val - original scale):")
        print(f"  Mean: {np.mean(y_val_orig.flatten()):.4f}, Std: {np.std(y_val_orig.flatten()):.4f}\n")
        if np.all(np.abs(y_val_orig) <= 1e-6): 
            print("CRITICAL WARNING: All original validation targets (y_val_orig) are effectively zero.")

    early_stopping = EarlyStopping(
        monitor='val_loss', patience=20, restore_best_weights=True, verbose=1 )
    
    print("Starting LSTM model training (with scaled targets)...")
    history = lstm_model.fit(
        X_train, y_train_scaled, # Train on scaled targets         
        epochs=epochs, batch_size=batch_size,
        validation_data=validation_data_for_fit, 
        callbacks=[early_stopping], verbose=1
    )
    
    print("LSTM training finished.")
    # Store the scaler in data_for_lstm to use for inverse transform during prediction
    data_for_lstm['target_scaler'] = target_scaler
    return lstm_model, history, data_for_lstm

# %%
def generate_refined_predictions_with_lstm(lstm_model, processed_data):
    """
    Generate refined time predictions using LSTM's total_time.
    Inverse transforms scaled predictions. Output CSV changed to _v6.
    """
    print("Generating refined predictions using LSTM's total time...")
    
    X_input_for_prediction = processed_data['X_lstm_input']
    # LSTM predicts SCALED total times
    lstm_predicted_scaled_total_times = lstm_model.predict(X_input_for_prediction) 
    
    # Inverse transform the predictions to original scale
    target_scaler = processed_data['target_scaler']
    lstm_predicted_total_times_original_scale = target_scaler.inverse_transform(lstm_predicted_scaled_total_times)
    lstm_predicted_total_times_original_scale = np.squeeze(lstm_predicted_total_times_original_scale)
    lstm_predicted_total_times_original_scale = np.maximum(0, lstm_predicted_total_times_original_scale) # Ensure non-negative

    transformer_step_proportions = processed_data['transformer_proportions_padded'] 
    masks_for_calc = processed_data['masks_for_calc'] 

    _, lstm_refined_increments, lstm_refined_cumulative = calculate_times_from_proportions(
        transformer_step_proportions,
        lstm_predicted_total_times_original_scale, # Use inverse_transformed predictions
        masks_for_calc 
    )

    lstm_refined_increments_np = lstm_refined_increments.numpy()
    lstm_refined_cumulative_np = lstm_refined_cumulative.numpy()

    results_list_df = []
    original_dfs_from_processing = processed_data['original_dfs'] 

    for i, seq_id in enumerate(processed_data['sequences_ids']):
        if i >= len(original_dfs_from_processing): continue
        current_seq_df_base = original_dfs_from_processing[i].copy()
        seq_len = len(current_seq_df_base)
        if seq_len == 0:
            if current_seq_df_base.empty: results_list_df.append(current_seq_df_base) 
            continue
        if i >= len(lstm_predicted_total_times_original_scale): continue

        current_seq_df_base['LSTM_Predicted_TotalTime'] = lstm_predicted_total_times_original_scale[i]
        current_seq_df_base['LSTM_Predicted_Increment'] = lstm_refined_increments_np[i, :seq_len]
        current_seq_df_base['LSTM_Predicted_Cumulative'] = lstm_refined_cumulative_np[i, :seq_len]
        
        if 'GroundTruth_Increment' in current_seq_df_base.columns and 'Predicted_Increment' in current_seq_df_base.columns:
            gt_increment = current_seq_df_base['GroundTruth_Increment'].fillna(0) 
            transformer_pred_increment = current_seq_df_base['Predicted_Increment'].fillna(0) 
            lstm_pred_increment = current_seq_df_base['LSTM_Predicted_Increment'].fillna(0)
            diff_transformer = np.abs(gt_increment - transformer_pred_increment)
            diff_lstm = np.abs(gt_increment - lstm_pred_increment)
            current_seq_df_base['Increment_MAE_Transformer'] = diff_transformer
            current_seq_df_base['Increment_MAE_LSTM'] = diff_lstm
            current_seq_df_base['Increment_Improvement_Pct'] = np.where(
                diff_transformer > 1e-6, (diff_transformer - diff_lstm) / diff_transformer * 100, 0 )
        
        if 'GroundTruth_Cumulative' in current_seq_df_base.columns and 'Predicted_Cumulative' in current_seq_df_base.columns:
            gt_cumulative = current_seq_df_base['GroundTruth_Cumulative'].fillna(0)
            transformer_pred_cumulative = current_seq_df_base['Predicted_Cumulative'].fillna(0) 
            lstm_pred_cumulative = current_seq_df_base['LSTM_Predicted_Cumulative'].fillna(0)
            diff_transformer_cum = np.abs(gt_cumulative - transformer_pred_cumulative)
            diff_lstm_cum = np.abs(gt_cumulative - lstm_pred_cumulative)
            current_seq_df_base['Cumulative_MAE_Transformer'] = diff_transformer_cum
            current_seq_df_base['Cumulative_MAE_LSTM'] = diff_lstm_cum
            current_seq_df_base['Cumulative_Improvement_Pct'] = np.where(
                diff_transformer_cum > 1e-6, (diff_transformer_cum - diff_lstm_cum) / diff_transformer_cum * 100, 0 )
        results_list_df.append(current_seq_df_base)
    
    if not results_list_df: return pd.DataFrame()
    final_results_df = pd.concat(results_list_df, ignore_index=True)
    
    output_filename = 'predictions_lstm_refined_total_time_v6.csv' # Changed filename
    final_results_df.to_csv(output_filename, index=False)
    print(f"Combined and refined predictions saved to {output_filename}")
    
    if not final_results_df.empty:
        if 'Increment_MAE_Transformer' in final_results_df.columns and 'Increment_MAE_LSTM' in final_results_df.columns:
            valid_inc_mae_transformer = final_results_df['Increment_MAE_Transformer'].dropna()
            valid_inc_mae_lstm = final_results_df['Increment_MAE_LSTM'].dropna()
            if not valid_inc_mae_transformer.empty and not valid_inc_mae_lstm.empty:
                avg_transformer_inc_mae = valid_inc_mae_transformer.mean(); avg_lstm_inc_mae = valid_inc_mae_lstm.mean()
                print("\n--- Mean Absolute Error for Increments ---")
                print(f"Transformer MAE (Increments): {avg_transformer_inc_mae:.4f}")
                print(f"LSTM-Refined MAE (Increments): {avg_lstm_inc_mae:.4f}")
                if avg_transformer_inc_mae > 1e-6: 
                    print(f"Improvement (Increments): {(avg_transformer_inc_mae - avg_lstm_inc_mae) / avg_transformer_inc_mae * 100:.2f}%")

        if 'Cumulative_MAE_Transformer' in final_results_df.columns and 'Cumulative_MAE_LSTM' in final_results_df.columns:
            valid_cum_mae_transformer = final_results_df['Cumulative_MAE_Transformer'].dropna()
            valid_cum_mae_lstm = final_results_df['Cumulative_MAE_LSTM'].dropna()
            if not valid_cum_mae_transformer.empty and not valid_cum_mae_lstm.empty:
                avg_transformer_cum_mae = valid_cum_mae_transformer.mean(); avg_lstm_cum_mae = valid_cum_mae_lstm.mean()
                print("\n--- Mean Absolute Error for Cumulative Times ---")
                print(f"Transformer MAE (Cumulative): {avg_transformer_cum_mae:.4f}")
                print(f"LSTM-Refined MAE (Cumulative): {avg_lstm_cum_mae:.4f}")
                if avg_transformer_cum_mae > 1e-6: 
                    print(f"Improvement (Cumulative): {(avg_transformer_cum_mae - avg_lstm_cum_mae) / avg_transformer_cum_mae * 100:.2f}%")

    gt_total_times_for_lstm_training = processed_data['y_lstm_target_total_times'] 
    if len(lstm_predicted_total_times_original_scale) == len(gt_total_times_for_lstm_training):
        mae_total_time_lstm = np.mean(np.abs(gt_total_times_for_lstm_training - lstm_predicted_total_times_original_scale))
        print("\n--- LSTM Total Time Prediction Performance (vs Max GT Cumulative) ---")
        print(f"MAE for LSTM Predicted Total Time: {mae_total_time_lstm:.4f}")
    else:
        print("\nWarning: Mismatch in lengths for GT total times and LSTM predicted total times for MAE calculation.")
    return final_results_df

# %%
def visualize_lstm_results(results_df, processed_data, lstm_model, training_history):
    if results_df.empty: print("Results DataFrame is empty, skipping visualizations."); return
    print("Generating visualizations for LSTM results...")
    plt.style.use('ggplot')
    if training_history and hasattr(training_history, 'history'):
        if 'loss' in training_history.history and 'val_loss' in training_history.history:
            plt.figure(figsize=(10, 5))
            plt.plot(training_history.history['loss'], label='Training Loss')
            plt.plot(training_history.history['val_loss'], label='Validation Loss')
            plt.title('LSTM Model Loss (Predicting Scaled Total Time)') # Note: Loss is on scaled values
            plt.xlabel('Epoch'); plt.ylabel('Mean Squared Error (Scaled Loss)'); plt.legend(); plt.tight_layout()
            plt.savefig('lstm_total_time_training_loss_v6.png'); print("Saved LSTM training loss plot."); plt.close()
    else: print("Warning: Training history not available or malformed.")

    sample_sequence_ids = results_df['Sequence'].unique()
    if len(sample_sequence_ids) == 0 : print("No sequences in results_df for plotting."); return
    sample_sequence_ids = sample_sequence_ids[:min(5, len(sample_sequence_ids))]
    
    if len(sample_sequence_ids) > 0:
        num_plots = len(sample_sequence_ids); fig_height = max(8, 3 * num_plots) 
        plt.figure(figsize=(15, fig_height))
        for i, seq_id in enumerate(sample_sequence_ids):
            seq_data_plot = results_df[results_df['Sequence'] == seq_id].sort_values('Step')
            if seq_data_plot.empty: continue
            plt.subplot(num_plots, 1, i + 1)
            plt.plot(seq_data_plot['Step'], seq_data_plot['GroundTruth_Cumulative'], 'o-', label='GT Cumul.', ms=4)
            if 'Predicted_Cumulative' in seq_data_plot.columns: plt.plot(seq_data_plot['Step'], seq_data_plot['Predicted_Cumulative'], 's--', label='Transformer Cumul.', ms=4)
            if 'LSTM_Predicted_Cumulative' in seq_data_plot.columns: plt.plot(seq_data_plot['Step'], seq_data_plot['LSTM_Predicted_Cumulative'], '^-.', label='LSTM-Refined Cumul.', ms=4)
            plt.title(f'Cumulative Times: Seq {seq_id}'); plt.xlabel('Step'); plt.ylabel('Cumulative Time'); plt.legend()
        plt.tight_layout(); plt.savefig('lstm_refined_cumulative_time_comparison_v6.png'); print("Saved cumulative time comparison plot."); plt.close()

        plt.figure(figsize=(15, fig_height))
        for i, seq_id in enumerate(sample_sequence_ids):
            seq_data_plot = results_df[results_df['Sequence'] == seq_id].sort_values('Step')
            if seq_data_plot.empty: continue
            plt.subplot(num_plots, 1, i + 1)
            plt.plot(seq_data_plot['Step'], seq_data_plot['GroundTruth_Increment'], 'o-', label='GT Incr.', ms=4)
            if 'Predicted_Increment' in seq_data_plot.columns: plt.plot(seq_data_plot['Step'], seq_data_plot['Predicted_Increment'], 's--', label='Transformer Incr.', ms=4)
            if 'LSTM_Predicted_Increment' in seq_data_plot.columns: plt.plot(seq_data_plot['Step'], seq_data_plot['LSTM_Predicted_Increment'], '^-.', label='LSTM-Refined Incr.', ms=4)
            plt.title(f'Time Increments: Seq {seq_id}'); plt.xlabel('Step'); plt.ylabel('Time Increment'); plt.legend()
        plt.tight_layout(); plt.savefig('lstm_refined_increment_comparison_v6.png'); print("Saved increment comparison plot."); plt.close()

    gt_total_times_for_lstm_training = processed_data.get('y_lstm_target_total_times', np.array([])) # Original scale
    if lstm_model is not None and 'X_lstm_input' in processed_data and 'target_scaler' in processed_data:
        X_input_tensor = tf.convert_to_tensor(processed_data['X_lstm_input'], dtype=tf.float32)
        lstm_pred_scaled_total_t = lstm_model.predict(X_input_tensor)
        lstm_pred_original_scale_total_t = processed_data['target_scaler'].inverse_transform(lstm_pred_scaled_total_t).squeeze()
        
        if lstm_pred_original_scale_total_t.ndim == 0: lstm_pred_original_scale_total_t = np.array([lstm_pred_original_scale_total_t])
            
        if gt_total_times_for_lstm_training.size > 0 and lstm_pred_original_scale_total_t.size > 0:
            plt.figure(figsize=(12, 6))
            plt.subplot(1, 2, 1); 
            plt.hist(gt_total_times_for_lstm_training, bins=30, alpha=0.7, label='GT Total Times (Max Cumul.)')
            plt.hist(lstm_pred_original_scale_total_t, bins=30, alpha=0.7, label='LSTM Pred Total Times (Original Scale)')
            plt.xlabel('Total Time'); plt.ylabel('Frequency'); plt.title('Distribution of Total Times'); plt.legend()
            
            plt.subplot(1, 2, 2); 
            if len(gt_total_times_for_lstm_training) == len(lstm_pred_original_scale_total_t):
                errors_total_time = gt_total_times_for_lstm_training - lstm_pred_original_scale_total_t
                plt.hist(errors_total_time, bins=30, alpha=0.7, color='red')
                plt.xlabel('Prediction Error (GT Max Cumul. - Pred)'); plt.ylabel('Frequency'); plt.title('LSTM Total Time Prediction Errors')
                if errors_total_time.size > 0: 
                    mean_error_val = errors_total_time.mean(); plt.axvline(mean_error_val, color='k', ls='--', lw=1, label=f'Mean Error: {mean_error_val:.2f}')
                plt.legend()
            else: print("Warning: Mismatch length GT total times and predictions for error histogram.")
            plt.tight_layout(); plt.savefig('lstm_total_time_prediction_analysis_v6.png'); print("Saved total time prediction analysis plot."); plt.close()
        else: print("Warning: Not enough data for total time distribution plots.")
    else: print("Warning: LSTM model, input data, or target_scaler missing for total time prediction plot.")
    print("Visualizations for LSTM completed!")

# %%
def main_lstm_total_time_flow():
    try:
        transformer_predictions_file = "predictions_transformer_182625.csv" 
        if not os.path.exists(transformer_predictions_file):
            print(f"Error: Transformer predictions file not found: {transformer_predictions_file}")
            print("Creating DUMMY CSV for testing flow.")
            dummy_data = []
            for seq_idx in range(250): # More dummy data
                num_steps = np.random.randint(8, 50) # Ensure decent sequence lengths
                steps = np.arange(1, num_steps + 1)
                # Increments that lead to a good range of total times
                gt_increments = np.random.lognormal(mean=1.5, sigma=0.8, size=num_steps) + 0.5 
                gt_increments = np.maximum(gt_increments, 0.01) 
                gt_cumulative = np.cumsum(gt_increments)
                
                raw_props = np.random.rand(num_steps) + 0.1 # Ensure non-zero proportions
                pred_proportions = raw_props / raw_props.sum() 
                
                actual_sequence_total_time = gt_cumulative[-1] if num_steps > 0 else 1.0
                actual_sequence_total_time = max(actual_sequence_total_time, 1.0) 

                # Transformer's "prediction" for its proportions (can be different from actual total time)
                dummy_transformer_effective_total_time = actual_sequence_total_time * np.random.normal(loc=1.0, scale=0.3)
                dummy_transformer_effective_total_time = max(dummy_transformer_effective_total_time, 0.1)

                pred_increments_from_transformer = pred_proportions * dummy_transformer_effective_total_time
                pred_cumulative_from_transformer = np.cumsum(pred_increments_from_transformer)

                for s_idx in range(num_steps):
                    dummy_data.append({
                        'Sequence': seq_idx, 'Step': steps[s_idx], 'SourceID': f'MRI_DUMMY_{s_idx%5 +1}',
                        'Predicted_Proportion': pred_proportions[s_idx], 
                        'Predicted_Increment': pred_increments_from_transformer[s_idx],
                        'Predicted_Cumulative': pred_cumulative_from_transformer[s_idx], 
                        'GroundTruth_Increment': gt_increments[s_idx], 
                        'GroundTruth_Cumulative': gt_cumulative[s_idx]  })
            if not dummy_data: 
                 dummy_data.append({ 'Sequence': 0, 'Step': 1, 'SourceID': 'MRI_DUMMY_0', 'Predicted_Proportion': 1.0, 
                                     'Predicted_Increment': 10.0, 'Predicted_Cumulative': 10.0, 
                                     'GroundTruth_Increment': 10.0, 'GroundTruth_Cumulative': 10.0})
            dummy_df = pd.DataFrame(dummy_data)
            dummy_df.to_csv(transformer_predictions_file, index=False)
            print(f"Dummy '{transformer_predictions_file}' created with {len(dummy_df)} rows and {len(dummy_df['Sequence'].unique())} sequences.")
        
        lstm_model, lstm_history, processed_lstm_data = train_total_time_lstm(
            transformer_predictions_file, epochs=150, batch_size=32 ) # batch_size back to 32
        
        if lstm_model is None or processed_lstm_data is None:
            print("LSTM training failed or returned None. Exiting."); return

        refined_results_df = generate_refined_predictions_with_lstm(lstm_model, processed_lstm_data)
        if not refined_results_df.empty:
            print("\nSample of Refined Predictions (LSTM Total Time Approach - v6):")
            display_cols = [ 'Sequence', 'Step', 'SourceID', 
                             'Predicted_Increment', 'LSTM_Predicted_Increment', 'GroundTruth_Increment', 
                             'Predicted_Cumulative', 'LSTM_Predicted_Cumulative', 'GroundTruth_Cumulative',
                             'LSTM_Predicted_TotalTime', 'Increment_Improvement_Pct']
            actual_display_cols = [col for col in display_cols if col in refined_results_df.columns]
            print(refined_results_df[actual_display_cols].head(10))
            print("\nGenerating visualizations for LSTM (total time approach - v6)...")
            visualize_lstm_results(refined_results_df, processed_lstm_data, lstm_model, lstm_history)
        else: print("No refined predictions were generated by the LSTM flow.")
    except Exception as e:
        print(f"Error in LSTM (total time) main function: {e}"); import traceback; traceback.print_exc()

In [6]:
if __name__ == "__main__":
    main_lstm_total_time_flow()

Processing data for LSTM training...
Processing data from: predictions_transformer_182625.csv

Statistics for TARGET y_total_times_list (max GT_Cumulative per seq, 186 sequences):
  Mean: 374.8065, Std Dev: 348.4868
  Min: 0.0000, Max: 2900.0000
  Number of zeros (<=1e-6): 5
  Number non-positive (<=0): 5

Number of features for LSTM input: 2
Max sequence length for LSTM input: 42

Simplified LSTM Model Summary (after sample call):



Manually split data: 148 train, 38 validation samples.
Training target statistics (y_train - original scale):
  Mean: 390.0405, Std: 362.5637
Validation target statistics (y_val - original scale):
  Mean: 315.4737, Std: 279.3477

Starting LSTM model training (with scaled targets)...
Epoch 1/150
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 76ms/step - loss: 1.2112 - val_loss: 0.6538
Epoch 2/150
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step - loss: 1.3508 - val_loss: 0.6462
Epoch 3/150
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step - loss: 1.2826 - val_loss: 0.6417
Epoch 4/150
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - loss: 1.1398 - val_loss: 0.6318
Epoch 5/150
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - loss: 0.8268 - val_loss: 0.6302
Epoch 6/150
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step - loss: 0.8439 - val_loss: 0.6316
Epoch 7