In [8]:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau # Added ReduceLROnPlateau
import matplotlib.pyplot as plt
import os # For checking file existence
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler # For target scaling

In [9]:
# Helper function to compute increments and cumulative times from proportions and total time
def calculate_times_from_proportions(proportions_per_step, total_time_for_sequence, mask_per_step):
    """
    Calculates time increments and cumulative times from step-wise proportions and a total time.
    """
    proportions_tf = tf.cast(proportions_per_step, tf.float32)
    total_time_tf = tf.cast(total_time_for_sequence, tf.float32)
    mask_tf = tf.cast(mask_per_step, tf.float32)

    if len(tf.shape(total_time_tf)) == 1:
        total_time_tf = tf.expand_dims(total_time_tf, axis=-1)

    masked_proportions = proportions_tf * mask_tf
    row_sums = tf.reduce_sum(masked_proportions, axis=1, keepdims=True)
    row_sums = tf.where(tf.equal(row_sums, 0), tf.ones_like(row_sums), row_sums) 
    normalized_proportions = masked_proportions / row_sums
    increments = normalized_proportions * total_time_tf 
    cumulative_times = tf.cumsum(increments, axis=1)
    
    increments *= mask_tf
    cumulative_times *= mask_tf
    normalized_proportions *= mask_tf

    return normalized_proportions, increments, cumulative_times

# %%
class TotalTimeLSTM(tf.keras.Model):
    """
    Enhanced LSTM model to predict only the total time of a sequence,
    incorporating global sequence features and increased capacity.
    Architecture: 
        Input1 (Sequential): BiLSTM -> GlobalAveragePooling1D 
        Input2 (Global): Global Features
        Concatenate -> Dense -> Dropout -> Dense -> Dropout -> Dense (Output)
    """
    def __init__(self, hidden_units=192, # Increased LSTM units
                 dense_units_1=128, # First dense layer after concat
                 dense_units_2=64,  # Second intermediate dense layer
                 dropout_rate=0.3):
        super(TotalTimeLSTM, self).__init__()
        
        self.hidden_units = hidden_units
        self.dense_units_1 = dense_units_1
        self.dense_units_2 = dense_units_2
        self.dropout_rate = dropout_rate

        # --- Layers for sequential input ---
        self.bi_lstm_layer = layers.Bidirectional(
            layers.LSTM(self.hidden_units, 
                        return_sequences=True, 
                        dropout=self.dropout_rate, # Dropout on inputs to LSTM
                        recurrent_dropout=0.2), # Reduced recurrent dropout
            name="bidirectional_lstm_v9"
        )
        self.global_avg_pool = layers.GlobalAveragePooling1D(name="global_avg_pooling_v9")
        
        # --- Layers for combined features ---
        self.concat_layer = layers.Concatenate(name="concatenate_features_v9")
        
        self.dense_1 = layers.Dense(
            self.dense_units_1, 
            activation='relu', # Using relu
            kernel_regularizer=tf.keras.regularizers.l2(0.001), # Added L2 regularization
            name="dense_1_v9"
        )
        self.dropout_1 = layers.Dropout(self.dropout_rate, name="dropout_1_v9")
        
        self.dense_2 = layers.Dense(
            self.dense_units_2,
            activation='relu', # Using relu
            kernel_regularizer=tf.keras.regularizers.l2(0.001), # Added L2 regularization
            name="dense_2_v9"
        )
        self.dropout_2 = layers.Dropout(self.dropout_rate, name="dropout_2_v9")

        self.total_time_head = layers.Dense(1, activation='linear', name="total_time_dense_v9") 
        
    def call(self, inputs, training=False): 
        sequence_input, global_features_input = inputs 

        mask_bool_seq = tf.reduce_any(tf.not_equal(sequence_input, 0.0), axis=-1)
        x_seq = self.bi_lstm_layer(sequence_input, mask=mask_bool_seq, training=training)
        x_seq_pooled = self.global_avg_pool(x_seq, mask=mask_bool_seq)
        
        combined_features = self.concat_layer([x_seq_pooled, global_features_input])
        
        x = self.dense_1(combined_features)
        x = self.dropout_1(x, training=training)
        x = self.dense_2(x)
        x = self.dropout_2(x, training=training)
        total_time_pred = self.total_time_head(x)
        
        return total_time_pred

# %%
def process_input_data_for_lstm(transformer_predictions_file):
    """
    Process the transformer predictions CSV to prepare data for LSTM training.
    Adds more global features: actual sequence length, sum, mean, and std of Transformer proportions.
    """
    print(f"Processing data from: {transformer_predictions_file}")
    if not os.path.exists(transformer_predictions_file):
        raise FileNotFoundError(f"Transformer predictions file not found: {transformer_predictions_file}")
    df = pd.read_csv(transformer_predictions_file)
    
    required_cols = ['Predicted_Proportion', 'GroundTruth_Cumulative', 'GroundTruth_Increment', 'Sequence', 'Step']
    for col in required_cols:
        if col not in df.columns:
            raise ValueError(f"CSV must contain '{col}' column.")

    sequences = df['Sequence'].unique()
    
    X_sequential_data_list = []
    X_global_features_list = []
    y_total_times_list = [] 
    original_dfs_list = [] 
    transformer_proportions_list = [] 
    ground_truth_increments_list = []
    ground_truth_cumulative_list = [] 

    for seq_id in sequences:
        seq_df = df[df['Sequence'] == seq_id].sort_values('Step').copy()
        if seq_df.empty:
            original_dfs_list.append(seq_df) 
            continue
        original_dfs_list.append(seq_df) 

        actual_sequence_length = float(len(seq_df)) # Ensure float for scaler
        props_for_seq = seq_df['Predicted_Proportion'].values
        sum_transformer_proportions = np.sum(props_for_seq)
        mean_transformer_proportions = np.mean(props_for_seq) if actual_sequence_length > 0 else 0.0
        std_transformer_proportions = np.std(props_for_seq) if actual_sequence_length > 1 else 0.0 # Std undefined for 1 element

        current_max_steps = seq_df['Step'].max()
        if current_max_steps == 0: current_max_steps = 1 
        
        sequential_features = np.column_stack([
            props_for_seq,
            seq_df['Step'].values / current_max_steps 
        ])
        X_sequential_data_list.append(sequential_features)
        
        global_features_for_seq = np.array([
            actual_sequence_length, 
            sum_transformer_proportions,
            mean_transformer_proportions,
            std_transformer_proportions
        ], dtype=np.float32)
        X_global_features_list.append(global_features_for_seq)
        
        gt_cumulative_for_seq = seq_df['GroundTruth_Cumulative'].values
        total_time_for_seq = np.max(gt_cumulative_for_seq) if len(gt_cumulative_for_seq) > 0 else 0.0
        y_total_times_list.append(total_time_for_seq)
        
        transformer_proportions_list.append(props_for_seq)
        ground_truth_increments_list.append(seq_df['GroundTruth_Increment'].values)
        ground_truth_cumulative_list.append(gt_cumulative_for_seq) 

    if not X_sequential_data_list:
        raise ValueError("No valid sequences processed.")

    y_total_times_array_unpadded = np.array(y_total_times_list)
    print(f"\nStatistics for TARGET y_total_times_list (max GT_Cumulative per seq, {len(y_total_times_array_unpadded)} sequences):")
    print(f"  Mean: {np.mean(y_total_times_array_unpadded):.4f}, Std Dev: {np.std(y_total_times_array_unpadded):.4f}")
    print(f"  Min: {np.min(y_total_times_array_unpadded):.4f}, Max: {np.max(y_total_times_array_unpadded):.4f}")
    print(f"  Number of zeros (<=1e-6): {np.sum(y_total_times_array_unpadded <= 1e-6)}")
    print(f"  Number non-positive (<=0): {np.sum(y_total_times_array_unpadded <= 0)}\n")

    max_length_sequential = max(len(x) for x in X_sequential_data_list) if X_sequential_data_list else 0
    if max_length_sequential == 0: raise ValueError("Max length for sequential features is 0.")
    num_sequential_features = X_sequential_data_list[0].shape[1] if X_sequential_data_list and X_sequential_data_list[0].shape[0] > 0 else 0
    if num_sequential_features == 0: raise ValueError("Number of sequential features is 0.")
    num_global_features = X_global_features_list[0].shape[0] if X_global_features_list else 0
    if num_global_features == 0 : raise ValueError("Number of global features is 0.")

    X_sequential_padded = np.zeros((len(X_sequential_data_list), max_length_sequential, num_sequential_features), dtype=np.float32)
    masks_padded_float = np.zeros((len(X_sequential_data_list), max_length_sequential), dtype=np.float32) 
    transformer_proportions_padded = np.zeros((len(X_sequential_data_list), max_length_sequential), dtype=np.float32)
    gt_increments_padded_original = np.zeros((len(X_sequential_data_list), max_length_sequential), dtype=np.float32)
    gt_cumulative_padded_original = np.zeros((len(X_sequential_data_list), max_length_sequential), dtype=np.float32)

    for i in range(len(X_sequential_data_list)):
        seq_len = len(X_sequential_data_list[i])
        if seq_len > 0:
            X_sequential_padded[i, :seq_len, :] = X_sequential_data_list[i]
            masks_padded_float[i, :seq_len] = 1.0 
            transformer_proportions_padded[i, :seq_len] = transformer_proportions_list[i]
            gt_increments_padded_original[i, :seq_len] = ground_truth_increments_list[i]
            gt_cumulative_padded_original[i, :seq_len] = ground_truth_cumulative_list[i]
        
    y_total_times_np = np.array(y_total_times_list, dtype=np.float32)
    X_global_features_np = np.array(X_global_features_list, dtype=np.float32)

    return {
        'X_sequential_input': X_sequential_padded, 
        'X_global_features_input': X_global_features_np, 
        'y_lstm_target_total_times': y_total_times_np,
        'masks_for_calc': masks_padded_float, 
        'sequences_ids': sequences, 
        'original_dfs': original_dfs_list, 
        'transformer_proportions_padded': transformer_proportions_padded, 
        'gt_increments_padded_original': gt_increments_padded_original,
        'gt_cumulative_padded_original': gt_cumulative_padded_original,
        'max_len_sequential': max_length_sequential,
        'num_sequential_features': num_sequential_features,
        'num_global_features': num_global_features
    }

# %%
def train_total_time_lstm(transformer_predictions_file, epochs=50, batch_size=32, val_split_ratio=0.2):
    """
    Train the enhanced LSTM model (with more global features) to predict total_time.
    """
    print("Processing data for LSTM training...")
    data_for_lstm = process_input_data_for_lstm(transformer_predictions_file)
    
    print(f"Num sequential features: {data_for_lstm['num_sequential_features']}, Max seq length: {data_for_lstm['max_len_sequential']}")
    print(f"Num global features: {data_for_lstm['num_global_features']}")

    lstm_model = TotalTimeLSTM(hidden_units=192, dense_units_1=128, dense_units_2=64, dropout_rate=0.3) 
    
    X_sequential_all = data_for_lstm['X_sequential_input']
    X_global_all = data_for_lstm['X_global_features_input']
    y_targets_all = data_for_lstm['y_lstm_target_total_times']
    
    if len(X_sequential_all) > 0:
        sample_seq_input_for_build = tf.convert_to_tensor(X_sequential_all[:1], dtype=tf.float32)
        sample_glob_input_for_build = tf.convert_to_tensor(X_global_all[:1], dtype=tf.float32)
        _ = lstm_model((sample_seq_input_for_build, sample_glob_input_for_build)) 
        print("\nEnhanced LSTM Model Summary (v9 - after sample call):")
        lstm_model.summary(expand_nested=True) 
    else:
        print("Warning: No data to build model with sample call.")
        # Fallback build if needed, though less ideal for multi-input subclassed models
        # seq_input_shape = (None, data_for_lstm['max_len_sequential'], data_for_lstm['num_sequential_features'])
        # glob_input_shape = (None, data_for_lstm['num_global_features'])
        # lstm_model.build(input_shape=[seq_input_shape, glob_input_shape]) # This might not work as expected
        # lstm_model.summary(expand_nested=True)

    lstm_model.compile(optimizer=Adam(learning_rate=0.0005), loss='mse' ) # Kept reduced LR
    
    if np.any(np.isnan(X_sequential_all)) or np.any(np.isinf(X_sequential_all)): print("CRITICAL WARNING: NaN/Inf in X_sequential_all.")
    if np.any(np.isnan(X_global_all)) or np.any(np.isinf(X_global_all)): print("CRITICAL WARNING: NaN/Inf in X_global_all.")
    if np.any(np.isnan(y_targets_all)) or np.any(np.isinf(y_targets_all)): print("CRITICAL WARNING: NaN/Inf in y_targets_all.")
    if len(y_targets_all) > 0 and np.all(np.abs(y_targets_all) <= 1e-6) : print("CRITICAL WARNING: All target total times are near zero.")

    target_scaler = StandardScaler()
    y_targets_all_reshaped = y_targets_all.reshape(-1, 1)

    global_feature_scaler = StandardScaler()
    # X_global_all is already a numpy array from process_input_data_for_lstm

    indices = np.arange(len(X_sequential_all))

    if len(X_sequential_all) < 10: # Increased minimum for a meaningful split 
        print("Warning: Very few samples (<10), using all for training.")
        X_train_seq = X_sequential_all
        X_train_glob_scaled = global_feature_scaler.fit_transform(X_global_all)
        y_train_scaled = target_scaler.fit_transform(y_targets_all_reshaped)
        validation_data_for_fit = None
    else:
        train_indices, val_indices = train_test_split(indices, test_size=val_split_ratio, random_state=42, shuffle=True)
        
        X_train_seq = X_sequential_all[train_indices]
        X_val_seq = X_sequential_all[val_indices]
        
        X_train_glob = X_global_all[train_indices]
        X_val_glob = X_global_all[val_indices]
        X_train_glob_scaled = global_feature_scaler.fit_transform(X_train_glob) 
        X_val_glob_scaled = global_feature_scaler.transform(X_val_glob)     

        y_train_orig_reshaped = y_targets_all_reshaped[train_indices]
        y_val_orig_reshaped = y_targets_all_reshaped[val_indices]
        
        y_train_scaled = target_scaler.fit_transform(y_train_orig_reshaped) 
        y_val_scaled = target_scaler.transform(y_val_orig_reshaped)         
        
        validation_data_for_fit = ([X_val_seq, X_val_glob_scaled], y_val_scaled) 
        
        print(f"\nManually split data: {len(X_train_seq)} train, {len(X_val_seq)} validation samples.")
        print("Training target statistics (y_train - original scale):")
        print(f"  Mean: {np.mean(y_train_orig_reshaped.flatten()):.4f}, Std: {np.std(y_train_orig_reshaped.flatten()):.4f}")
        print("Validation target statistics (y_val - original scale):")
        print(f"  Mean: {np.mean(y_val_orig_reshaped.flatten()):.4f}, Std: {np.std(y_val_orig_reshaped.flatten()):.4f}\n")

    callbacks_list = [
        EarlyStopping(monitor='val_loss', patience=35, restore_best_weights=True, verbose=1), # Increased patience
        ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=10, min_lr=1e-6, verbose=1) # Added LR scheduler
    ]
    
    print("Starting LSTM model training (with scaled targets and global features)...")
    history = lstm_model.fit(
        [X_train_seq, X_train_glob_scaled], y_train_scaled,          
        epochs=epochs, batch_size=batch_size,
        validation_data=validation_data_for_fit, 
        callbacks=callbacks_list, verbose=1
    )
    
    print("LSTM training finished.")
    data_for_lstm['target_scaler'] = target_scaler
    data_for_lstm['global_feature_scaler'] = global_feature_scaler 
    return lstm_model, history, data_for_lstm

# %%
def generate_refined_predictions_with_lstm(lstm_model, processed_data):
    """
    Generate refined time predictions. LSTM_Predicted_TotalTime only on last row of seq.
    Output CSV changed to _v9.
    """
    print("Generating refined predictions using LSTM's total time...")
    
    X_sequential_input_all = processed_data['X_sequential_input']
    X_global_features_input_all_unscaled = processed_data['X_global_features_input']
    
    global_feature_scaler = processed_data['global_feature_scaler']
    X_global_features_input_all_scaled = global_feature_scaler.transform(X_global_features_input_all_unscaled)

    lstm_predicted_scaled_total_times = lstm_model.predict(
        [X_sequential_input_all, X_global_features_input_all_scaled] 
    ) 
    
    target_scaler = processed_data['target_scaler']
    lstm_predicted_total_times_original_scale = target_scaler.inverse_transform(lstm_predicted_scaled_total_times)
    lstm_predicted_total_times_original_scale = np.squeeze(lstm_predicted_total_times_original_scale)
    lstm_predicted_total_times_original_scale = np.maximum(0, lstm_predicted_total_times_original_scale) 

    transformer_step_proportions = processed_data['transformer_proportions_padded'] 
    masks_for_calc = processed_data['masks_for_calc'] 

    _, lstm_refined_increments, lstm_refined_cumulative = calculate_times_from_proportions(
        transformer_step_proportions,
        lstm_predicted_total_times_original_scale, 
        masks_for_calc 
    )

    lstm_refined_increments_np = lstm_refined_increments.numpy()
    lstm_refined_cumulative_np = lstm_refined_cumulative.numpy()

    results_list_df = []
    original_dfs_from_processing = processed_data['original_dfs'] 

    for i, seq_id in enumerate(processed_data['sequences_ids']):
        if i >= len(original_dfs_from_processing): continue
        current_seq_df_base = original_dfs_from_processing[i].copy()
        seq_len = len(current_seq_df_base)
        if seq_len == 0:
            if current_seq_df_base.empty: results_list_df.append(current_seq_df_base) 
            continue
        if i >= len(lstm_predicted_total_times_original_scale): continue

        # Initialize LSTM_Predicted_TotalTime with NaN
        current_seq_df_base['LSTM_Predicted_TotalTime'] = np.nan
        # Set the predicted total time only for the last row of the sequence
        if seq_len > 0:
            current_seq_df_base.iloc[-1, current_seq_df_base.columns.get_loc('LSTM_Predicted_TotalTime')] = lstm_predicted_total_times_original_scale[i]
        
        current_seq_df_base['LSTM_Predicted_Increment'] = lstm_refined_increments_np[i, :seq_len]
        current_seq_df_base['LSTM_Predicted_Cumulative'] = lstm_refined_cumulative_np[i, :seq_len]
        
        if 'GroundTruth_Increment' in current_seq_df_base.columns and 'Predicted_Increment' in current_seq_df_base.columns:
            gt_increment = current_seq_df_base['GroundTruth_Increment'].fillna(0); transformer_pred_increment = current_seq_df_base['Predicted_Increment'].fillna(0); lstm_pred_increment = current_seq_df_base['LSTM_Predicted_Increment'].fillna(0)
            diff_transformer = np.abs(gt_increment - transformer_pred_increment); diff_lstm = np.abs(gt_increment - lstm_pred_increment)
            current_seq_df_base['Increment_MAE_Transformer'] = diff_transformer; current_seq_df_base['Increment_MAE_LSTM'] = diff_lstm
            current_seq_df_base['Increment_Improvement_Pct'] = np.where(diff_transformer > 1e-6, (diff_transformer - diff_lstm) / diff_transformer * 100, 0 )
        if 'GroundTruth_Cumulative' in current_seq_df_base.columns and 'Predicted_Cumulative' in current_seq_df_base.columns:
            gt_cumulative = current_seq_df_base['GroundTruth_Cumulative'].fillna(0); transformer_pred_cumulative = current_seq_df_base['Predicted_Cumulative'].fillna(0); lstm_pred_cumulative = current_seq_df_base['LSTM_Predicted_Cumulative'].fillna(0)
            diff_transformer_cum = np.abs(gt_cumulative - transformer_pred_cumulative); diff_lstm_cum = np.abs(gt_cumulative - lstm_pred_cumulative)
            current_seq_df_base['Cumulative_MAE_Transformer'] = diff_transformer_cum; current_seq_df_base['Cumulative_MAE_LSTM'] = diff_lstm_cum
            current_seq_df_base['Cumulative_Improvement_Pct'] = np.where(diff_transformer_cum > 1e-6, (diff_transformer_cum - diff_lstm_cum) / diff_transformer_cum * 100, 0 )
        results_list_df.append(current_seq_df_base)
    
    if not results_list_df: return pd.DataFrame()
    final_results_df = pd.concat(results_list_df, ignore_index=True)
    
    output_filename = 'combined_model_results_SN_175651_175974_182625_v9.csv' # Changed filename
    final_results_df.to_csv(output_filename, index=False)
    print(f"Combined and refined predictions saved to {output_filename}")
    
    if not final_results_df.empty:
        if 'Increment_MAE_Transformer' in final_results_df.columns and 'Increment_MAE_LSTM' in final_results_df.columns:
            valid_inc_mae_transformer = final_results_df['Increment_MAE_Transformer'].dropna(); valid_inc_mae_lstm = final_results_df['Increment_MAE_LSTM'].dropna()
            if not valid_inc_mae_transformer.empty and not valid_inc_mae_lstm.empty:
                avg_transformer_inc_mae = valid_inc_mae_transformer.mean(); avg_lstm_inc_mae = valid_inc_mae_lstm.mean()
                print(f"\nTransformer MAE (Increments): {avg_transformer_inc_mae:.4f}, LSTM-Refined MAE (Increments): {avg_lstm_inc_mae:.4f}")
                if avg_transformer_inc_mae > 1e-6: print(f"Improvement (Increments): {(avg_transformer_inc_mae - avg_lstm_inc_mae) / avg_transformer_inc_mae * 100:.2f}%")
        if 'Cumulative_MAE_Transformer' in final_results_df.columns and 'Cumulative_MAE_LSTM' in final_results_df.columns:
            valid_cum_mae_transformer = final_results_df['Cumulative_MAE_Transformer'].dropna(); valid_cum_mae_lstm = final_results_df['Cumulative_MAE_LSTM'].dropna()
            if not valid_cum_mae_transformer.empty and not valid_cum_mae_lstm.empty:
                avg_transformer_cum_mae = valid_cum_mae_transformer.mean(); avg_lstm_cum_mae = valid_cum_mae_lstm.mean()
                print(f"Transformer MAE (Cumulative): {avg_transformer_cum_mae:.4f}, LSTM-Refined MAE (Cumulative): {avg_lstm_cum_mae:.4f}")
                if avg_transformer_cum_mae > 1e-6: print(f"Improvement (Cumulative): {(avg_transformer_cum_mae - avg_lstm_cum_mae) / avg_transformer_cum_mae * 100:.2f}%")

    gt_total_times_for_lstm_training = processed_data['y_lstm_target_total_times'] 
    if len(lstm_predicted_total_times_original_scale) == len(gt_total_times_for_lstm_training):
        mae_total_time_lstm = np.mean(np.abs(gt_total_times_for_lstm_training - lstm_predicted_total_times_original_scale))
        print(f"\nLSTM MAE for Total Time (vs Max GT Cumulative): {mae_total_time_lstm:.4f}")

    return final_results_df

# %%
def visualize_lstm_results(results_df, processed_data, lstm_model, training_history):
    if results_df.empty: print("Results DataFrame is empty, skipping visualizations."); return
    print("Generating visualizations for LSTM results...")
    plt.style.use('ggplot')
    if training_history and hasattr(training_history, 'history'):
        if 'loss' in training_history.history and 'val_loss' in training_history.history:
            plt.figure(figsize=(10, 5))
            plt.plot(training_history.history['loss'], label='Training Loss')
            plt.plot(training_history.history['val_loss'], label='Validation Loss')
            if 'lr' in training_history.history:
                ax2 = plt.gca().twinx()
                ax2.plot(training_history.history['lr'], label='Learning Rate', color='g', linestyle='--')
                ax2.set_ylabel('Learning Rate')
                ax2.legend(loc='upper center')
            plt.title('LSTM Model Loss (Predicting Scaled Total Time)') 
            plt.xlabel('Epoch'); plt.ylabel('Mean Squared Error (Scaled Loss)'); plt.legend(loc='upper left'); plt.tight_layout()
            plt.savefig('lstm_total_time_training_loss_v9.png'); print("Saved LSTM training loss plot."); plt.close()
    else: print("Warning: Training history not available or malformed.")

    sample_sequence_ids = results_df['Sequence'].unique()
    if len(sample_sequence_ids) == 0 : print("No sequences in results_df for plotting."); return
    sample_sequence_ids = sample_sequence_ids[:min(5, len(sample_sequence_ids))]
    
    if len(sample_sequence_ids) > 0:
        num_plots = len(sample_sequence_ids); fig_height = max(8, 3 * num_plots) 
        plt.figure(figsize=(15, fig_height)) # Cumulative Plot
        for i, seq_id in enumerate(sample_sequence_ids):
            seq_data_plot = results_df[results_df['Sequence'] == seq_id].sort_values('Step')
            if seq_data_plot.empty: continue
            plt.subplot(num_plots, 1, i + 1)
            plt.plot(seq_data_plot['Step'], seq_data_plot['GroundTruth_Cumulative'], 'o-', label='GT Cumul.', ms=4)
            if 'Predicted_Cumulative' in seq_data_plot.columns: plt.plot(seq_data_plot['Step'], seq_data_plot['Predicted_Cumulative'], 's--', label='Transformer Cumul.', ms=4)
            if 'LSTM_Predicted_Cumulative' in seq_data_plot.columns: plt.plot(seq_data_plot['Step'], seq_data_plot['LSTM_Predicted_Cumulative'], '^-.', label='LSTM-Refined Cumul.', ms=4)
            # Display the single LSTM_Predicted_TotalTime for the sequence
            lstm_total_time_for_seq = seq_data_plot['LSTM_Predicted_TotalTime'].dropna().unique()
            if len(lstm_total_time_for_seq) == 1:
                 plt.axhline(y=lstm_total_time_for_seq[0], color='purple', linestyle=':', label=f'LSTM Total Pred: {lstm_total_time_for_seq[0]:.2f}')
            plt.title(f'Cumulative Times: Seq {seq_id}'); plt.xlabel('Step'); plt.ylabel('Cumulative Time'); plt.legend()
        plt.tight_layout(); plt.savefig('lstm_refined_cumulative_time_comparison_v9.png'); print("Saved cumulative time comparison plot."); plt.close()

        plt.figure(figsize=(15, fig_height)) # Increment Plot
        for i, seq_id in enumerate(sample_sequence_ids):
            seq_data_plot = results_df[results_df['Sequence'] == seq_id].sort_values('Step')
            if seq_data_plot.empty: continue
            plt.subplot(num_plots, 1, i + 1)
            plt.plot(seq_data_plot['Step'], seq_data_plot['GroundTruth_Increment'], 'o-', label='GT Incr.', ms=4)
            if 'Predicted_Increment' in seq_data_plot.columns: plt.plot(seq_data_plot['Step'], seq_data_plot['Predicted_Increment'], 's--', label='Transformer Incr.', ms=4)
            if 'LSTM_Predicted_Increment' in seq_data_plot.columns: plt.plot(seq_data_plot['Step'], seq_data_plot['LSTM_Predicted_Increment'], '^-.', label='LSTM-Refined Incr.', ms=4)
            plt.title(f'Time Increments: Seq {seq_id}'); plt.xlabel('Step'); plt.ylabel('Time Increment'); plt.legend()
        plt.tight_layout(); plt.savefig('lstm_refined_increment_comparison_v9.png'); print("Saved increment comparison plot."); plt.close()

    gt_total_times_for_lstm_training = processed_data.get('y_lstm_target_total_times', np.array([])) 
    if lstm_model is not None and 'X_sequential_input' in processed_data and 'X_global_features_input' in processed_data and 'target_scaler' in processed_data:
        X_seq_tensor = tf.convert_to_tensor(processed_data['X_sequential_input'], dtype=tf.float32)
        X_glob_unscaled = processed_data['X_global_features_input']
        X_glob_scaled_for_plot = processed_data['global_feature_scaler'].transform(X_glob_unscaled)
        X_glob_tensor = tf.convert_to_tensor(X_glob_scaled_for_plot, dtype=tf.float32)
        
        lstm_pred_scaled_total_t = lstm_model.predict([X_seq_tensor, X_glob_tensor])
        lstm_pred_original_scale_total_t = processed_data['target_scaler'].inverse_transform(lstm_pred_scaled_total_t).squeeze()
        
        if lstm_pred_original_scale_total_t.ndim == 0: lstm_pred_original_scale_total_t = np.array([lstm_pred_original_scale_total_t])
            
        if gt_total_times_for_lstm_training.size > 0 and lstm_pred_original_scale_total_t.size > 0:
            plt.figure(figsize=(12, 6))
            plt.subplot(1, 2, 1); 
            plt.hist(gt_total_times_for_lstm_training, bins=30, alpha=0.7, label='GT Total Times (Max Cumul.)')
            plt.hist(lstm_pred_original_scale_total_t, bins=30, alpha=0.7, label='LSTM Pred Total Times (Original Scale)')
            plt.xlabel('Total Time'); plt.ylabel('Frequency'); plt.title('Distribution of Total Times'); plt.legend()
            plt.subplot(1, 2, 2); 
            if len(gt_total_times_for_lstm_training) == len(lstm_pred_original_scale_total_t):
                errors_total_time = gt_total_times_for_lstm_training - lstm_pred_original_scale_total_t
                plt.hist(errors_total_time, bins=30, alpha=0.7, color='red')
                plt.xlabel('Prediction Error (GT Max Cumul. - Pred)'); plt.ylabel('Frequency'); plt.title('LSTM Total Time Prediction Errors')
                if errors_total_time.size > 0: 
                    mean_error_val = errors_total_time.mean(); plt.axvline(mean_error_val, color='k', ls='--', lw=1, label=f'Mean Error: {mean_error_val:.2f}')
                plt.legend()
            else: print("Warning: Mismatch length GT total times and predictions for error histogram.")
            plt.tight_layout(); plt.savefig('lstm_total_time_prediction_analysis_v9.png'); print("Saved total time prediction analysis plot."); plt.close()
        else: print("Warning: Not enough data for total time distribution plots.")
    else: print("Warning: LSTM model, input data, or scaler missing for total time prediction plot.")
    print("Visualizations for LSTM completed!")

# %%
def main_lstm_total_time_flow():
    try:
        transformer_predictions_file = "combined_model_results_SN_175651_175974_182625.csv" 
        if not os.path.exists(transformer_predictions_file):
            print(f"Error: Transformer predictions file not found: {transformer_predictions_file}")
            print("Creating DUMMY CSV for testing flow.")
            dummy_data = []
            for seq_idx in range(300): 
                num_steps = np.random.randint(10, 60) 
                steps = np.arange(1, num_steps + 1)
                gt_increments = np.random.lognormal(mean=2.0, sigma=0.7, size=num_steps) + 0.1 
                gt_increments = np.maximum(gt_increments, 0.01) 
                gt_cumulative = np.cumsum(gt_increments)
                
                raw_props = np.random.rand(num_steps) + 0.05 
                pred_proportions = raw_props / raw_props.sum() 
                
                actual_sequence_total_time = gt_cumulative[-1] if num_steps > 0 else 1.0
                actual_sequence_total_time = max(actual_sequence_total_time, 1.0) 

                dummy_transformer_effective_total_time = actual_sequence_total_time * np.random.normal(loc=1.0, scale=0.4) 
                dummy_transformer_effective_total_time = max(dummy_transformer_effective_total_time, 0.1)

                pred_increments_from_transformer = pred_proportions * dummy_transformer_effective_total_time
                pred_cumulative_from_transformer = np.cumsum(pred_increments_from_transformer)

                for s_idx in range(num_steps):
                    dummy_data.append({
                        'Sequence': seq_idx, 'Step': steps[s_idx], 'SourceID': f'MRI_DUMMY_{s_idx%5 +1}',
                        'Predicted_Proportion': pred_proportions[s_idx], 
                        'Predicted_Increment': pred_increments_from_transformer[s_idx],
                        'Predicted_Cumulative': pred_cumulative_from_transformer[s_idx], 
                        'GroundTruth_Increment': gt_increments[s_idx], 
                        'GroundTruth_Cumulative': gt_cumulative[s_idx]  })
            if not dummy_data: 
                 dummy_data.append({ 'Sequence': 0, 'Step': 1, 'SourceID': 'MRI_DUMMY_0', 'Predicted_Proportion': 1.0, 
                                     'Predicted_Increment': 10.0, 'Predicted_Cumulative': 10.0, 
                                     'GroundTruth_Increment': 10.0, 'GroundTruth_Cumulative': 10.0})
            dummy_df = pd.DataFrame(dummy_data)
            dummy_df.to_csv(transformer_predictions_file, index=False)
            print(f"Dummy '{transformer_predictions_file}' created with {len(dummy_df)} rows and {len(dummy_df['Sequence'].unique())} sequences.")
        
        lstm_model, lstm_history, processed_lstm_data = train_total_time_lstm(
            transformer_predictions_file, epochs=200, batch_size=32 ) 
        
        if lstm_model is None or processed_lstm_data is None:
            print("LSTM training failed or returned None. Exiting."); return

        refined_results_df = generate_refined_predictions_with_lstm(lstm_model, processed_lstm_data)
        if not refined_results_df.empty:
            print("\nSample of Refined Predictions (LSTM Total Time Approach - v9):")
            display_cols = [ 'Sequence', 'Step', 'SourceID', 
                             'Predicted_Increment', 'LSTM_Predicted_Increment', 'GroundTruth_Increment', 
                             'Predicted_Cumulative', 'LSTM_Predicted_Cumulative', 'GroundTruth_Cumulative',
                             'LSTM_Predicted_TotalTime', 'Increment_Improvement_Pct']
            actual_display_cols = [col for col in display_cols if col in refined_results_df.columns]
            print(refined_results_df[actual_display_cols].head(20)) # Show more rows to see NaN behavior
            print("\nGenerating visualizations for LSTM (total time approach - v9)...")
            visualize_lstm_results(refined_results_df, processed_lstm_data, lstm_model, lstm_history)
        else: print("No refined predictions were generated by the LSTM flow.")
    except Exception as e:
        print(f"Error in LSTM (total time) main function: {e}"); import traceback; traceback.print_exc()


In [10]:
# %%
if __name__ == "__main__":
    main_lstm_total_time_flow()

Processing data for LSTM training...
Processing data from: combined_model_results_SN_175651_175974_182625.csv
Error in LSTM (total time) main function: CSV must contain 'Predicted_Proportion' column.


Traceback (most recent call last):
  File "C:\Users\lukis\AppData\Local\Temp\ipykernel_5448\317463272.py", line 527, in main_lstm_total_time_flow
    lstm_model, lstm_history, processed_lstm_data = train_total_time_lstm(
  File "C:\Users\lukis\AppData\Local\Temp\ipykernel_5448\317463272.py", line 215, in train_total_time_lstm
    data_for_lstm = process_input_data_for_lstm(transformer_predictions_file)
  File "C:\Users\lukis\AppData\Local\Temp\ipykernel_5448\317463272.py", line 109, in process_input_data_for_lstm
    raise ValueError(f"CSV must contain '{col}' column.")
ValueError: CSV must contain 'Predicted_Proportion' column.
