# *__Working on BTCUSD predictions with GRU model.__*

## __Check first before starting__

In [1]:
import os

# Change the working directory to the project root
# Working_directory = os.path.normpath("C:/Users/gilda/OneDrive/Documents/_NYCU/MASTER_S_studies/Master\'s Thesis/LABORATORY/_Global_Pytorch/Continual_Learning")
Working_directory = os.path.normpath("C:/Users/james/OneDrive/文件/Continual_Learning")
os.chdir(Working_directory)
print(f"Working directory: {os.getcwd()}")  # Prints the current working directory

Working directory: C:\Users\james\OneDrive\文件\Continual_Learning


## **__All imports__**

In [2]:
import ipywidgets as widgets
from IPython.display import display
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from mpl_interactions import zoom_factory, panhandler
from sklearn.model_selection import train_test_split
import pickle
from ta import trend, momentum, volatility, volume
import math
from scipy.ndimage import gaussian_filter1d
from typing import Callable, Tuple
import shutil
import contextlib
import traceback
import copy
import gc
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

## __**All functions (For data processing)**__

In [4]:
def ensure_folder(folder_path):
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)

def plot_with_matplotlib(data: pd.DataFrame, 
                         title: str, 
                         interactive: bool = False, 
                         save_path: str = None, 
                         show_plot: bool = True, 
                         save_matplotlib_object: str = None) -> None:
    """
    Plot data using Matplotlib, with optional interactivity using mpl-interactions.
    
    Parameters:
    - data (pd.DataFrame): The data to plot, must contain 'close' column.
    - title (str): The title of the plot.
    - interactive (bool): If True, enables interactive zoom and pan.
    - save_path (Optional[str]): If provided, saves the plot to this path.
    - show_plot (bool): If True, displays the plot. If False, skips display.
    - save_matplotlib_object (Optional[str]): If provided, saves the Matplotlib figure object to this file path.
    """
    if not all(col in data.columns for col in ['close']):
        raise ValueError("The input DataFrame must contain 'close' column.")

    # Use the default Matplotlib color cycle for the line
    default_blue = plt.rcParams['axes.prop_cycle'].by_key()['color'][0]
    print(f"Default blue: {default_blue}")

    # Explicit colors for trends
    trend_colors = {
        0: 'black',
        1: 'yellow',
        2: 'red',
        3: 'green',
        4: default_blue #'purple',
    }
    # unique_trends = [0, -25, -15, 15, 25]
    # colormap = plt.cm.get_cmap('tab10', len(unique_trends))  # Choose 'tab10' or 'Set1' for distinct colors
    # trend_colors = {trend: colormap(i) for i, trend in enumerate(unique_trends)}

    fig, ax = plt.subplots(figsize=(12, 6))

    # Plot data as a single connected line, colored by trend
    if 'trend' in data.columns:
        legend_added = set() # Track which trends have already been added to the legend
        prev_idx = data.index[0]
        for idx, row in data.iterrows():
            if idx != prev_idx:
                trend_key = int(row['trend'])  # Convert trend value to int for lookup
                label = f'Trend {trend_key}' if trend_key not in legend_added else None
                ax.plot([prev_idx, idx], 
                        [data.loc[prev_idx, 'close'], row['close']],
                        color=trend_colors[trend_key], 
                        linestyle='-', 
                        # marker='o', 
                        linewidth=1,
                        label=label  # Add label only if it's not in the legend
                )
                legend_added.add(trend_key)  # Mark this trend as added to the legend
            prev_idx = idx

        ax.set_title(f"{title} (Connected, Colored by Trend)")
    else:
        # Default plot if no trend column exists
        ax.plot(data.index, data['close'], label='Closing Price', linestyle='-', marker='o', 
                markersize=2, linewidth=1, color=default_blue, markerfacecolor='green', markeredgecolor='black')
        ax.set_title(title)
    
    ax.set_xlabel('Date')
    ax.set_ylabel('Closing Price (USD)')
    
    # Add a legend manually for trends
    # for trend, color in trend_colors.items():
    #     ax.plot([], [], color=color, label=f'Trend {trend}')
    ax.legend()
    ax.grid()
    
    # Enable interactivity if requested
    if interactive:
        zoom_factory(ax)  # Enable zoom with mouse wheel
        panhandler(fig)   # Enable panning with left-click
        print("Interactive mode enabled. Use mouse wheel to zoom and left click to pan.")

    # Save the plot if a path is provided
    if save_path:
        fig.tight_layout()  # Ensures the layout is clean
        fig.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Plot saved to: {save_path}")

    # Save the Matplotlib figure object
    if save_matplotlib_object:
        with open(save_matplotlib_object, 'wb') as f:
            pickle.dump(fig, f)
        print(f"Matplotlib figure object saved to: {save_matplotlib_object}")

    if show_plot:
        plt.show()
    else:
        print("Plot display skipped.")

def load_and_show_pickle(pickle_file_path: str):
    """
    Load a pickled Matplotlib figure object and display it with optional interactivity.

    Parameters:
    - pickle_file_path (str): Path to the pickled Matplotlib figure file.

    Returns:
    - None
    """
    try:
        # Open the pickle file and load the figure
        with open(pickle_file_path, "rb") as f:
            loaded_fig = pickle.load(f)

        print(f"Figure successfully loaded and displayed from: {pickle_file_path}")

        # Use plt.show() to allow interactivity
        plt.show(block=True)

    except FileNotFoundError:
        print(f"Error: File not found at {pickle_file_path}. Please check the path.")
    except Exception as e:
        print(f"Error loading the pickled figure: {e}")

def save_to_csv(df: pd.DataFrame, file_path: str):
    """
    Save a DataFrame to a CSV file.

    Parameters:
        df (pd.DataFrame): The DataFrame containing the data to be saved.
        file_path (str): The file path (including the file name) to save the CSV.

    Returns:
        None
    """
    df.to_csv(file_path)
    # df_to_save = df.copy()
    # df_to_save["date"] = df_to_save.index.strftime('%Y-%m-%d %H:%M:%S')  # Add formatted index as a column
    
    # Save the DataFrame to CSV
    # df_to_save.to_csv(file_path)
    print(f"\nSuccessfully saved data with moving average to CSV: \n\t{file_path}\n")

def read_csv_file(file_path: str, preview_rows: int = 5, 
                  days_towards_end: int = None, 
                  days_from_start: int = None, description: str = ""):
    """
    Reads a CSV file and returns a pandas DataFrame filtered by date range.

    Args:
        file_path (str): The path to the CSV file.
        preview_rows (int): Number of rows to preview (default is 5).
        days_towards_end (int, optional): Number of days from the most recent date to retrieve data.
        days_from_start (int, optional): Number of days from the oldest date of the filtered data to retrieve data.
        description (str): A brief description of the dataset being loaded.
                           Explanation:
                           - To retrieve data from the **end**: Use `days_towards_end`.
                           - To retrieve data from the **start of the filtered range**: Use `days_from_start`.
                           - To retrieve data from the **middle**: Use both:
                             For example, if `days_towards_end=100` and `days_from_start=50`,
                             the function will first filter the last 100 days of the dataset,
                             and then filter the first 50 days from this range.
                             This results in data between the last 100th and the last 50th day.

    Returns:
        DataFrame: The loaded and filtered data from the CSV file.
    """
    try:
        if description:
            print(f"\nDescription: {description}")
        print(f"\nFile path: {file_path}")
        
        # Read the CSV file
        data = pd.read_csv(file_path, parse_dates=['date'], index_col='date')
        
        # Filter by days towards the end
        if days_towards_end is not None:
            last_date = data.index.max()  # Get the most recent date in the dataset
            end_cutoff_date = last_date - pd.Timedelta(days=days_towards_end)
            data = data[data.index >= end_cutoff_date]
            print(f"\nRetrieving data from the past {days_towards_end} days (from {end_cutoff_date.date()} onwards):")
        
        # Filter by days from the start (from the filtered data)
        if days_from_start is not None:
            first_date = data.index.min()  # Get the earliest date in the filtered dataset
            start_cutoff_date = first_date + pd.Timedelta(days=days_from_start)
            data = data[data.index <= start_cutoff_date]
            print(f"\nRetrieving the first {days_from_start} days from the filtered data (up to {start_cutoff_date.date()}):")

        if preview_rows:
            # Print a preview of the data
            print(f"\nPreview of the first {preview_rows} rows:")
            # print(data.head(preview_rows), '\n')
            display(data.head(preview_rows))
            print()

            print(f"\nPreview of the last {preview_rows} rows:")
            # print(data.tail(preview_rows), '\n')
            display(data.tail(preview_rows))
            print()

        return data
    except FileNotFoundError:
        print("Error: File not found. Please check the file path.")
    except pd.errors.EmptyDataError:
        print("Error: The file is empty.")
    except pd.errors.ParserError:
        print("Error: The file could not be parsed. Please check the file format.")
    except Exception as e:
        print(f"An unexpected error occurred: {e}")

def downsample_minute_data(data: pd.DataFrame, n: int) -> pd.DataFrame:
    """
    Downsample minute data into N-minute intervals by retaining every Nth row.

    Parameters:
        data (pd.DataFrame): The original DataFrame with a datetime index.
        n (int): The number of minutes for the downsampling interval.

    Returns:
        pd.DataFrame: Downsampled DataFrame.
    """
    print("\n========---> Downsampling the data! \n")
    data = data.copy()
    # Ensure the index is a DatetimeIndex
    if not isinstance(data.index, pd.DatetimeIndex):
        try:
            data.index = pd.to_datetime(data.index)  # Convert to DatetimeIndex
        except Exception as e:
            raise ValueError("The DataFrame index could not be converted to DatetimeIndex.") from e

    # Filter rows where the minute index modulo N is 0
    downsampled_data = data[data.index.minute % n == 0]

    return downsampled_data

def calculate_log_returns_all_columns(data: pd.DataFrame, exclude_columns: list = [], dropna: bool = True) -> pd.DataFrame:
    """
    Calculate log returns for all numeric columns in a pandas DataFrame,
    excluding specified columns, and removing excluded columns from the returned DataFrame.

    Args:
        data (pd.DataFrame): Input DataFrame containing numeric data.
        exclude_columns (list): List of columns to exclude from log return calculations and the result.
        dropna (bool): Whether to drop rows with NaN values resulting from the calculation.

    Returns:
        pd.DataFrame: DataFrame with log returns for numeric columns,
                      excluding specified columns.
    """
    data = data.copy().drop(columns=exclude_columns)
    columns_to_transform = data.select_dtypes(include=[np.number]).columns
    print(f"columns_to_transform = \n{columns_to_transform}, \nlen(columns_to_transform) = {len(columns_to_transform)}")

    for col in columns_to_transform:
        # Ensure no negative or zero values
        if (data[col] <= 0).any():
            raise ValueError(f"Column '{col}' contains non-positive values. Log returns require strictly positive values.")
        data[col] = np.log(data[col] / data[col].shift(1))

    # Optionally drop rows with NaN values
    return data.dropna() if dropna else data

def created_sequences_2(data: pd.DataFrame, sequence_length: int = 60, sliding_interval: int = 60) -> list:
    """
    Divide the dataset into sequences based on the sequence_length.
    Each sequence must fully cover the window size.

    Args:
    - data (pd.DataFrame): The input DataFrame.
    - sequence_length (int): The window size for sequences.

    Returns:
    - sequences (list): A list of sequences (as DataFrames).
    """
    sequences = []
    for i in range(0, len(data) - sequence_length + 1, sliding_interval):
        # print(f"Processing sequence starting at index: {i}")
        seq = data.iloc[i:i + sequence_length].copy()
        sequences.append(seq)
    # print(f"Total sequences created: {len(sequences)}")
    return sequences

def gaussian_smoothing(data: pd.DataFrame, sigma=2) -> pd.DataFrame:
    """
    Applies Gaussian smoothing to numeric columns in a DataFrame and ensures the index is sorted in ascending order.

    Args:
        data (pd.DataFrame): Input DataFrame.
        sigma (float): Standard deviation for the Gaussian kernel. Defaults to 2.

    Returns:
        pd.DataFrame: A DataFrame with smoothed numeric columns and sorted index.
    """
    # Sort the DataFrame by index in ascending order
    data = data.sort_index(ascending=True).copy()
    for column in data.columns:
        if pd.api.types.is_numeric_dtype(data[column]):  # Only apply to numeric columns
            data[column] = gaussian_filter1d(data[column].values, sigma=sigma)
    return data

def detect_trends_4(
    dataframe: pd.DataFrame, 
    column: str = 'close', 
    lower_threshold: float = 0.001, 
    upper_threshold: float = 0.02,
    reverse_steps: int = 7,
    trends_to_keep: set = {0, 1, 2, 3, 4}  # Default keeps all trends
) -> pd.DataFrame:
    """
    Detects trends based on log return data provided in a specified column and categorizes them into different strength levels.

    This function analyzes time-series data by evaluating cumulative trends in log return values provided in the input DataFrame. 
    It uses three dictionaries (`dic1`, `dic2`, `dic3`) to track different phases of trends, handles multi-step reversals, and 
    classifies trends dynamically based on cumulative product thresholds and specified thresholds for trend strengths.

    Args:
        dataframe (pd.DataFrame): Input DataFrame containing log return data.
        column (str): Column name containing log return values. Defaults to 'close'.
        lower_threshold (float): Threshold for categorizing moderate trends. Defaults to 0.001.
        upper_threshold (float): Threshold for categorizing strong trends. Defaults to 0.02.
        reverse_steps (int): Number of consecutive steps to confirm a trend reversal. Defaults to 7.
        trends_to_keep (set): A set of trend categories to retain; others will be set to 0 (No Trend). Defaults to keeping all trends {0, 1, 2, 3, 4}.

    Returns:
        pd.DataFrame: A DataFrame with an added column:
                    - 'trend': Categorized trend values based on the detected phases:
                        - 0: No trend
                        - 1: Moderate negative trend
                        - 2: Very strong negative trend
                        - 3: Moderate positive trend
                        - 4: Very strong positive trend
                      Any trends not included in `trends_to_keep` will be reset to 0.

    Function Details:
    1. **Input Assumption**:
    - The input DataFrame already contains log return data in the specified column (`column`).

    2. **Trend Tracking**:
    - Uses dictionaries to monitor trends:
        - `dic1`: Tracks the first phase of the trend.
        - `dic2`: Tracks the second phase if a reversal occurs.
        - `dic3`: Tracks the third phase if another reversal occurs.

    3. **Cumulative Product**:
    - Calculates the cumulative product of `(1 + log_return)` from the specified column to evaluate the strength of trends.

    4. **Reversal Handling**:
    - If a trend reversal persists beyond `reverse_steps`, labels are assigned based on the cumulative product tracked in `dic1`.
    - Subsequent reversals are merged or labeled independently if conditions are met.

    5. **Label Assignment**:
    - Labels are dynamically assigned based on cumulative product thresholds for positive and negative trends:
        - Positive trends are categorized as moderate (3) or strong (4).
        - Negative trends are categorized as moderate (1) or strong (2).

    6. **Trend Filtering**:
    - After detecting trends, only those specified in `trends_to_keep` remain unchanged.
    - Any trend category not included in `trends_to_keep` is reset to 0 (No Trend).

    7. **Edge Cases**:
    - Properly handles scenarios where data points are insufficient for trend analysis or when trend phases overlap, ensuring all data points are labeled.
    """
    # Copy to avoid modifying the original DataFrame
    df = dataframe.copy()
    df['trend'] = None  # Default value 
    
    # print("\n#-------------------- Working on 'trend' patterns -----------------------#")
    dic1, dic2, dic3 = None, None, None # Initialize trend tracking dictionaries
    # dic1 = None # {'ids': [], 'last_sign': None, 'cumulative': 1.0}
    
    def assign_label(dictio_, lower_threshold, upper_threshold):
        cumulative = dictio_['cumulative']
        # print(f"cumulative = {cumulative}")
        if cumulative > (1 + upper_threshold):
            df.iloc[dictio_['ids'], df.columns.get_loc('trend')] = 4  # Very strong positive
        elif (1 + lower_threshold) < cumulative <= (1 + upper_threshold):
            df.iloc[dictio_['ids'], df.columns.get_loc('trend')] = 3  # Moderate positive
        elif (1 - upper_threshold) < cumulative <= (1 - lower_threshold):
            df.iloc[dictio_['ids'], df.columns.get_loc('trend')] = 1  # Moderate negative
        elif cumulative <= (1 - upper_threshold):
            df.iloc[dictio_['ids'], df.columns.get_loc('trend')] = 2  # Very strong negative
        else:
            df.iloc[dictio_['ids'], df.columns.get_loc('trend')] = 0  # No trend
    
    #----------------------- For Loop -----------------------#
    for idx, log_ret in enumerate(df[column]):
        sign = 1 if log_ret > 0 else -1

        if dic1 is None:  # Initialize dic1
            # print(f"\nThis one time condition 'if loop' is running \n")
            dic1 = {'ids': [idx], 'last_sign': sign, 'cumulative': (1 + log_ret)}
            continue
        last_sign = dic1['last_sign']
        if sign == last_sign and dic2 is None:  # Continue same trend
            dic1['ids'].append(idx)
            dic1['last_sign'] = sign
            dic1['cumulative'] *= (1 + log_ret)
            continue

        # 1st Reversal occuring
        if dic2 is None:  # Start dic2
            dic2 = {'ids': [idx], 'last_sign': sign, 'cumulative': (1 + log_ret)}
            continue
        last_sign = dic2['last_sign']
        if sign == last_sign and dic3 is None:  # Continue same trend
            dic2['ids'].append(idx)
            dic2['last_sign'] = sign
            dic2['cumulative'] *= (1 + log_ret)
            if len(dic2['ids']) == reverse_steps:
                assign_label(dic1, lower_threshold, upper_threshold) # Assign labels in the 'trend' column for ids of dic1
                # print(f"dic1['cumulative'] = {dic1['cumulative']}, and dic1['ids'] = {dic1['ids']}")
                dic1 = dic2
                dic2 = None
                # print(f"dic1 after trend reversal persisted and dic1 = dic2 = \n{dic1}")
                # print(f"dic2 after being reset: {dic2}\n")
            continue

        # 2nd Reversal occuring
        if dic3 is None:  # Start dic3
            dic3 = {'ids': [idx], 'last_sign': sign, 'cumulative': (1 + log_ret)}
            continue
        last_sign = dic3['last_sign']
        if sign == last_sign: # Continue same trend, there is no dic4 to check if is None
            dic3['ids'].append(idx)
            dic3['last_sign'] = sign
            dic3['cumulative'] *= (1 + log_ret)
            dic_prod = dic2['cumulative'] * dic3['cumulative']
            # if (sign == 1 and dic1['cumulative'] * dic_prod > dic1['cumulative']) or (sign == -1 and dic1['cumulative'] * dic_prod < dic1['cumulative'])):
            if (sign == 1 and dic_prod > 1) or (sign == -1 and dic_prod < 1): # More beautiful
                # Merge dic1, dic2, and dic3
                dic1['ids'] += dic2['ids'] + dic3['ids']
                dic1['last_sign'] = dic3['last_sign']
                dic1['cumulative'] *= dic2['cumulative'] * dic3['cumulative']
                dic2, dic3 = None, None
                continue

            if len(dic3['ids']) == reverse_steps:      
                assign_label(dic1, lower_threshold, upper_threshold) # Assign labels in the 'trend' column for ids of dic1
                assign_label(dic2, lower_threshold, upper_threshold) # Assign labels in the 'trend' column for ids of dic1
                dic1 = dic3
                dic2, dic3 = None, None
                # print(f"dic2 after 2nd trend reversal didn't catch up fast enough, and now \ndic1 = dic3 = {dic1}")
                # print(f"dic3 and dic2 after being reset: {dic3}\n")
            continue
            
        # 3rd Reversal occuring
        assign_label(dic1, lower_threshold, upper_threshold) # Assign labels in the 'trend' column for ids of dic1
        # Reassign values
        dic1 = dic2
        dic2 = dic3
        dic3 = {'ids': [idx], 'last_sign': sign, 'cumulative': (1 + log_ret)}
        # print(f"There was a 3rd trend reversal, and now \ndic1 = dic2 = {dic1}, \ndic2 = dic3 = {dic2}")
        # print(f"dic3 after being reset: {dic3}\n")

    # Assign remaining labels
    if dic1:
        assign_label(dic1, lower_threshold, upper_threshold)
    if dic2:
        assign_label(dic2, lower_threshold, upper_threshold)
    if dic3:
        assign_label(dic3, lower_threshold, upper_threshold)
    # print("\n#-------------------- Returning 'trend' patterns ------------------------#")
    
    # Apply filtering: Keep only selected trends, set others to 0
    df['trend'] = df['trend'].apply(lambda x: x if x in trends_to_keep else 0)

    return df

def split_X_y(sequences: list[pd.DataFrame], 
              target_column: str = 'trend',
              detect_trends_function: Callable[[pd.DataFrame, str, float, float, int, set], pd.DataFrame] = detect_trends_4, 
              column: str = 'close', 
              lower_threshold: float = 0.0009, 
              upper_threshold: float = 0.015,
              reverse_steps: int = 7,
              trends_to_keep: set = {0, 1, 2, 3, 4}) -> Tuple[np.ndarray, np.ndarray]:
    """
    Process sequences to generate features (X) and labels (y) while applying trend detection.

    Args:
    - sequences (list of pd.DataFrame): List of DataFrame sequences.
    - lower_threshold (float): Lower threshold for trend detection.
    - upper_threshold (float): Upper threshold for trend detection.
    - reverse_steps (int): Steps to reverse trends in the sequence.
    - target_column (str): Column name to use as the label (default: 'trend').
    - detect_trends_function (Callable): Function for detecting trends, defaults to `detect_trends_4`.
    - trends_to_keep (set): A set of trend categories to retain; others will be set to 0 (No Trend).

    Returns:
    - X (np.ndarray): Features array of shape (num_sequences, sequence_length, num_features).
    - y (np.ndarray): Labels array of shape (num_sequences,).
    """
    X = []
    y = []
    # count = 0
    for seq in sequences:
        # Apply trend detection on the sequence
        seq = detect_trends_function(seq, column=column, 
                                     lower_threshold=lower_threshold, 
                                     upper_threshold=upper_threshold, 
                                     reverse_steps=reverse_steps,
                                     trends_to_keep=trends_to_keep)
        # if count == 0:
        #     count = 1
        #     print(f"\nseq.head()")
        #     display(seq.head())
        #     print()
        # Extract features (X) and labels (y)
        X.append(seq.drop(columns=[target_column]).values)  # All but the target column
        y.append(seq[target_column].values)  # Target column as labels
        
    # return np.array(X, dtype=np.float32), np.array(y, dtype=np.int64)
    return np.array(X), np.array(y)

def process_and_return_splits(
    with_indicators_file_path: str,
    downsampled_data_minutes: int,
    exclude_columns: list[str],
    lower_threshold: float,
    upper_threshold: float,
    reverse_steps: int,
    sequence_length: int,
    sliding_interval: int,
    trends_to_keep: set = {0, 1, 2, 3, 4}  # Default keeps all trends
) -> tuple[
    list[list[float]],  # X_train: List of sequences, each containing a list of features
    list[list[int]],    # y_train: List of sequences, each containing a list of labels
    list[list[float]],  # X_val: List of sequences, each containing a list of features
    list[list[int]],    # y_val: List of sequences, each containing a list of labels
    list[list[float]],  # X_test: List of sequences, each containing a list of features
    list[list[int]]     # y_test: List of sequences, each containing a list of labels
]:
    """
    Processes time-series data from a CSV file and prepares it for machine learning.

    This function performs the following steps:
        1. Reads data from the specified CSV file and sorts it by date in descending order.
        2. Optionally downsamples the data to a lower frequency (e.g., 5-minute intervals).
        3. Applies Gaussian smoothing to reduce noise in the data.
        4. Calculates log returns for all numeric columns, excluding specified columns.
        5. Detects trends based on defined thresholds (`lower_threshold`, `upper_threshold`, and `reverse_steps`).
        6. Filters trends to keep only those specified in `trends_to_keep`, setting others to 0 (No Trend).
        7. Converts the processed data into sequences of a fixed length (`sequence_length`) with a sliding interval.
        8. Splits the sequences into training (80%), validation (10%), and test (10%) sets.
        9. Further splits the sequences into features (`X`) and labels (`y`) for supervised learning.

    Args:
        with_indicators_file_path (str): Path to the CSV file containing the time-series data.
        downsampled_data_minutes (int): Frequency for downsampling the data (e.g., 1 for no downsampling).
        exclude_columns (list[str]): List of column names to exclude from log return calculations.
        lower_threshold (float): Lower threshold for trend detection.
        upper_threshold (float): Upper threshold for trend detection.
        reverse_steps (int): Number of steps for reversing trends in trend detection.
        sequence_length (int): Length of sequences to create from the data.
        sliding_interval (int): Interval for sliding the window when creating sequences.
        trends_to_keep (set): A set of trend categories to retain; others will be set to 0 (No Trend). Defaults to keeping all trends {0, 1, 2, 3, 4}.

    Returns:
        tuple[list[list[float]], list[list[int]], list[list[float]], list[list[int]], list[list[float]], list[list[int]]]:
            A tuple containing:
            - X_train (list[list[float]]): Input sequences for training.
            - y_train (list[list[int]]): Target sequences for training.
            - X_val (list[list[float]]): Input sequences for validation.
            - y_val (list[list[int]]): Target sequences for validation.
            - X_test (list[list[float]]): Input sequences for testing.
            - y_test (list[list[int]]): Target sequences for testing.

    Example:
        X_train, y_train, X_val, y_val, X_test, y_test = process_and_return_splits(
            with_indicators_file_path="data.csv",
            downsampled_data_minutes=5,
            exclude_columns=["volume"],
            lower_threshold=-0.05,
            upper_threshold=0.05,
            reverse_steps=3,
            sequence_length=50,
            sliding_interval=5,
            trends_to_keep={1, 2, 3, 4}  # Only keep categorized trends, set others to 0
        )
    """

    data_retrieved = read_csv_file(with_indicators_file_path, preview_rows=0) # 190 days of data
    data_retrieved = data_retrieved.sort_index(ascending=False)

    #---------------------------------------------------------------------------------------
    if downsampled_data_minutes != 1:
        print("Downsampling the data! \n")
        data_retrieved = downsample_minute_data(data_retrieved, downsampled_data_minutes)
    #---------------------------------------------------------------------------------------

    # Get missing timestamps
    missing_timestamps = pd.date_range(
        start=data_retrieved.index.min(), # Returns smallest/earliest/oldest date
        end=data_retrieved.index.max(),
        freq='1min',  # Use 'min' for a frequency of 1 minute, '30s' for a frequency of 30 seconds
        tz=data_retrieved.index.tz,
    ).difference(data_retrieved.index)
    print(f"\ndata_retrieved - Missing timestamps time: \n{missing_timestamps}") 

    data_gaussian = gaussian_smoothing(data_retrieved, sigma=7)

    # Get missing timestamps
    missing_timestamps = pd.date_range(
        start=data_gaussian.index.min(), # Returns smallest/earliest/oldest date
        end=data_gaussian.index.max(),
        freq='1min',  # Use 'min' for a frequency of 1 minute, '30s' for a frequency of 30 seconds
        tz=data_gaussian.index.tz,
    ).difference(data_gaussian.index)
    print(f"\ndata_gaussian - Missing timestamps time: \n{missing_timestamps}\n")

    data_log_return = calculate_log_returns_all_columns(data_gaussian, exclude_columns=exclude_columns)

    # Get missing timestamps
    missing_timestamps = pd.date_range(
        start=data_log_return.index.min(), # Returns smallest/earliest/oldest date
        end=data_log_return.index.max(),
        freq='1min',  # Use 'min' for a frequency of 1 minute, '30s' for a frequency of 30 seconds
        tz=data_log_return.index.tz,
    ).difference(data_log_return.index)
    print(f"\ndata_log_return - Missing timestamps time: \n{missing_timestamps}\n") 

    # Check if there are missing timestamps
    if missing_timestamps.empty:
        print("No missing timestamps.")
    else:
        for timestamp in missing_timestamps:
            print(f"\nMissing timestamp: {timestamp}")
            
            # Create a placeholder for the missing timestamp
            if timestamp not in data_log_return.index:
                print('Missing')
            
            # Get data before and after the missing timestamp
            before_data = data_log_return[data_log_return.index < timestamp].tail(5)  # 5 data points before
            after_data = data_log_return[data_log_return.index > timestamp].head(5)  # 5 data points after
            
            # Display surrounding data
            if not before_data.empty:
                print("\nData before:")
                print(before_data)
            else:
                print("\nNo data available before the missing timestamp.")
            
            if not after_data.empty:
                print("\nData after:")
                print(after_data)
            else:
                print("\nNo data available after the missing timestamp.")

    sequences = created_sequences_2(data_log_return, sequence_length, sliding_interval)

    # Split sequences into training, validation, and test sets
    train_size = int(len(sequences) * 0.8)
    val_size = int(len(sequences) * 0.1)

    train_sequences = sequences[:train_size]
    val_sequences = sequences[train_size:train_size + val_size]
    test_sequences = sequences[train_size + val_size:]

    print(f"""
    Number of sequences:
        - sequences[0].shape: {sequences[0].shape}
        - Total sequences: {len(sequences)}
        - Train sequences: {len(train_sequences)}
        - Validation sequences: {len(val_sequences)}
        - Test sequences: {len(test_sequences)}
    """)

    # Process train, validation, and test sets
    X_train, y_train = split_X_y(train_sequences, 
                                target_column='trend',
                                detect_trends_function = detect_trends_4,
                                column= 'close',
                                lower_threshold=lower_threshold, 
                                upper_threshold=upper_threshold, 
                                reverse_steps=reverse_steps,
                                trends_to_keep=trends_to_keep)

    X_val, y_val = split_X_y(val_sequences, 
                            target_column='trend',
                            detect_trends_function = detect_trends_4,
                            column= 'close',
                            lower_threshold=lower_threshold, 
                            upper_threshold=upper_threshold, 
                            reverse_steps=reverse_steps,
                            trends_to_keep=trends_to_keep)

    X_test, y_test = split_X_y(test_sequences, 
                            target_column='trend',
                            detect_trends_function = detect_trends_4,
                            column= 'close',
                            lower_threshold=lower_threshold, 
                            upper_threshold=upper_threshold, 
                            reverse_steps=reverse_steps,
                            trends_to_keep=trends_to_keep)

    # Checking X arrays
    for idx, seq in enumerate(X_train):  # Loop through sequences
        for sub_idx, feature_set in enumerate(seq):  # Loop through data points
            for feature_idx, feature in enumerate(feature_set):  # Loop through features
                if not isinstance(feature, (float, np.float32)):  # Check each feature
                    print(f"Unexpected type in X_train at sequence {idx}, data point {sub_idx}, feature {feature_idx}: {type(feature)}")

    # Checking y arrays
    for idx, seq in enumerate(y_train):  # Loop through sequences
        for sub_idx, label in enumerate(seq):  # Loop through data points (labels)
            if not isinstance(label, (int, np.int64)):  # Check each label
                print(f"Unexpected type in y_train at sequence {idx}, data point {sub_idx}: {type(label)}")

    # Checking X arrays
    for idx, seq in enumerate(X_val):  # Loop through sequences
        for sub_idx, feature_set in enumerate(seq):  # Loop through data points
            for feature_idx, feature in enumerate(feature_set):  # Loop through features
                if not isinstance(feature, (float, np.float32)):  # Check each feature
                    print(f"Unexpected type in X_val at sequence {idx}, data point {sub_idx}, feature {feature_idx}: {type(feature)}")
    # Checking y arrays
    for idx, seq in enumerate(y_val):  # Loop through sequences
        for sub_idx, label in enumerate(seq):  # Loop through data points (labels)
            if not isinstance(label, (int, np.int64)):  # Check each label
                print(f"Unexpected type in y_val at sequence {idx}, data point {sub_idx}: {type(label)}")

    # Checking X arrays
    for idx, seq in enumerate(X_test):  # Loop through sequences
        for sub_idx, feature_set in enumerate(seq):  # Loop through data points
            for feature_idx, feature in enumerate(feature_set):  # Loop through features
                if not isinstance(feature, (float, np.float32)):  # Check each feature
                    print(f"Unexpected type in X_test at sequence {idx}, data point {sub_idx}, feature {feature_idx}: {type(feature)}")
    # Checking y arrays
    for idx, seq in enumerate(y_test):  # Loop through sequences
        for sub_idx, label in enumerate(seq):  # Loop through data points (labels)
            if not isinstance(label, (int, np.int64)):  # Check each label
                print(f"Unexpected type in y_test at sequence {idx}, data point {sub_idx}: {type(label)}")

    if isinstance(y_train, np.ndarray) and y_train.dtype == np.object_:
        # Convert to numeric if needed
        y_train = np.array(y_train, dtype=np.int64)

    if isinstance(y_val, np.ndarray) and y_val.dtype == np.object_:
        # Convert to numeric if needed
        y_val = np.array(y_val, dtype=np.int64)

    if isinstance(y_test, np.ndarray) and y_test.dtype == np.object_:
        # Convert to numeric if needed
        y_test = np.array(y_test, dtype=np.int64)

    close_col_index = data_log_return.columns.get_loc('close') # 'date' is set as index so doesnt count as a column
    Number_features = X_train.shape[-1]
    print(f"close_col_index = {close_col_index}")
    print(f"Number_features = {Number_features}")

    return X_train, y_train, X_val, y_val, X_test, y_test, Number_features

def print_class_distribution(y, var_name: str) -> None:
    """
    Prints the class distribution of a label array.

    Args:
        y: Tensor, array, or list of class labels.
        var_name: Name of the variable (for display).
    """
    if isinstance(y, torch.Tensor):
        y = y.cpu().numpy()
    flattened = np.array(y).flatten()

    unique_classes, counts = np.unique(flattened, return_counts=True)
    total = counts.sum()

    header = f"Class Distribution for '{var_name}':"
    line_parts = [
        f"Class {int(c):<3} Percent: {(count / total) * 100:>6.2f}%"
        for c, count in zip(unique_classes, counts)
    ]
    print(header.ljust(40) + " || ".join(line_parts))


## __All (Initial) parameters__

In [5]:
ticker = 'BTC-USD'
downsampled_data_minutes = 1 # No downsampling

# Step 0 (Again): Identify parameters for trend settings of the loaded data with 1,000 data points
lower_threshold = 0.0009 
upper_threshold = 0.015  
reverse_steps = 13       

# Features not to be included in the analysis
exclude_columns= ['MACD', 'MACD_signal', 'ROC_10', 'OBV', 'AD_Line']

# Step 3, under ### Correlation Analysis
# Compute correlations with the 'trend' column
# corr = data_trends.corr()
# trend_corr = corr['trend'].sort_values(ascending=False)
strongly_correlated = ['close', 'open', 'SMA_5', 'high', 'low', 'EMA_10', 'SMA_10'] # Strongly correlated (correlation > 0.6)
moderately_correlated = ['BB_middle', 'BB_lower', 'BB_upper', 'RSI_14']             # Moderately correlated (correlation between 0.3 and 0.6)
weakly_correlated = ['SMA_50', 'volume', 'BBW', 'ATR_14']                           # Weakly correlated or negligible (correlation <~ 0.3)

# Add the weakly_correlated and moderately_correlated features to exclude_columns.
exclude_columns += weakly_correlated + moderately_correlated

sequence_length = 1000
sliding_interval = 60

## __Check GPU, CUDA, Pytorch__

### GPU details

In [6]:
!nvidia-smi

Sun Apr 20 17:32:52 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 566.03                 Driver Version: 566.03         CUDA Version: 12.7     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4070      WDDM  |   00000000:01:00.0  On |                  N/A |
|  0%   42C    P8             11W /  200W |    1332MiB /  12282MiB |     13%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

### CUDA Details

In [7]:
def check_gpu_config():
    """
    Check GPU availability and display detailed configuration information.
    """
    # Check if GPU is available
    gpu_available = torch.cuda.is_available()
    
    # Print header
    print("=" * 50)
    print("GPU Configuration Check".center(50))
    print("=" * 50)
    
    # Basic GPU availability
    print(f"{'PyTorch Version':<25}: {torch.__version__}")
    print(f"{'GPU Available':<25}: {'Yes' if gpu_available else 'No'}")
    
    # If GPU is available, print detailed info
    if gpu_available:
        print("-" * 50)
        print("GPU Details".center(50))
        print("-" * 50)
        
        # Device info
        print(f"{'Device Name':<25}: {torch.cuda.get_device_name(0)}")
        print(f"{'Number of GPUs':<25}: {torch.cuda.device_count()}")
        print(f"{'Current Device Index':<25}: {torch.cuda.current_device()}")
        
        # Compute capability and CUDA cores
        props = torch.cuda.get_device_properties(0)
        print(f"{'Compute Capability':<25}: {props.major}.{props.minor}")
        print(f"{'Total CUDA Cores':<25}: {props.multi_processor_count * 128}")  # Approx. 128 cores per SM
        
        # Memory info
        total_memory = props.total_memory / (1024 ** 3)  # Convert to GB
        memory_allocated = torch.cuda.memory_allocated(0) / (1024 ** 3)
        memory_reserved = torch.cuda.memory_reserved(0) / (1024 ** 3)
        print(f"{'Total Memory (GB)':<25}: {total_memory:.2f}")
        print(f"{'Allocated Memory (GB)':<25}: {memory_allocated:.2f}")
        print(f"{'Reserved Memory (GB)':<25}: {memory_reserved:.2f}")
    else:
        print("-" * 50)
        print("No GPU detected. Running on CPU.".center(50))
        print("-" * 50)
    
    print("=" * 50)

if __name__ == "__main__":
    check_gpu_config()

             GPU Configuration Check              
PyTorch Version          : 2.4.1+cu124
GPU Available            : Yes
--------------------------------------------------
                   GPU Details                    
--------------------------------------------------
Device Name              : NVIDIA GeForce RTX 4070
Number of GPUs           : 1
Current Device Index     : 0
Compute Capability       : 8.9
Total CUDA Cores         : 5888
Total Memory (GB)        : 11.99
Allocated Memory (GB)    : 0.00
Reserved Memory (GB)     : 0.00


### PyTorch Details

In [8]:
def print_torch_config():
    """Print PyTorch and CUDA configuration in a formatted manner."""
    print("=" * 50)
    print("PyTorch Configuration".center(50))
    print("=" * 50)
    
    # Basic PyTorch and CUDA info
    print(f"{'PyTorch Version':<25}: {torch.__version__}")
    print(f"{'CUDA Compiled Version':<25}: {torch.version.cuda}")
    print(f"{'CUDA Available':<25}: {'Yes' if torch.cuda.is_available() else 'No'}")
    print(f"{'Number of GPUs':<25}: {torch.cuda.device_count()}")

    # GPU details if available
    if torch.cuda.is_available():
        print(f"{'GPU Name':<25}: {torch.cuda.get_device_name(0)}")

    print("-" * 50)
    
    # Seed setting
    torch.manual_seed(42)
    print(f"{'Random Seed':<25}: 42 (Seeding successful!)")
    
    print("=" * 50)

if __name__ == "__main__":
    print_torch_config()

              PyTorch Configuration               
PyTorch Version          : 2.4.1+cu124
CUDA Compiled Version    : 12.4
CUDA Available           : Yes
Number of GPUs           : 1
GPU Name                 : NVIDIA GeForce RTX 4070
--------------------------------------------------
Random Seed              : 42 (Seeding successful!)


## __Build the GRU Model__

#### Bi-Directional GRU with Attention

In [9]:
class BiGRUWithAttention(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, output_size: int, num_layers: int, dropout: float = 0.0):
        super(BiGRUWithAttention, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        # Bi-Directional GRU Layer
        self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True, dropout=dropout)
        # Attention layer
        self.attention_fc = nn.Linear(hidden_size * 2, hidden_size * 2)  # Hidden size * 2 for bi-directional
        # Fully connected layer for classification
        self.fc = nn.Linear(hidden_size * 2, output_size)
        self.dropout = nn.Dropout(dropout)  # Apply dropout before the fully connected layer
        self.init_weights()

    def init_weights(self):
        for name, param in self.named_parameters():
            if 'weight_ih' in name or 'weight_hh' in name: # GRU weights
                nn.init.xavier_uniform_(param)
            elif 'weight' in name: # Other weights (attention, fc)
                nn.init.xavier_uniform_(param)  # Xavier initialization for weights
            elif 'bias' in name:
                nn.init.constant_(param, 0)  # Zero initialization for biases

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(x.device)  # Bi-directional: num_layers * 2
        # Bi-Directional GRU forward pass
        out, _ = self.gru(x, h0)  # Shape: (batch_size, seq_length, hidden_size * 2)

        # Attention mechanism
        attn_weights = torch.tanh(self.attention_fc(out))  # Shape: (batch_size, seq_length, hidden_size * 2)
        out = attn_weights * out  # Element-wise attention application
        out = self.dropout(out)  # Apply dropout

        # Fully connected layer (applied at each time step)
        out = self.fc(out)  # Shape: (batch_size, seq_length, output_size)
        return out


#### EWC Class Definition

In [10]:
##############################################
# EWC Class Definition
##############################################

class EWC:
    """
    Elastic Weight Consolidation (EWC) for preventing catastrophic forgetting.
    This class computes the Fisher Information matrix for each parameter (using a provided dataloader)
    and stores a copy of the model parameters after training on the previous task, and
    provides a penalty term for the loss function.
    """
    def __init__(self, fisher: dict, params: dict):
        """
        Initializes EWC with pre-computed Fisher matrix and optimal parameters.
        Args:
            fisher (dict): Dictionary mapping parameter names to Fisher information tensors (on CPU).
            params (dict): Dictionary mapping parameter names to optimal parameter tensors from the previous task (on CPU).
        """
        # Store fisher and params, assuming they are passed on CPU
        self.fisher = {k: v.cpu() for k, v in fisher.items()} # Ensure CPU
        self.params = {k: v.cpu() for k, v in params.items()} # Ensure CPU

    @staticmethod
    def compute_fisher_and_params(model: nn.Module, dataloader: DataLoader, criterion: nn.Module,
                                  device: torch.device, sample_size: int = None):
        """
        Computes the diagonal Fisher Information Matrix (FIM) and copies optimal parameters.
        (REVISED VERSION 3: Uses model.train() required by RNN backward, but
         temporarily disables dropout during forward pass).

        Args:
            model (nn.Module): The trained model (best model for the task).
            dataloader (DataLoader): DataLoader for the task's training data.
            criterion (nn.Module): The loss function used for training.
            device (torch.device): The device ('cuda' or 'cpu') for computation.
            sample_size (int, optional): Max number of samples to use for FIM estimation.

        Returns:
            tuple(dict, dict): (fisher_dict_cpu, params_dict_cpu) on CPU.
        """
        print("Starting Fisher Information Matrix calculation (v3: train mode for RNN backward)...")

        # --- Store optimal parameters ---
        try:
            params_dict_cpu = {n: p.clone().cpu().detach()
                               for n, p in model.named_parameters() if p.requires_grad}
        except Exception as e:
            print(f"Error cloning/detaching parameters: {e}")
            return {}, {}

        # --- Initialize Fisher dictionary on the computation device ---
        fisher_dict = {n: torch.zeros_like(p, device=device)
                       for n, p in model.named_parameters() if p.requires_grad and n in params_dict_cpu}

        if not fisher_dict:
             print("Warning: No parameters requiring gradients found.")
             return {}, params_dict_cpu

        # --- Store original mode and set model to TRAIN mode (required for RNN backward) ---
        original_mode_was_training = model.training
        model.train()
        print(f"Model temporarily set to TRAIN mode for Fisher calculation.")

        # --- Temporarily disable Dropout layers ---
        # Store original dropout states
        dropout_states = {}
        for name, module in model.named_modules():
            if isinstance(module, nn.Dropout):
                dropout_states[name] = module.training
                module.eval() # Put dropout layer itself in eval mode
        if dropout_states:
             print("Dropout layers temporarily disabled during forward pass.")

        # --- Loop Variables ---
        num_samples_processed = 0
        num_batches = 0
        calculation_successful = True

        print(f"Processing data batches on device: {device}")
        for batch_idx, (x, y) in enumerate(dataloader):
            x, y = x.to(device), y.to(device)
            current_batch_size = x.size(0)

            model.zero_grad()

            # --- Forward pass ---
            # Model is in train() mode globally, but Dropout layers are manually in eval()
            try:
                 outputs = model(x)
                 output_size = outputs.size(-1)
                 outputs_flat = outputs.view(-1, output_size)
                 y_flat = y.view(-1)
            except Exception as e:
                 print(f"----> ERROR during forward pass or reshape in batch {batch_idx}: {e}. Skipping.")
                 calculation_successful = False
                 continue # Skip to next batch

            # --- Validate labels ---
            valid_indices = (y_flat >= 0) & (y_flat < output_size)
            if not valid_indices.all():
                # Filter labels (optional print)
                outputs_flat = outputs_flat[valid_indices]
                y_flat = y_flat[valid_indices]
                if y_flat.numel() == 0: continue # Skip if batch empty

            # --- Calculate loss ---
            try:
                loss = criterion(outputs_flat, y_flat)
                if torch.isnan(loss) or torch.isinf(loss):
                    print(f"----> WARNING: Loss is {loss} in batch {batch_idx}. Skipping backward pass.")
                    continue # Skip gradient calculation
            except Exception as e:
                 print(f"----> ERROR calculating loss in batch {batch_idx}: {e}. Skipping.")
                 calculation_successful = False
                 continue

            # --- Calculate gradients (Model *must* be in train mode here for RNN) ---
            try:
                loss.backward()
            except Exception as e:
                # Catch the specific RNN error if it happens again, though it shouldn't now
                if "RNN backward can only be called in training mode" in str(e):
                     print(f"----> CRITICAL INTERNAL ERROR: RNN backward error despite model being in train mode? {e}")
                else:
                     print(f"----> ERROR during loss.backward() in batch {batch_idx}: {e}. Skipping accumulation.")
                calculation_successful = False
                continue # Skip accumulation for this batch

            # --- Accumulate squared gradients ---
            for n, p in model.named_parameters():
                if p.requires_grad and p.grad is not None and n in fisher_dict:
                    grad_data = p.grad.data
                    # --- Keep gradient checks ---
                    if torch.isnan(grad_data).any() or torch.isinf(grad_data).any(): continue # Skip if NaN/Inf
                    try: squared_grad = grad_data.to(device).pow(2)
                    except Exception as e: continue # Skip if power fails
                    if torch.isnan(squared_grad).any() or torch.isinf(squared_grad).any(): continue # Skip if NaN/Inf after pow
                    if squared_grad.min() < 0: # Critical check
                        print(f"----> CRITICAL ERROR: Negative value ({squared_grad.min().item():.4e}) found after .pow(2) for '{n}'.")
                        calculation_successful = False # Mark as failed
                        continue # Skip bad value
                    # --- Accumulate ---
                    fisher_dict[n] += squared_grad * current_batch_size

            # --- Update counters ---
            num_samples_processed += current_batch_size
            num_batches += 1

            # --- Check sample size limit ---
            if sample_size is not None and num_samples_processed >= sample_size:
                print(f"Stopping Fisher computation early after {num_samples_processed} samples.")
                break

        # --- Restore Dropout states ---
        if dropout_states:
            print("Restoring original Dropout layer states...")
            for name, module in model.named_modules():
                if name in dropout_states:
                    module.train(dropout_states[name]) # Set back to original state (True/False)

        # --- Restore original model mode ---
        model.train(original_mode_was_training) # Set back to True or False
        print(f"Model mode restored to original state (training={original_mode_was_training}).")

        # --- Post-loop checks and Normalization ---
        if num_samples_processed == 0:
             print("Warning: No samples processed successfully. Fisher matrix will be zeros.")
             fisher_dict_cpu = {n: f.cpu().detach() for n, f in fisher_dict.items()} # Still return zeros
             return fisher_dict_cpu, params_dict_cpu

        if not calculation_successful:
             print("ERROR: Issues encountered during calculation. Fisher matrix may be inaccurate or invalid.")
             # Return potentially bad dict but warn user
             
        print(f"Finished accumulating gradients over {num_samples_processed} samples.")
        print("Normalizing Fisher matrix...")
        fisher_dict_cpu = {}
        for n, f_tensor in fisher_dict.items():
            if num_samples_processed > 0:
                normalized_f = f_tensor / num_samples_processed
                # Final check after normalization
                if normalized_f.min() < 0: print(f"----> CRITICAL WARNING: Negative value found AFTER NORM for '{n}'.")
                if torch.isnan(normalized_f).any() or torch.isinf(normalized_f).any(): print(f"----> WARNING: NaN/Inf found AFTER NORM for '{n}'.")
                fisher_dict_cpu[n] = normalized_f.cpu().detach() # Store on CPU
            else:
                 fisher_dict_cpu[n] = torch.zeros_like(f_tensor).cpu().detach()

        print("Fisher Information Matrix calculation complete.")
        return fisher_dict_cpu, params_dict_cpu


    def penalty(self, model: nn.Module) -> torch.Tensor:
        """
        Calculates the EWC penalty term based on stored Fisher and params.
        (Includes optional checks for negative/nan/inf Fisher values)
        """
        if not hasattr(self, 'fisher') or not hasattr(self, 'params') or not self.fisher or not self.params:
             return torch.tensor(0.0, device=next(model.parameters()).device)

        penalty_loss = torch.tensor(0.0, device=next(model.parameters()).device)
        model_device = penalty_loss.device # Get model device once

        for n, p in model.named_parameters():
            # Check if param exists in EWC state and requires grad in current model
            if n in self.fisher and n in self.params and p.requires_grad:
                stored_fisher_n = self.fisher[n]
                stored_params_n = self.params[n]

                if p.shape == stored_params_n.shape:
                    try:
                        # Move stored tensors to model's device
                        fisher_n_device = stored_fisher_n.to(model_device)
                        params_n_device = stored_params_n.to(model_device)

                        # --- Check Fisher validity before using ---
                        if fisher_n_device.min() < 0:
                             # print(f"---> Warning: Negative Fisher ({fisher_n_device.min():.4e}) in PENALTY for '{n}'. Skipping.")
                             continue
                        if torch.isnan(fisher_n_device).any() or torch.isinf(fisher_n_device).any():
                             # print(f"---> Warning: NaN/Inf Fisher in PENALTY for '{n}'. Skipping.")
                             continue

                        diff_sq = (p - params_n_device).pow(2)
                        term_penalty = (fisher_n_device * diff_sq).sum()

                        # --- Check resulting term ---
                        if not torch.isnan(term_penalty) and not torch.isinf(term_penalty):
                              penalty_loss += term_penalty
                        # else:
                              # print(f"---> Warning: NaN/Inf term_penalty for '{n}'. Skipping addition.")

                    except Exception as e:
                        print(f"Error processing penalty for layer {n}: {e}")
                        continue
        return penalty_loss
    

## __Training and validation function__

### Analytical Function

In [None]:
def print_model_info(model):
    total_params = sum(p.numel() for p in model.parameters())
    param_size_bytes = total_params * 4  # 假設 float32，每個參數佔 4 bytes
    param_size_MB = param_size_bytes / (1024**2)

    print(f"Total Parameters: {total_params}")
    print(f"Model Size (float32): {param_size_MB:.2f} MB")

In [None]:
def compute_fwt_fixed_verbose(previous_model, init_model, X_val, y_val, known_classes, batch_size=64):
    """
    Corrected FWT computation for token-level labeling with sequence inputs.
    X_val: shape [B, L, F]
    y_val: shape [B, L]
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    previous_model.to(device).eval()
    init_model.to(device).eval()

    B, L, F = X_val.shape
    X_val_flat = X_val  # [B, L, F] (unchanged)
    y_val_flat = y_val.view(-1)  # [B * L]

    # 🔍 找出符合 known_classes 的位置
    mask = torch.isin(y_val_flat, torch.tensor(known_classes, device=y_val.device))
    indices_flat = mask.nonzero(as_tuple=False).squeeze()  # token 級 index

    # 將 token 級 index 映射回 batch 級索引
    batch_indices = indices_flat // L  # 轉為 sequence (樣本) 索引
    batch_indices = batch_indices.unique()

    X_known = X_val[batch_indices]
    y_known = y_val[batch_indices]

    if len(y_known) == 0:
        print(f"⚠️ No validation samples for known classes {known_classes}.")
        return None, None, None

    print(f"📋 Total matching sequences for known classes {known_classes}: {len(y_known)}")

    loader = DataLoader(TensorDataset(X_known, y_known), batch_size=batch_size)

    correct_prev, correct_init, total = 0, 0, 0

    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)

            if xb.dim() == 2:
                xb = xb.unsqueeze(0)
                yb = yb.unsqueeze(0)

            # [batch, seq, class] → [batch * seq, class]
            out_prev = previous_model(xb).view(-1, previous_model.output_size)
            out_init = init_model(xb).view(-1, init_model.output_size)

            yb_flat = yb.view(-1)
            preds_prev = torch.argmax(out_prev, dim=-1)
            preds_init = torch.argmax(out_init, dim=-1)

            # 只計算 known_classes 的部分
            mask = torch.isin(yb_flat, torch.tensor(known_classes, device=yb.device))
            total_batch = mask.sum().item()

            correct_prev += (preds_prev[mask] == yb_flat[mask]).sum().item()
            correct_init += (preds_init[mask] == yb_flat[mask]).sum().item()
            total += total_batch

    acc_prev = correct_prev / total
    acc_init = correct_init / total
    fwt_value = acc_prev - acc_init

    print(f"\n### 🔍 FWT Debug Info:")
    print(f"- Total evaluated tokens: {total}")
    print(f"- Correct (PrevModel): {correct_prev} / {total} → Acc = {acc_prev:.4f}")
    print(f"- Correct (InitModel): {correct_init} / {total} → Acc = {acc_init:.4f}")
    print(f"- FWT = Acc_prev - Acc_init = {fwt_value:.4f}")

    return fwt_value, acc_prev, acc_init


In [11]:
def compute_classwise_accuracy(student_logits_flat, y_batch, class_correct, class_total):
    """
    Computes per-class accuracy by accumulating correct and total samples for each class using vectorized operations.
    
    Args:
        student_logits_flat (torch.Tensor): Model predictions (logits) in shape [batch_size * seq_len, output_size]
        y_batch (torch.Tensor): True labels in shape [batch_size * seq_len]
        class_correct (dict): Dictionary to store correct predictions per class
        class_total (dict): Dictionary to store total samples per class
    """
    # Ensure inputs are on the same device
    if student_logits_flat.device != y_batch.device:
        raise ValueError("student_logits_flat and y_batch must be on the same device")

    # Convert logits to predicted class indices
    predictions = torch.argmax(student_logits_flat, dim=-1)  # Shape: [batch_size * seq_len]

    # Compute correct predictions mask
    correct_mask = (predictions == y_batch)  # Shape: [batch_size * seq_len], boolean

    # Get unique labels in this batch
    unique_labels = torch.unique(y_batch)

    # Update class_total and class_correct using vectorized operations
    for label in unique_labels:
        label = label.item()  # Convert tensor to scalar
        if label not in class_total:
            class_total[label] = 0
            class_correct[label] = 0
        
        # Count total samples for this label
        label_mask = (y_batch == label)
        class_total[label] += label_mask.sum().item()
        
        # Count correct predictions for this label
        class_correct[label] += (label_mask & correct_mask).sum().item()

### Training and validation functions

In [12]:
def inspect_fisher_info(fisher_dict: dict, label: str = "Fisher Inspection"):
    # --- !! START INSPECTION !! ---
    print(f"\n--- Inspecting {label} ---")
    total_fisher_sum = 0.0
    max_fisher_val = 0.0
    min_positive_fisher = float('inf')
    zero_fisher_layers = 0

    for n, f_tensor in fisher_dict.items():
        if f_tensor.min() < 0:
            print(f"----> ERROR: Negative value found in Fisher for {n}")
        if torch.isnan(f_tensor).any() or torch.isinf(f_tensor).any():
            print(f"----> ERROR: NaN/Inf found in Fisher for {n}")

        layer_sum = f_tensor.sum().item()
        layer_max = f_tensor.abs().max().item()
        pos_min = f_tensor[f_tensor > 0].min().item() if (f_tensor > 0).any() else float('inf')

        total_fisher_sum += layer_sum
        max_fisher_val = max(max_fisher_val, layer_max)
        min_positive_fisher = min(min_positive_fisher, pos_min)
        if layer_max == 0:
            zero_fisher_layers += 1
        # print(f"Layer {n}: Sum={layer_sum:.4e}, MaxAbs={layer_max:.4e}") # Optional detail

    print(f"Total Fisher Sum: {total_fisher_sum:.4e}")
    print(f"Overall Max Fisher Abs Value: {max_fisher_val:.4e}")
    print(f"Overall Min Positive Fisher Value: {min_positive_fisher:.4e}")
    print(f"Number of layers with all-zero Fisher: {zero_fisher_layers} / {len(fisher_dict)}")
    print(f"--- Finished Inspecting {label} ---\n")
    # --- !! END INSPECTION !! ---


# Training and validation function for Period 1.
def train_and_validate(model, output_size, criterion, optimizer, 
                       X_train, y_train, X_val, y_val, scheduler, 
                       use_scheduler=False, num_epochs=10, batch_size=64, 
                       model_saving_folder=None, model_name=None, stop_signal_file=None):
    """
    Base training function (Period 1) without EWC.
    This is essentially your provided training loop.
    """
    print("'train_and_validate' function started. \n")
    # Ensure model saving folder exists (deleting existing first if there is one)
    if model_saving_folder and os.path.exists(model_saving_folder):
        # os.rmdir(model_saving_folder) # Only works on empty folders 
        shutil.rmtree(model_saving_folder) # Safely remove all contents
        if not os.path.exists(model_saving_folder):
            print(f"Existing folder has been removed : {model_saving_folder}\n")
    if model_saving_folder and not os.path.exists(model_saving_folder):
        os.makedirs(model_saving_folder)
        
    if not model_saving_folder:
        model_saving_folder = './saved_models'
    if not model_name:
        model_name = 'model'

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    
    # Convert data to tensors # Returns a copy, original is safe
    X_train = torch.tensor(X_train, dtype=torch.float32).to(device)  # (seqs, seq_len, features)
    y_train = torch.tensor(y_train, dtype=torch.long).to(device)    # (seqs, seq_len)
    X_val = torch.tensor(X_val, dtype=torch.float32).to(device)
    y_val = torch.tensor(y_val, dtype=torch.long).to(device)

    # Create TensorDatasets
    train_dataset = TensorDataset(X_train, y_train)
    val_dataset = TensorDataset(X_val, y_val)

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    print("y_train:")
    print(type(y_train))
    print(y_train.dtype)
    print(y_train.shape)
    print("X_train:")
    print(type(X_train))
    print(X_train.dtype)
    print(X_train.shape)
    print("\ny_val:")
    print(type(y_val))
    print(y_val.dtype)
    print(y_val.shape)
    print("X_val:")
    print(type(X_val))
    print(X_val.dtype)
    print(X_val.shape)

    # Debug prints for TensorDataset and DataLoader
    print("\nDataset Lengths:")
    print(f"Train Dataset Length: {len(train_dataset)}")
    print(f"Validation Dataset Length: {len(val_dataset)}")

    print("\nDataLoader Batch Sizes:")
    print(f"Number of Batches in Train DataLoader: {len(train_loader)}")
    print(f"Number of Batches in Validation DataLoader: {len(val_loader)}")

    # Additional details for y_train, y_val, and y_test
    print("\ny_train Unique Values and Stats:")
    print(f"Unique values in y_train: {y_train.unique()}")
    print(f"y_train Min: {y_train.min()}, Max: {y_train.max()}")

    print("\ny_val Unique Values and Stats:")
    print(f"Unique values in y_val: {y_val.unique()}")
    print(f"y_val Min: {y_val.min()}, Max: {y_val.max()}")

    # Device check
    print("\nDevice Info:")
    print(f"X_train Device: {X_train.device}")
    print(f"y_train Device: {y_train.device}")
    print(f"X_val Device: {X_val.device}")
    print(f"y_val Device: {y_val.device}\n")

    # Calculate number of batches
    # num_batches = (len(X_train) + batch_size - 1) // batch_size

    global best_results  # Ensure we can modify the external variable if defined outside.
    best_results = []    # Start empty each training run
    model.train()

    for epoch in range(num_epochs):
        if stop_signal_file and os.path.exists(stop_signal_file):
            print("\nStop signal detected. Exiting training loop safely.\n")
            break
        epoch_loss = 0
        class_correct = {}  # Dictionary to store correct predictions per class
        class_total = {}  # Dictionary to store total samples per class
        model.train()
        i=0
        for X_batch, y_batch in train_loader:
            # Reset gradients before forward pass
            optimizer.zero_grad()  # Best practice

            # Forward pass
            outputs = model(X_batch)
            outputs = outputs.view(-1, output_size)
            y_batch = y_batch.view(-1)

            if epoch == 1 and i < 3:
                i += 1
                print(f"\nUnique target values: {y_batch.unique()}")
                print(f"Target dtype: {y_batch.dtype}")
                print(f"Min target: {y_batch.min()}, Max target: {y_batch.max()}")
                print("Unique classes in y_train:", y_train.unique())
                print(f"Unique classes in y_val: {y_val.unique()}\n")
            
            # Compute loss
            loss = criterion(outputs, y_batch)

            # Compute class-wise accuracy (Accumulates values in dict)
            compute_classwise_accuracy(outputs, y_batch, class_correct, class_total)

            # Backward pass and optimization
            # No longer reset gradients here: optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item() * X_batch.size(0)  # Scale back to total loss
            
        train_loss = epoch_loss / len(train_loader.dataset) # Average training loss per sample over an epoch 

        # Compute per-class training accuracy
        # train_classwise_accuracy = {int(c): (class_correct[c] / class_total[c]) * 100 if class_total[c] > 0 else 0 
        #                            for c in sorted(class_total.keys())}
        train_classwise_accuracy = {int(c): f"{(class_correct[c] / class_total[c]) * 100:.2f}%" if class_total[c] > 0 else "0.00%" 
                                    for c in sorted(class_total.keys())}
        
        # Perform validation at the end of each epoch
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        val_class_correct = {}
        val_class_total = {}
        with torch.no_grad():
            for X_val_batch, y_val_batch in val_loader:
                val_outputs = model(X_val_batch).view(-1, output_size)
                val_labels = y_val_batch.view(-1)
                val_loss += criterion(val_outputs, val_labels).item() * X_val_batch.size(0)  # Scale to total loss
                val_predictions = torch.argmax(val_outputs, dim=-1)
                val_correct += (val_predictions == val_labels).sum().item()
                val_total += val_labels.size(0)
                # Compute per-class validation accuracy
                compute_classwise_accuracy(val_outputs, val_labels, val_class_correct, val_class_total)
        val_loss /= len(val_loader.dataset) # Average validation loss per sample over an epoch 
        val_accuracy = val_correct / val_total

        # Compute per-class validation accuracy
        # val_classwise_accuracy = {int(c): (val_class_correct[c] / val_class_total[c]) * 100 if val_class_total[c] > 0 else 0 
        #                          for c in sorted(val_class_total.keys())}
        val_classwise_accuracy = {int(c): f"{(val_class_correct[c] / val_class_total[c]) * 100:.2f}%" if val_class_total[c] > 0 else "0.00%" 
                                  for c in sorted(val_class_total.keys())}
        
        current_lr = optimizer.param_groups[0]['lr']
        print(f"Epoch {epoch+1}/{num_epochs}, "
              f"Train Loss: {train_loss:.9f}, "
              f"Train-Class-Acc: {train_classwise_accuracy}, "
              f"Val Loss: {val_loss:.9f}, "
              f"Val Accuracy: {val_accuracy * 100:.2f}%, "
              f"Val-Class-Acc: {val_classwise_accuracy}, "
              f"LR: {current_lr:.9f}")

        # Save current model and update best results if applicable
        current_epoch_info = {
            "epoch": epoch+1,
            "train_loss": train_loss,
            "train_classwise_accuracy": train_classwise_accuracy,
            "val_loss": val_loss,
            "val_accuracy": val_accuracy,
            "val_classwise_accuracy": val_classwise_accuracy,
            'learning_rate': current_lr, # Optimizer state
            "model_path": os.path.join(model_saving_folder, f"{model_name}_epoch_{epoch+1}.pth")
        }

        # Insert this epoch if we have fewer than 5 results
        # or if it beats the lowest of the top 5
        if len(best_results) < 5 or val_accuracy > best_results[-1]["val_accuracy"]:
            if len(best_results) == 5:
                # Remove the worst model from the list, the last (lowest accuracy)
                worst = best_results.pop() 
                if os.path.exists(worst["model_path"]):
                    os.remove(worst["model_path"])
                    print(f"Removed old model with accuracy {worst['val_accuracy']*100:.2f}%, and file was at {worst['model_path']}")
            # Just insert and sort by val_accuracy descending
            best_results.append(current_epoch_info) 
            best_results.sort(key=lambda x: x["val_accuracy"], reverse=True)
            torch.save({ # Save this model
                'epoch': epoch+1,  # Save the current epoch
                'train_loss': train_loss,
                'val_loss': val_loss,
                'model_state_dict': model.state_dict(),  # Model weights
                'optimizer_state_dict': optimizer.state_dict(),  # Optimizer state
                'learning_rate': current_lr # Optimizer state
            }, current_epoch_info["model_path"])
            print(f"Model saved after epoch {epoch+1} to {current_epoch_info['model_path']} \n")

        if use_scheduler == True:
            # Scheduler step should follow after considering the results (placed after otallher losses)
            scheduler.step(val_loss)


    # Save the final model
    if current_epoch_info:
        final_model_path = os.path.join(model_saving_folder, f"{model_name}_final.pth")
        torch.save({ # Save this model
            'epoch': epoch+1,  # Save the current epoch
            'train_loss': train_loss,
            'val_loss': val_loss,
            'model_state_dict': model.state_dict(),  # Model weights
            'optimizer_state_dict': optimizer.state_dict(),  # Optimizer state
            'learning_rate': current_lr # Optimizer state
        }, final_model_path)
        print(f"\nFinal model saved to {final_model_path}")

    print("\nTraining complete. \n\nTop 5 Models Sorted by Validation Accuracy: ")
    for res in best_results:        
        print(f"Epoch {res['epoch']}/{num_epochs}, "
              f"Train Loss: {res['train_loss']:.9f}, "
              f"Train-Class-Acc: {train_classwise_accuracy}, " 
              f"Val Loss: {res['val_loss']:.9f}, "
              f"Val Accuracy: {res['val_accuracy'] * 100:.2f}%, "
              f"Val-Class-Acc: {val_classwise_accuracy}, "
              f"Model Path: {res['model_path']}")
    print('\n')
    
    del X_train, y_train, X_val, y_val, train_dataset, val_dataset, train_loader, val_loader
    torch.cuda.empty_cache()

    # Load the checkpoint
    # checkpoint = torch.load("path/to/model_checkpoint.pth")
    # # Restore model state
    # model.load_state_dict(checkpoint['model_state_dict'])
    # # Restore optimizer state
    # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    # # Restore scheduler state (if used)
    # scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    # # Restore epoch and other metadata
    # start_epoch = checkpoint['epoch'] + 1  # Resume from the next epoch
    # loss = checkpoint['loss']  # Optional
    # print(f"Checkpoint loaded. Resuming from epoch {start_epoch}")


# Training and validation function for Period 2 and beyond 
def train_and_validate_ewc(model, deepcopy, output_size, criterion, optimizer, 
                           X_train, y_train, X_val, y_val, scheduler, use_scheduler=False, 
                           ewc: 'EWC' = None, lambda_ewc: float = 0.4,
                           num_epochs=10, batch_size=64, 
                           model_saving_folder=None, model_name=None, stop_signal_file=None):
    """
    Training function for incremental periods that adds an EWC penalty.
    """
    print("'train_and_validate_ewc' function started.\n")
    # Ensure model saving folder exists (deleting existing first if there is one)
    if model_saving_folder and os.path.exists(model_saving_folder):
        # os.rmdir(model_saving_folder) # Only works on empty folders 
        shutil.rmtree(model_saving_folder) # Safely remove all contents
        if not os.path.exists(model_saving_folder):
            print(f"Existing folder has been removed : {model_saving_folder}\n")
    if model_saving_folder and not os.path.exists(model_saving_folder):
        os.makedirs(model_saving_folder)
    if not model_saving_folder:
        model_saving_folder = './saved_models'
    if not model_name:
        model_name = 'model'

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    # Convert data to tensors # Returns a copy, original is safe
    X_train = torch.tensor(X_train, dtype=torch.float32).to(device)  # (seqs, seq_len, features)
    y_train = torch.tensor(y_train, dtype=torch.long).to(device)    # (seqs, seq_len)
    X_val = torch.tensor(X_val, dtype=torch.float32).to(device)
    y_val = torch.tensor(y_val, dtype=torch.long).to(device)

    # Create TensorDatasets
    train_dataset = TensorDataset(X_train, y_train)
    val_dataset = TensorDataset(X_val, y_val)

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    print("y_train:")
    print(type(y_train))
    print(y_train.dtype)
    print(y_train.shape)
    print("X_train:")
    print(type(X_train))
    print(X_train.dtype)
    print(X_train.shape)
    print("\ny_val:")
    print(type(y_val))
    print(y_val.dtype)
    print(y_val.shape)
    print("X_val:")
    print(type(X_val))
    print(X_val.dtype)
    print(X_val.shape)

    # Debug prints for TensorDataset and DataLoader
    print("\nDataset Lengths:")
    print(f"Train Dataset Length: {len(train_dataset)}")
    print(f"Validation Dataset Length: {len(val_dataset)}")

    print("\nDataLoader Batch Sizes:")
    print(f"Number of Batches in Train DataLoader: {len(train_loader)}")
    print(f"Number of Batches in Validation DataLoader: {len(val_loader)}")

    # Additional details for y_train, y_val, and y_test
    print("\ny_train Unique Values and Stats:")
    print(f"Unique values in y_train: {y_train.unique()}")
    print(f"y_train Min: {y_train.min()}, Max: {y_train.max()}")

    print("\ny_val Unique Values and Stats:")
    print(f"Unique values in y_val: {y_val.unique()}")
    print(f"y_val Min: {y_val.min()}, Max: {y_val.max()}")

    # Device check
    print("\nDevice Info:")
    print(f"X_train Device: {X_train.device}")
    print(f"y_train Device: {y_train.device}")
    print(f"X_val Device: {X_val.device}")
    print(f"y_val Device: {y_val.device}\n")

    # Calculate number of batches
    # num_batches = (len(X_train) + batch_size - 1) // batch_size

    global best_results  # Ensure we can modify the external variable if defined outside.
    best_results = []    # Start empty each training run

    for epoch in range(num_epochs):
        if stop_signal_file and os.path.exists(stop_signal_file):
            print("\nStop signal detected. Exiting training loop safely.\n")
            break
        epoch_train_loss = 0.0
        train_class_correct = {}  # Dictionary to store correct predictions per class
        train_class_total = {}  # Dictionary to store total samples per class
        model.train()
        i=0
        for X_batch, y_batch in train_loader:
            # Reset gradients before forward pass
            optimizer.zero_grad()  # Best practice
            # Forward pass
            outputs = model(X_batch) # Shape: [batch, seq_len, output_size]
            # Reshape for CE loss computation.
            outputs = outputs.view(-1, output_size) # Shape: [batch * seq_len, output_size]
            y_batch = y_batch.view(-1) # Shape: [batch * seq_len]

            if epoch == 1 and i < 3:
                i += 1
                print(f"\nUnique target values: {y_batch.unique()}")
                print(f"Target dtype: {y_batch.dtype}")
                print(f"Min target: {y_batch.min()}, Max target: {y_batch.max()}")
                print("Unique classes in y_train:", y_train.unique())
                print(f"Unique classes in y_val: {y_val.unique()}\n")
            
            # Compute Cross-Entropy loss
            loss = criterion(outputs, y_batch)
            # EWC regularization/penalty term: (lambda/2)*penalty
            if ewc is not None and lambda_ewc > 0:
                ewc_penalty = ewc.penalty(model)
                # print(f"ewc_penalty = {ewc_penalty}")
                # print(f"loss = {loss}")
                loss += (lambda_ewc / 2) * ewc_penalty # Factor of 1/2 is common
            loss.backward()
            # Optional: Gradient clipping
            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            epoch_train_loss += loss.item() * X_batch.size(0) # Accumulate total loss for the epoch

            # Compute class-wise accuracy (Accumulates values in dict)
            compute_classwise_accuracy(outputs, y_batch, train_class_correct, train_class_total )
            
        train_loss = epoch_train_loss / len(train_loader.dataset) # Average training loss per sample over an epoch 
        train_classwise_accuracy = {int(c): f"{(train_class_correct[c] / train_class_total[c]) * 100:.2f}%" if train_class_total[c] > 0 else "0.00%" 
                                    for c in sorted(train_class_total.keys())}

        # Perform validation at the end of each epoch (only CE loss and accuracy)
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        val_class_correct = {}
        val_class_total = {}
        with torch.no_grad():
            for X_val_batch, y_val_batch in val_loader:
                val_outputs = model(X_val_batch).view(-1, output_size)
                val_labels = y_val_batch.view(-1)
                val_loss += criterion(val_outputs, val_labels).item() * X_val_batch.size(0)  # Scale to total loss
                val_predictions = torch.argmax(val_outputs, dim=-1)
                val_correct += (val_predictions == val_labels).sum().item()
                val_total += val_labels.size(0)
                # Compute per-class validation accuracy
                compute_classwise_accuracy(val_outputs, val_labels, val_class_correct, val_class_total)
        val_loss /= len(val_loader.dataset) # Average validation loss per sample over an epoch 
        val_accuracy = val_correct / val_total
        val_classwise_accuracy = {int(c): f"{(val_class_correct[c] / val_class_total[c]) * 100:.2f}%" if val_class_total[c] > 0 else "0.00%" 
                                  for c in sorted(val_class_total.keys())}

        current_lr = optimizer.param_groups[0]['lr']
        print(f"Epoch {epoch+1}/{num_epochs}, "
              f"Train Loss: {train_loss:.9f}, "
              f"Train-Class-Acc: {train_classwise_accuracy}, "
              f"Val Loss: {val_loss:.9f}, "
              f"Val Accuracy: {val_accuracy * 100:.2f}%, "
              f"Val-Class-Acc: {val_classwise_accuracy}, "
              f"LR: {current_lr:.9f}")

        # Save current model and update best results if applicable
        current_epoch_info = {
            "epoch": epoch+1,
            "train_loss": train_loss,
            "train_classwise_accuracy": train_classwise_accuracy,
            "val_loss": val_loss,
            "val_accuracy": val_accuracy,
            "val_classwise_accuracy": val_classwise_accuracy,
            'learning_rate': current_lr, # Optimizer state
            "model_path": os.path.join(model_saving_folder, f"{model_name}_epoch_{epoch+1}.pth")
        }

        # Insert this epoch if we have fewer than 5 results
        # or if it beats the lowest of the top 5
        if len(best_results) < 5 or val_accuracy > best_results[-1]["val_accuracy"]:
            if len(best_results) == 5:
                # Remove the worst model from the list, the last (lowest accuracy)
                worst = best_results.pop() 
                if os.path.exists(worst["model_path"]):
                    os.remove(worst["model_path"])
                    print(f"Removed old model with accuracy {worst['val_accuracy']*100:.2f}%, and file was at {worst['model_path']}")
            # Just insert and sort by val_accuracy descending
            best_results.append(current_epoch_info) 
            best_results.sort(key=lambda x: x["val_accuracy"], reverse=True)
            torch.save({ # Save this model
                'epoch': epoch+1,  # Save the current epoch
                'train_loss': train_loss,
                'val_loss': val_loss,
                'model_state_dict': model.state_dict(),  # Model weights
                'optimizer_state_dict': optimizer.state_dict(),  # Optimizer state
                'learning_rate': current_lr # Optimizer state
            }, current_epoch_info["model_path"])
            print(f"Model saved after epoch {epoch+1} to {current_epoch_info['model_path']} \n")

        if use_scheduler == True:
            # Scheduler step should follow after considering the results (placed after otallher losses)
            scheduler.step(val_loss)

    # Save the final model
    if current_epoch_info:
        final_model_path = os.path.join(model_saving_folder, f"{model_name}_final.pth")
        torch.save({ # Save this model
            'epoch': epoch+1,  # Save the current epoch
            'train_loss': train_loss,
            'val_loss': val_loss,
            'model_state_dict': model.state_dict(),  # Model weights
            'optimizer_state_dict': optimizer.state_dict(),  # Optimizer state
            'learning_rate': current_lr # Optimizer state
        }, final_model_path)
        print(f"\nFinal model saved to {final_model_path}")
        # Create a temporary list with best models + final model
        ewc_targets = best_results + [{
            "epoch": epoch+1,
            "model_path": final_model_path
        }]
    else: # Create a temporary list with a copy of best models
        # ewc_targets = best_results.copy() # Shallow copy, can't modify the dicts in the list
        ewc_targets = copy.deepcopy(best_results) # Fully independent copy.
    print(f"\nCalculating EWC state for 'best & final models', with sample_size = All...")
    sample_size_all = None
    for res in ewc_targets:
        try: # Reload saved model
            deep_copy_model = copy.deepcopy(deepcopy).to(device) # Create independent copy (deep copy) again.
            step_model_path = os.path.normpath(res['model_path'])
            step_checkpoint = torch.load(step_model_path, map_location=device, weights_only=True)
            deep_copy_model.load_state_dict(step_checkpoint['model_state_dict'])
            print(f"Loaded previous period checkpoint from: \n\t{step_model_path}")
            # print(f"\nprev_checkpoint: \n{prev_checkpoint}\n")
        except Exception as e:
            print(f"ERROR reloading saved model: {e}")
        try: # Call the static method to compute Fisher and Params
            ewc_fisher_dict_all, ewc_params_dict_all = EWC.compute_fisher_and_params(
                model=deep_copy_model,       # Pass the loaded Period's model
                dataloader=train_loader,
                criterion=criterion,
                device=device,
                sample_size=sample_size_all # Use all samples
            )
            print("EWC Fisher and Params computed (All Samples).")
            # --- !! START INSPECTION HERE !! ---
            inspect_fisher_info(ewc_fisher_dict_all, label=f"Computed Fisher Values (All Samples) for epoch {res['epoch']}")
            # --- !! END INSPECTION !! ---
            # Add computed EWC state (Already saved on CPU for better compatibility)
            step_checkpoint['ewc_fisher'] = ewc_fisher_dict_all
            step_checkpoint['ewc_params'] = ewc_params_dict_all
            try:
                if os.path.isfile(step_model_path):
                    os.remove(step_model_path)
                else:
                    print(f"No file at: {step_model_path}")
            except Exception as e:
                print(f"Error deleting model at {step_model_path}: {e}")
            torch.save(step_checkpoint, step_model_path)
            del step_model_path, deep_copy_model, ewc_fisher_dict_all, ewc_params_dict_all, step_checkpoint
        except Exception as e:
            print(f"ERROR calculating or saving EWC state (All Samples): {e}")
    del sample_size_all, deepcopy, ewc_targets

    print("\nTraining complete. \n\nTop 5 Models Sorted by Validation Accuracy: ")
    for res in best_results:        
        print(f"Epoch {res['epoch']}/{num_epochs}, "
              f"Train Loss: {res['train_loss']:.9f}, "
              f"Train-Class-Acc: {train_classwise_accuracy}, " 
              f"Val Loss: {res['val_loss']:.9f}, "
              f"Val Accuracy: {res['val_accuracy'] * 100:.2f}%, "
              f"Val-Class-Acc: {val_classwise_accuracy}, "
              f"Model Path: {res['model_path']}")
    print('\n')
    
    del X_train, y_train, X_val, y_val, train_dataset, val_dataset, train_loader, val_loader
    torch.cuda.empty_cache()


## __Setup before training__

### Define list_period_files_full_path

In [13]:
def setup_file_paths(pair='BTCUSD', base_dir='Data', days=190):
    """
    Set up file paths for cryptocurrency data across multiple periods.

    Args:
        pair (str): Trading pair (e.g., 'BTCUSD').
        base_dir (str): Base directory for data storage (default 'Data').
        days (int): Number of days for each period (default 190).

    Returns:
        tuple: (base_folder_path, with_indicators_file_path, list_period_files_full_path)
    """
    # Define base file name and folder structure
    file_name = f"Polygon_{pair}_4Y_1min"
    base_folder_path = os.path.normpath(os.path.join(base_dir, file_name))
    
    # Check if folder exists
    if not os.path.isdir(base_folder_path):
        raise FileNotFoundError(f"Directory '{base_folder_path}' does not exist.")

    # Define file path with indicators for Period 1
    with_indicators_file_path = os.path.normpath(
        os.path.join(base_folder_path, f"_{file_name}_{days}_days_with_indicators.csv")
    )

    # Define file paths for all periods
    list_period_files_full_path = [
        # Period 1
        with_indicators_file_path,
        # Period 2: 2020-11-11 to 2021-05-20
        os.path.normpath(os.path.join(
            base_folder_path, f"{file_name}_{days}_days__2020-11-11__2021-05-20__with_indicators.csv"
        )),
        # Period 3: 2021-05-20 to 2021-11-26
        os.path.normpath(os.path.join(
            base_folder_path, f"{file_name}_{days}_days__2021-05-20__2021-11-26__with_indicators.csv"
        )),
        # Period 4: 2021-11-26 to 2022-06-04
        os.path.normpath(os.path.join(
            base_folder_path, f"{file_name}_{days}_days__2021-11-26__2022-06-04__with_indicators.csv"
        )),
        # Period 5: 2022-06-04 to 2022-12-11
        os.path.normpath(os.path.join(
            base_folder_path, f"{file_name}_{days}_days__2022-06-04__2022-12-11__with_indicators.csv"
        )),
    ]

    return base_folder_path, with_indicators_file_path, list_period_files_full_path

def print_folder_contents(folder_path):
    """Print all files in the specified folder."""
    print("\n📂 Folder Contents:")
    for file in os.listdir(folder_path):
        print(f"Found file: {file}")

if __name__ == "__main__":
    # Set up paths
    base_folder_path, with_indicators_file_path, list_period_files_full_path = setup_file_paths()

    # Print results
    print("=" * 70)
    print("File Path Configuration".center(70))
    print("=" * 70)
    
    print(f"{'Base Folder Path':<25}: {base_folder_path}")
    print(f"{'Period 1 File Path':<25}: {with_indicators_file_path}")
    print("-" * 70)
    
    print("List of Period Files:")
    for i, path in enumerate(list_period_files_full_path, 1):
        print(f"{'Period ' + str(i):<25}: {path}")
    
    print("=" * 70)

    # Print folder contents
    print_folder_contents(base_folder_path)

                       File Path Configuration                        
Base Folder Path         : Data\Polygon_BTCUSD_4Y_1min
Period 1 File Path       : Data\Polygon_BTCUSD_4Y_1min\_Polygon_BTCUSD_4Y_1min_190_days_with_indicators.csv
----------------------------------------------------------------------
List of Period Files:
Period 1                 : Data\Polygon_BTCUSD_4Y_1min\_Polygon_BTCUSD_4Y_1min_190_days_with_indicators.csv
Period 2                 : Data\Polygon_BTCUSD_4Y_1min\Polygon_BTCUSD_4Y_1min_190_days__2020-11-11__2021-05-20__with_indicators.csv
Period 3                 : Data\Polygon_BTCUSD_4Y_1min\Polygon_BTCUSD_4Y_1min_190_days__2021-05-20__2021-11-26__with_indicators.csv
Period 4                 : Data\Polygon_BTCUSD_4Y_1min\Polygon_BTCUSD_4Y_1min_190_days__2021-11-26__2022-06-04__with_indicators.csv
Period 5                 : Data\Polygon_BTCUSD_4Y_1min\Polygon_BTCUSD_4Y_1min_190_days__2022-06-04__2022-12-11__with_indicators.csv

📂 Folder Contents:
Found file: Polyg

### __All periods data__
'trend': Categorized trend values based on the detected phases:
- 0: No trend
- 1: Moderate negative trend
- 2: Very strong negative trend
- 3: Moderate positive trend
- 4: Very strong positive trend


## __Train the Model__

### Period 1 --> Training and saving in __*'1st_try'*__ (BiGRUWithAttention, num_layers = 4) ---> Val acc = 98.35 %
### Val-Class-Acc: {0: '98.63%', 1: '97.92%'}

In [14]:
"""
- 'trend': Categorized trend values based on the detected phases:
    - 0: No trend
    - 1: Moderate negative trend
    - 2: Very strong negative trend
    - 3: Moderate positive trend
    - 4: Very strong positive trend
"""
with contextlib.redirect_stdout(open(os.devnull, 'w')):
    X_train, y_train, X_val, y_val, X_test, y_test, Number_features = process_and_return_splits(
        with_indicators_file_path = list_period_files_full_path[0], # Change 
        downsampled_data_minutes = downsampled_data_minutes,
        exclude_columns = exclude_columns,
        lower_threshold = lower_threshold,
        upper_threshold = upper_threshold,
        reverse_steps = reverse_steps,
        sequence_length = sequence_length,
        sliding_interval = sliding_interval,
        trends_to_keep = {0, 1}  # Default keeps all trends : {0, 1, 2, 3, 4}
        # trends_to_keep = {0, 1, 2, 3, 4}  # Default keeps all trends
    )

print(f"\nNumber_features = {Number_features}")

unique_classes = np.unique(y_val)
num_classes = len(unique_classes)
print(f"unique_classes = {unique_classes}")
print(f"num_classes = {num_classes}")



Number_features = 7
unique_classes = [0 1]
num_classes = 2


In [61]:
#----------------------------------------------------------------------
# Initialize the list to store results across runs
track_across_runs = []
#----------------------------------------------------------------------

# Model parameters
input_size = Number_features  # Number of features
hidden_size = 64  # Number of GRU units
output_size = num_classes # Must be dynamic, up to 5  # Number of trend classes (0, 15, 25, -15, -25)
num_layers = 4  # Number of GRU layers
num_epochs= 2000 # Number of epochs/ go through entire data
batch_size= 64 # How many sequences passed at once to the model
model_name = 'BiGRUWithAttention' # Name of the model to use for saving
best_results = [] # Initialize this outside the training function or at the beginning of training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Define a global stop signal
stop_signal_file = os.path.normpath(os.path.join('Class_Incremental_CL', 'Classif_Bi_Dir_GRU_Model/stop_training.txt'))  # Create this file to stop training
model_saving_folder = os.path.normpath(os.path.join('Class_Incremental_CL', "Classif_Bi_Dir_GRU_Model/Trained_models/Baseline/Period_1/1st_try"))
ensure_folder(model_saving_folder)

# Instantiate the model
class_gru_model = BiGRUWithAttention(input_size, hidden_size, output_size, num_layers).to(device)

# Define the loss function, optimizer and scheduler
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(class_gru_model.parameters(), lr=0.0001) # lr=0.00005
# optimizer = optim.Adam(class_gru_model.parameters(), lr=0.001, weight_decay=1e-5)
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.9, patience=10)

# train_and_validate(class_gru_model, output_size, criterion, optimizer, X_train_all, y_train_all, X_val_all, y_val_all, scheduler, 
train_and_validate(class_gru_model, output_size, criterion, optimizer, X_train, y_train, X_val, y_val, scheduler, 
                   False, num_epochs, batch_size, model_saving_folder, model_name, stop_signal_file)

#----------------------------------------------------------------------
# Append only the best result (already at index 0)
track_across_runs.append(best_results[0])
#----------------------------------------------------------------------

for res in best_results:        
    print(f"Epoch {res['epoch']}/{num_epochs}, "
            f"Train Loss: {res['train_loss']:.4f}, " 
            f"Val Loss: {res['val_loss']:.4f}, "
            f"Val Accuracy: {res['val_accuracy'] * 100:.2f}%, "
            f"Model Path: {res['model_path']}")      
print(f"\nclass_gru_model: \n{class_gru_model}")

print(f"\nunique_classes = {unique_classes}")
print(f"num_classes = {num_classes}")
# del unique_classes, num_classes
for var in ["X_train", "y_train", "X_val", "y_val", "X_test", "y_test", "Number_features", "unique_classes", "num_classes"]:
    if var in locals():
        del locals()[var]


'train_and_validate' function started. 

Existing folder has been removed : Class_Incremental_CL\Classif_Bi_Dir_GRU_Model\Trained_models\Baseline\Period_1\1st_try

y_train:
<class 'torch.Tensor'>
torch.int64
torch.Size([3634, 1000])
X_train:
<class 'torch.Tensor'>
torch.float32
torch.Size([3634, 1000, 7])

y_val:
<class 'torch.Tensor'>
torch.int64
torch.Size([454, 1000])
X_val:
<class 'torch.Tensor'>
torch.float32
torch.Size([454, 1000, 7])

Dataset Lengths:
Train Dataset Length: 3634
Validation Dataset Length: 454

DataLoader Batch Sizes:
Number of Batches in Train DataLoader: 57
Number of Batches in Validation DataLoader: 8

y_train Unique Values and Stats:
Unique values in y_train: tensor([0, 1], device='cuda:0')
y_train Min: 0, Max: 1

y_val Unique Values and Stats:
Unique values in y_val: tensor([0, 1], device='cuda:0')
y_val Min: 0, Max: 1

Device Info:
X_train Device: cuda:0
y_train Device: cuda:0
X_val Device: cuda:0
y_val Device: cuda:0

Epoch 1/2000, Train Loss: 0.681195149, 

In [24]:
import torch
from torch.utils.data import TensorDataset, DataLoader

def compute_classwise_accuracy(student_logits_flat, y_batch, class_correct, class_total):
    predictions = torch.argmax(student_logits_flat, dim=-1)
    correct_mask = (predictions == y_batch)
    unique_labels = torch.unique(y_batch)

    for label in unique_labels:
        label = label.item()
        if label not in class_total:
            class_total[label] = 0
            class_correct[label] = 0
        label_mask = (y_batch == label)
        class_total[label] += label_mask.sum().item()
        class_correct[label] += (label_mask & correct_mask).sum().item()

def add_val_class_acc(model_copy, output_size, criterion, X_val, y_val, file_path, model_checkpoint, batch_size=64):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model_copy.to(device)
    model_copy.eval()

    X_val = torch.tensor(X_val, dtype=torch.float32).to(device)
    y_val = torch.tensor(y_val, dtype=torch.long).to(device)
    val_dataset = TensorDataset(X_val, y_val)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    val_loss = 0.0
    val_correct = 0
    val_total = 0
    val_class_correct = {}
    val_class_total = {}

    with torch.no_grad():
        for X_val_batch, y_val_batch in val_loader:
            val_outputs = model_copy(X_val_batch).view(-1, output_size)
            val_labels = y_val_batch.view(-1)
            val_loss += criterion(val_outputs, val_labels).item() * X_val_batch.size(0)
            val_predictions = torch.argmax(val_outputs, dim=-1)
            val_correct += (val_predictions == val_labels).sum().item()
            val_total += val_labels.size(0)
            compute_classwise_accuracy(val_outputs, val_labels, val_class_correct, val_class_total)

    val_accuracy = val_correct / val_total
    val_classwise_accuracy = {
        int(c): f"{(val_class_correct[c] / val_class_total[c]) * 100:.2f}%" if val_class_total[c] > 0 else "0.00%"
        for c in sorted(val_class_total.keys())
    }

    model_checkpoint.update({
        "val_accuracy": val_accuracy,
        "Val-Class-Acc": val_classwise_accuracy
    })
    torch.save(model_checkpoint, file_path)
    print(f"✅ Updated: {os.path.basename(file_path)} | Acc: {val_accuracy:.4f} | Saved to: {file_path}")

def show_val_accuracy_from_checkpoints(model_folder):
    import torch
    import glob
    import os

    ckpt_paths = sorted(glob.glob(os.path.join(model_folder, "*.pth")))

    print(f"\n📝 Summary of val_accuracy and Val-Class-Acc:")
    for path in ckpt_paths:
        checkpoint = torch.load(path, map_location="cpu")
        filename = os.path.basename(path)

        val_acc = checkpoint.get("val_accuracy", None)
        val_class_acc = checkpoint.get("Val-Class-Acc", None)

        if val_acc is not None and val_class_acc is not None:
            print(f"{filename}: val_accuracy = {val_acc:.4f}, Val-Class-Acc: {val_class_acc}")
        else:
            print(f"{filename}: ❌ val_accuracy or Val-Class-Acc not found")


In [None]:
import os
import glob
import copy

# Reconstruct model with same config
model_template = BiGRUWithAttention(
    input_size=Number_features,
    hidden_size=64,
    output_size=num_classes,
    num_layers=4
)

# Prepare validation data
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
X_val_tensor = torch.tensor(X_val, dtype=torch.float32).to(device)
y_val_tensor = torch.tensor(y_val, dtype=torch.long).to(device)
criterion = nn.CrossEntropyLoss()

# Your model folder
model_folder = 'Class_Incremental_CL/Classif_Bi_Dir_GRU_Model/Trained_models/Baseline/Period_1/1st_try'
ckpt_paths = sorted(glob.glob(os.path.join(model_folder, "*.pth")))

# Loop through and update each model
for file_path in ckpt_paths:
    model_copy = copy.deepcopy(model_template).to(device)
    checkpoint = torch.load(file_path, map_location=device)
    model_copy.load_state_dict(checkpoint['model_state_dict'])

    add_val_class_acc(
        model_copy=model_copy,
        output_size=num_classes,
        criterion=criterion,
        X_val=X_val_tensor,
        y_val=y_val_tensor,
        file_path=file_path,
        model_checkpoint=checkpoint,
        batch_size=64
    )

show_val_accuracy_from_checkpoints(
    "Class_Incremental_CL/Classif_Bi_Dir_GRU_Model/Trained_models/Baseline/Period_1/1st_try"
)


  checkpoint = torch.load(file_path, map_location=device)
  X_val = torch.tensor(X_val, dtype=torch.float32).to(device)
  y_val = torch.tensor(y_val, dtype=torch.long).to(device)


✅ Updated: BiGRUWithAttention_epoch_1550.pth | Acc: 0.9833 | Saved to: Class_Incremental_CL/Classif_Bi_Dir_GRU_Model/Trained_models/Baseline/Period_1/1st_try\BiGRUWithAttention_epoch_1550.pth
✅ Updated: BiGRUWithAttention_epoch_1986.pth | Acc: 0.9835 | Saved to: Class_Incremental_CL/Classif_Bi_Dir_GRU_Model/Trained_models/Baseline/Period_1/1st_try\BiGRUWithAttention_epoch_1986.pth
✅ Updated: BiGRUWithAttention_epoch_1987.pth | Acc: 0.9835 | Saved to: Class_Incremental_CL/Classif_Bi_Dir_GRU_Model/Trained_models/Baseline/Period_1/1st_try\BiGRUWithAttention_epoch_1987.pth
✅ Updated: BiGRUWithAttention_epoch_1989.pth | Acc: 0.9833 | Saved to: Class_Incremental_CL/Classif_Bi_Dir_GRU_Model/Trained_models/Baseline/Period_1/1st_try\BiGRUWithAttention_epoch_1989.pth
✅ Updated: BiGRUWithAttention_epoch_1999.pth | Acc: 0.9833 | Saved to: Class_Incremental_CL/Classif_Bi_Dir_GRU_Model/Trained_models/Baseline/Period_1/1st_try\BiGRUWithAttention_epoch_1999.pth
✅ Updated: BiGRUWithAttention_final.pth 

  checkpoint = torch.load(path, map_location="cpu")


### Period 1 --> Training and saving in __*'2nd_try'*__ (BiGRUWithAttention, num_layers = 3) ---> Val acc = 98.31 %

In [None]:
"""
- 'trend': Categorized trend values based on the detected phases:
    - 0: No trend
    - 1: Moderate negative trend
    - 2: Very strong negative trend
    - 3: Moderate positive trend
    - 4: Very strong positive trend
"""
with contextlib.redirect_stdout(open(os.devnull, 'w')):
    X_train, y_train, X_val, y_val, X_test, y_test, Number_features = process_and_return_splits(
        with_indicators_file_path = list_period_files_full_path[0], # Change 
        downsampled_data_minutes = downsampled_data_minutes,
        exclude_columns = exclude_columns,
        lower_threshold = lower_threshold,
        upper_threshold = upper_threshold,
        reverse_steps = reverse_steps,
        sequence_length = sequence_length,
        sliding_interval = sliding_interval,
        trends_to_keep = {0, 1}  # Default keeps all trends : {0, 1, 2, 3, 4}
        # trends_to_keep = {0, 1, 2, 3, 4}  # Default keeps all trends
    )

print(f"\nNumber_features = {Number_features}")

unique_classes = np.unique(y_val)
num_classes = len(unique_classes)
print(f"unique_classes = {unique_classes}")
print(f"num_classes = {num_classes}")


In [None]:
# Model parameters
input_size = Number_features  # Number of features
hidden_size = 64  # Number of GRU units
output_size = num_classes # Must be dynamic, up to 5  # Number of trend classes (0, 15, 25, -15, -25)
num_layers = 3  # Number of GRU layers
num_epochs= 2000 # Number of epochs/ go through entire data
batch_size= 64 # How many sequences passed at once to the model
model_name = 'BiGRUWithAttention' # Name of the model to use for saving
best_results = [] # Initialize this outside the training function or at the beginning of training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Define a global stop signal
stop_signal_file = os.path.normpath(os.path.join('Class_Incremental_CL', 'Classif_Bi_Dir_GRU_Model/stop_training.txt'))  # Create this file to stop training
model_saving_folder = os.path.normpath(os.path.join('Class_Incremental_CL', "Classif_Bi_Dir_GRU_Model/Trained_models/Baseline/Period_1/2nd_try"))
ensure_folder(model_saving_folder)

# Instantiate the model
class_gru_model = BiGRUWithAttention(input_size, hidden_size, output_size, num_layers).to(device)

# Define the loss function, optimizer and scheduler
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(class_gru_model.parameters(), lr=0.0001) # lr=0.00005
# optimizer = optim.Adam(class_gru_model.parameters(), lr=0.001, weight_decay=1e-5)
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.9, patience=10)

# train_and_validate(class_gru_model, output_size, criterion, optimizer, X_train_all, y_train_all, X_val_all, y_val_all, scheduler, 
train_and_validate(class_gru_model, output_size, criterion, optimizer, X_train, y_train, X_val, y_val, scheduler, 
                   False, num_epochs, batch_size, model_saving_folder, model_name, stop_signal_file)

for res in best_results:        
    print(f"Epoch {res['epoch']}/{num_epochs}, "
            f"Train Loss: {res['train_loss']:.4f}, " 
            f"Val Loss: {res['val_loss']:.4f}, "
            f"Val Accuracy: {res['val_accuracy'] * 100:.2f}%, "
            f"Model Path: {res['model_path']}")      
print(f"\nclass_gru_model: \n{class_gru_model}")

print(f"\nunique_classes = {unique_classes}")
print(f"num_classes = {num_classes}")
# del unique_classes, num_classes
for var in ["X_train", "y_train", "X_val", "y_val", "X_test", "y_test", "Number_features", "unique_classes", "num_classes"]:
    if var in locals():
        del locals()[var]


'train_and_validate' function started. 

Existing folder has been removed : Class_Incremental_CL\Classif_Bi_Dir_GRU_Model\Trained_models\2nd_try

y_train:
<class 'torch.Tensor'>
torch.int64
torch.Size([3634, 1000])
X_train:
<class 'torch.Tensor'>
torch.float32
torch.Size([3634, 1000, 7])

y_val:
<class 'torch.Tensor'>
torch.int64
torch.Size([454, 1000])
X_val:
<class 'torch.Tensor'>
torch.float32
torch.Size([454, 1000, 7])

Dataset Lengths:
Train Dataset Length: 3634
Validation Dataset Length: 454

DataLoader Batch Sizes:
Number of Batches in Train DataLoader: 57
Number of Batches in Validation DataLoader: 8

y_train Unique Values and Stats:
Unique values in y_train: tensor([0, 1], device='cuda:0')
y_train Min: 0, Max: 1

y_val Unique Values and Stats:
Unique values in y_val: tensor([0, 1], device='cuda:0')
y_val Min: 0, Max: 1

Device Info:
X_train Device: cuda:0
y_train Device: cuda:0
X_val Device: cuda:0
y_val Device: cuda:0

Epoch 1/2000, Train Loss: 0.010724814, Val Loss: 0.011819

### -----------> Fisher Matrix computation for __*Period 1 '1st_try'*__

In [35]:
"""
- 'trend': Categorized trend values based on the detected phases:
    - 0: No trend
    - 1: Moderate negative trend
    - 2: Very strong negative trend
    - 3: Moderate positive trend
    - 4: Very strong positive trend
"""
with contextlib.redirect_stdout(open(os.devnull, 'w')):
    X_train, y_train, X_val, y_val, X_test, y_test, Number_features = process_and_return_splits(
        with_indicators_file_path = list_period_files_full_path[0], # Change 
        downsampled_data_minutes = downsampled_data_minutes,
        exclude_columns = exclude_columns,
        lower_threshold = lower_threshold,
        upper_threshold = upper_threshold,
        reverse_steps = reverse_steps,
        sequence_length = sequence_length,
        sliding_interval = sliding_interval,
        trends_to_keep = {0, 1}  # Default keeps all trends : {0, 1, 2, 3, 4}
        # trends_to_keep = {0, 1, 2, 3, 4}  # Default keeps all trends
    )
print(f"\nNumber_features = {Number_features}")
unique_classes = np.unique(y_val)
num_classes = len(unique_classes)
print(f"unique_classes = {unique_classes}")
print(f"num_classes = {num_classes}")


#-------------------------------------------------------------------------
# Model parameters
input_size = Number_features  # Number of features
hidden_size = 64  # Number of GRU units
output_size = num_classes # Must be dynamic, up to 5  # Number of trend classes (0, 15, 25, -15, -25)
num_layers = 4  # Number of GRU layers
dropout = 0.0
fisher_calc_batch_size= 50 
model_name = 'BiGRUWithAttentionEWC' # Name of the model to use for saving
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#-----------------
previous_model = BiGRUWithAttention(input_size, hidden_size, output_size, num_layers, dropout).to(device)
best_model_path = r"Class_Incremental_CL\Classif_Bi_Dir_GRU_Model\Trained_models\Baseline\Period_1\1st_try\BiGRUWithAttention_epoch_1987.pth"
prev_checkpoint_path = os.path.normpath(best_model_path)
prev_checkpoint = torch.load(prev_checkpoint_path, map_location=device, weights_only=True)
# print(f"\n{prev_checkpoint}\n")
previous_model.load_state_dict(prev_checkpoint['model_state_dict'])
#-----------------
print(prev_checkpoint.keys(), "\n")
skip_keys = {"model_state_dict", "optimizer_state_dict"}
for name, content in prev_checkpoint.items():
    if name in skip_keys:
        continue
    print(f"{name}: (type: {type(content)}) \n{content}")
# for param_name, tensor in prev_checkpoint['model_state_dict'].items():
#     print(f"{param_name}: {tensor.shape}")
print(f"Loaded preious model from: \n\t{prev_checkpoint_path}")
#-----------------
criterion = nn.CrossEntropyLoss()
# Create a dataloader from Period 1 training data for Fisher computation
train_dataset_period1 = TensorDataset(torch.tensor(X_train, dtype=torch.float32), 
                                        torch.tensor(y_train, dtype=torch.long))
train_loader_period1 = DataLoader(train_dataset_period1, batch_size=fisher_calc_batch_size, shuffle=True) # Shuffle is good practice


#-------------------------------------------------------------------------
# == Case 1: Compute EWC using ALL samples ==
print(f"\nCalculating EWC state with sample_size = All...")
sample_size_all = None
# Call the static method to compute Fisher and Params
try:
    ewc_fisher_dict_all, ewc_params_dict_all = EWC.compute_fisher_and_params(
        model=previous_model,       # Pass the loaded Period 1 model
        dataloader=train_loader_period1,
        criterion=criterion,
        device=device,
        sample_size=sample_size_all # Use all samples
    )
    print("EWC Fisher and Params computed (All Samples).")
    #----------------------------------------<<<<<<<<<<<<<<<
    # --- !! START INSPECTION HERE !! ---
    inspect_fisher_info(ewc_fisher_dict_all, label="Computed Fisher Values (All Samples)")
    # --- !! END INSPECTION !! ---
    #----------------------------------------<<<<<<<<<<<<<<<
    # Prepare a copy of the original checkpoint dictionary to add EWC state to it
    checkpoint_to_save_all = prev_checkpoint.copy()
    # Add computed EWC state (save to CPU for better compatibility)
    checkpoint_to_save_all['ewc_fisher'] = ewc_fisher_dict_all
    checkpoint_to_save_all['ewc_params'] = ewc_params_dict_all
    #-----------------
    save_checkpoint_dir = os.path.normpath(os.path.join('Class_Incremental_CL', "Classif_Bi_Dir_GRU_Model/Trained_models/EWC_CIL/Period_1_weights_with_ewc/EWC_with_All_samples"))
    ensure_folder(save_checkpoint_dir)
    save_checkpoint_path = os.path.normpath(os.path.join(save_checkpoint_dir, f"{model_name}_epoch_{prev_checkpoint['epoch']}_ewc_All.pth"))
    #-----------------
    torch.save(checkpoint_to_save_all, save_checkpoint_path)
    del ewc_fisher_dict_all, ewc_params_dict_all, checkpoint_to_save_all, save_checkpoint_dir, save_checkpoint_path
except Exception as e:
    print(f"ERROR calculating or saving EWC state (All Samples): {e}")


#-------------------------------------------------------------------------
# == Case 2: Compute EWC using 100 samples ==
print(f"\nCalculating EWC state with sample_size = 100...")
sample_size_100 = 100
# Call the static method again for the 100-sample case
try:
    ewc_fisher_dict_100, ewc_params_dict_100 = EWC.compute_fisher_and_params(
        model=previous_model,             # Use the same loaded model
        dataloader=train_loader_period1, # Use the same dataloader
        criterion=criterion,
        device=device,
        sample_size=sample_size_100      # Use 100 samples this time
    )
    print("EWC Fisher and Params computed (100 Samples).\n")
    #----------------------------------------<<<<<<<<<<<<<<<
    # --- !! START INSPECTION HERE !! ---
    inspect_fisher_info(ewc_fisher_dict_100, label="Computed Fisher Values (100 Samples)")
    # --- !! END INSPECTION !! ---
    #----------------------------------------<<<<<<<<<<<<<<<
    # Prepare another copy of the original checkpoint
    checkpoint_to_save_100 = prev_checkpoint.copy()
    # Add computed EWC state (save to CPU)
    checkpoint_to_save_100['ewc_fisher'] = ewc_fisher_dict_100
    checkpoint_to_save_100['ewc_params'] = ewc_params_dict_100
    #-----------------
    save_dir_100 = os.path.normpath(os.path.join('Class_Incremental_CL', f"Classif_Bi_Dir_GRU_Model/Trained_models/EWC_CIL/Period_1_weights_with_ewc/EWC_with_{sample_size_100}_samples"))
    ensure_folder(save_dir_100) # Ensure the directory exists
    save_path_100 = os.path.normpath(os.path.join(save_dir_100, f"{model_name}_epoch_{prev_checkpoint['epoch']}_ewc_{sample_size_100}.pth"))
    #-----------------
    torch.save(checkpoint_to_save_100, save_path_100)
    del ewc_fisher_dict_100, ewc_params_dict_100, checkpoint_to_save_100, save_dir_100, save_path_100
except Exception as e:
    print(f"ERROR calculating or saving EWC state (100 Samples): {e}")


#-------------------------------------------------------------------------
for var in ["X_train", "y_train", "X_val", "y_val", "X_test", "y_test", "fisher_calc_batch_size", "train_dataset_period1", 
            "train_loader_period1", "Number_features", "unique_classes", "num_classes", "previous_model", "ewc_period_1"]:
    if var in locals():
        del locals()[var]
#-----------------
del prev_checkpoint, best_model_path, prev_checkpoint_path
gc.collect()



Number_features = 7
unique_classes = [0 1]
num_classes = 2
dict_keys(['epoch', 'train_loss', 'val_loss', 'model_state_dict', 'optimizer_state_dict', 'learning_rate']) 

epoch: (type: <class 'int'>) 
1987
train_loss: (type: <class 'float'>) 
0.004282388898165231
val_loss: (type: <class 'float'>) 
0.08081079054508966
learning_rate: (type: <class 'float'>) 
0.0001
Loaded preious model from: 
	Class_Incremental_CL\Classif_Bi_Dir_GRU_Model\Trained_models\Baseline\Period_1\1st_try\BiGRUWithAttention_epoch_1987.pth

Calculating EWC state with sample_size = All...
Starting Fisher Information Matrix calculation (v3: train mode for RNN backward)...
Model temporarily set to TRAIN mode for Fisher calculation.
Dropout layers temporarily disabled during forward pass.
Processing data batches on device: cuda
Restoring original Dropout layer states...
Model mode restored to original state (training=True).
Finished accumulating gradients over 3634 samples.
Normalizing Fisher matrix...
Fisher Informatio

8

### Period 2 --> Training and saving in __*'1st_try'*__ (BiGRUWithAttention, lambda_ewc= 40, learning_rate= 0.0001) ---> Val acc = 98.00 %
### Val-Class-Acc: {0: '99.70%', 1: '97.60%', 2: '91.06%'}

In [25]:
"""
- 'trend': Categorized trend values based on the detected phases:
    - 0: No trend
    - 1: Moderate negative trend
    - 2: Very strong negative trend
    - 3: Moderate positive trend
    - 4: Very strong positive trend
"""
with contextlib.redirect_stdout(open(os.devnull, 'w')):
    X_train, y_train, X_val, y_val, X_test, y_test, Number_features = process_and_return_splits(
        with_indicators_file_path = list_period_files_full_path[1], # Change
        downsampled_data_minutes = downsampled_data_minutes,
        exclude_columns = exclude_columns,
        lower_threshold = lower_threshold,
        upper_threshold = upper_threshold,
        reverse_steps = reverse_steps,
        sequence_length = sequence_length,
        sliding_interval = sliding_interval,
        trends_to_keep = {0, 1, 2}  # Default keeps all trends : {0, 1, 2, 3, 4}
        # trends_to_keep = {0, 1, 2, 3, 4}  # Default keeps all trends
    )

print(f"\nNumber_features = {Number_features}")

unique_classes = np.unique(y_val)
num_classes = len(unique_classes)
print(f"unique_classes = {unique_classes}")
print(f"num_classes = {num_classes}\n")

print_class_distribution(y_train, "y_train")
print_class_distribution(y_val, "y_val")
print_class_distribution(y_test, "y_test")



Number_features = 7
unique_classes = [0 1 2]
num_classes = 3

Class Distribution for 'y_train':       Class 0   Percent:  53.41% || Class 1   Percent:  36.03% || Class 2   Percent:  10.55%
Class Distribution for 'y_val':         Class 0   Percent:  52.72% || Class 1   Percent:  36.38% || Class 2   Percent:  10.90%
Class Distribution for 'y_test':        Class 0   Percent:  49.11% || Class 1   Percent:  32.27% || Class 2   Percent:  18.62%


In [None]:
# Model parameters
input_size = Number_features  # Number of features
hidden_size = 64  # Number of GRU units
output_size = num_classes # Must be dynamic, up to 5  # Number of trend classes (0, 15, 25, -15, -25)
num_layers = 4  # Number of GRU layers
dropout = 0.0
learning_rate = 0.0001 # For the optimizer # lr=0.00005
weight_decay=1e-5 # For the optimizer
use_scheduler = False
lambda_ewc=40 #<<<<<<<<<<<<0.4
# ---- ---- ---- #
num_epochs= 2000 # Number of epochs/ go through entire data
batch_size= 64 # How many sequences passed at once to the model
model_name = 'BiGRUWithAttentionEWC' # Name of the model to use for saving
best_results = [] # Initialize this outside the training function or at the beginning of training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Define a global stop signal
stop_signal_file = os.path.normpath(os.path.join('Class_Incremental_CL', 'Classif_Bi_Dir_GRU_Model/stop_training.txt'))  # Create this file to stop training
model_saving_folder = os.path.normpath(os.path.join('Class_Incremental_CL', "Classif_Bi_Dir_GRU_Model/Trained_models/EWC_CIL/Period_2/1st_try"))
ensure_folder(model_saving_folder)

# Instantiate the current period model (and load weights from previous period's best model)
current_model = BiGRUWithAttention(input_size, hidden_size, output_size, num_layers, dropout).to(device)
# Create independent copy (deep copy). 
# It will be used for Fisher Matrix computation for the current period's model
deepcopy = copy.deepcopy(current_model)
#-------------------------------------------------------------------------
period_1_epoch = 1987 # Epoch of the chosen Period 1 model
ewc_sample_tag = "All" # Options: "All" or "100" (must match the saved file)
best_model_dir = os.path.normpath(os.path.join('Class_Incremental_CL', f"Classif_Bi_Dir_GRU_Model/Trained_models/EWC_CIL/Period_1_weights_with_ewc/EWC_with_{ewc_sample_tag}_samples"))
best_model_path = os.path.normpath(os.path.join(best_model_dir, f"{model_name}_epoch_{period_1_epoch}_ewc_{ewc_sample_tag}.pth"))
#-------- OR ---------
# best_model_path = r"Class_Incremental_CL\Classif_Bi_Dir_GRU_Model\Trained_models\Baseline\Period_1\1st_try\BiGRUWithAttention_epoch_1987.pth"
#----------------------------------------------------------------------
# Initialize the list to store results across runs
track_across_runs = []
#----------------------------------------------------------------------
#-------------------------------------------------------------------------
prev_checkpoint_path = os.path.normpath(best_model_path)
prev_checkpoint = torch.load(prev_checkpoint_path, map_location=device, weights_only=True)
required_keys = ['model_state_dict', 'ewc_fisher', 'ewc_params']
for key in required_keys:
    if key not in prev_checkpoint:
        raise KeyError(f"Checkpoint {prev_checkpoint} is missing required key: '{key}'")
prev_model_dict = prev_checkpoint['model_state_dict']
ewc_fisher_dict = prev_checkpoint['ewc_fisher'] # Should be on CPU
ewc_params_dict = prev_checkpoint['ewc_params'] # Should be on CPU
print(f"Loaded previous period checkpoint from: \n\t{prev_checkpoint_path}")
# print(f"\nprev_checkpoint: \n{prev_checkpoint}\n")
#-------------------------------------------------------------------------

#----------------------------------------<<<<<<<<<<<<<<<
# --- !! START INSPECTION HERE !! ---
inspect_fisher_info(ewc_fisher_dict, label=f"Loaded Fisher Values ({ewc_sample_tag} Samples)")
# --- !! END INSPECTION !! ---
#----------------------------------------<<<<<<<<<<<<<<<

#-------------------------------------------------------------------------
# --- Transfer Compatible Weights from Previous Period Model to Current Period Model ---
print("Transferring compatible weights to the new model...")
current_model_dict = current_model.state_dict()
# Filter the loaded state dict: Keep only layers that exist in the current model AND have matching shapes
# Update only parameters with matching shapes (skip final fc if dimensions differ)
filtered_prev_state_dict = {
    k: v for k, v in prev_model_dict.items()
    if k in current_model_dict and v.size() == current_model_dict[k].size()
}
if not filtered_prev_state_dict:
     print("Warning: No compatible weights found to transfer. The new model will start from random initialization (except layers possibly initialized by default).")
else:
    # Load the filtered weights into the new model.
    # `strict=False` allows loading a partial state dict (ignoring missing keys like the final FC layer)
    missing_keys, unexpected_keys = current_model.load_state_dict(filtered_prev_state_dict, strict=False)
    print(f"  Weights loaded. Keys missing in loaded dict (expected: fc layer): {missing_keys}")
    print(f"  Keys in loaded dict but not in model (should be empty): {unexpected_keys}")
# Ensure the model is on the correct device *after* loading state dict
current_model.to(device)
print(f"\nCurrent Period Model Structure: \n{current_model}\n")
#-------------------------------------------------------------------------
# # --- Transfer Compatible Weights from Previous Period Model to Current Period Model ---
# current_model_dict = current_model.state_dict()
# # Filter the loaded state dict: Keep only layers that exist in the current model AND have matching shapes
# # Update only parameters with matching shapes (skip final fc if dimensions differ)
# prev_model_dict = {
#     k: v for k, v in prev_model_dict.items() 
#     if k in current_model_dict and v.size() == current_model_dict[k].size()
# }
# current_model_dict.update(prev_model_dict)
# current_model.load_state_dict(current_model_dict)
# # Ensure the model is on the correct device *after* loading state dict
# current_model.to(device)
# print(f"\nCurrent Period Model Structure: \n{current_model}\n")
#-------------------------------------------------------------------------
# --- Instantiate EWC Object for Training ---
# Use the Fisher matrix and optimal parameters loaded from the Previous Period checkpoint
print("Instantiating EWC object using loaded state...")
ewc_object = EWC(fisher=ewc_fisher_dict, params=ewc_params_dict)
# The fisher/params tensors are currently on CPU (as saved).
# The EWC.penalty method handles moving the param tensors to the model's device during calculation.
#-------------------------------------------------------------------------
# Define the loss function, optimizer and scheduler
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(current_model.parameters(), lr=learning_rate, weight_decay=weight_decay) # lr=0.00005
scheduler = None
if use_scheduler:
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.9, patience=10)
    # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
    scheduler_name = scheduler.__class__.__name__
    print(f"Using {scheduler_name} scheduler.\n")

# raise Exception("Stop here!") # <<<-----------------------------------------

train_and_validate_ewc(
    model=current_model,
    deepcopy=deepcopy,
    output_size=output_size,
    criterion=criterion,
    optimizer=optimizer,
    X_train=X_train, y_train=y_train,
    X_val=X_val, y_val=y_val,
    scheduler=scheduler,
    use_scheduler=use_scheduler,
    ewc=ewc_object, # Pass the EWC object
    lambda_ewc=lambda_ewc, # Pass the EWC strength
    num_epochs=num_epochs,
    batch_size=batch_size,
    model_saving_folder=model_saving_folder,
    model_name=model_name, # Use base name for saved files in this period
    stop_signal_file=stop_signal_file
)


#----------------------------------------------------------------------
# Append only the best result (already at index 0)
track_across_runs.append(best_results[0])
#----------------------------------------------------------------------

for res in best_results:        
    print(f"Epoch {res['epoch']}/{num_epochs}, "
            f"Train Loss: {res['train_loss']:.4f}, " 
            f"Val Loss: {res['val_loss']:.4f}, "
            f"Val Accuracy: {res['val_accuracy'] * 100:.2f}%, "
            f"Model Path: {res['model_path']}")      
print(f"\nclass_gru_model with ewc (current_model): \n{current_model}")

print(f"\nunique_classes = {unique_classes}")
print(f"num_classes = {num_classes}")

for var in [
    "X_train", "y_train", "X_val", "y_val", "X_test", "y_test",
    "Number_features", "unique_classes", "num_classes",
    "current_model", "ewc_object",
    "prev_checkpoint", "prev_model_dict", "prev_checkpoint_path",
    "deepcopy", "filtered_prev_state_dict", "current_model_dict",
    "ewc_fisher_dict", "ewc_params_dict"
]:
    if var in locals():
        del locals()[var]
# --- Force garbage collection ---
gc.collect()
torch.cuda.empty_cache()


Loaded previous period checkpoint from: 
	Class_Incremental_CL\Classif_Bi_Dir_GRU_Model\Trained_models\EWC_CIL\Period_1_weights_with_ewc\EWC_with_All_samples\BiGRUWithAttentionEWC_epoch_1987_ewc_All.pth

--- Inspecting Loaded Fisher Values (All Samples) ---
Total Fisher Sum: 1.2840e+02
Overall Max Fisher Abs Value: 3.1890e+01
Overall Min Positive Fisher Value: 6.4935e-23
Number of layers with all-zero Fisher: 0 / 36
--- Finished Inspecting Loaded Fisher Values (All Samples) ---

Transferring compatible weights to the new model...
  Weights loaded. Keys missing in loaded dict (expected: fc layer): ['fc.weight', 'fc.bias']
  Keys in loaded dict but not in model (should be empty): []

Current Period Model Structure: 
BiGRUWithAttention(
  (gru): GRU(7, 64, num_layers=4, batch_first=True, bidirectional=True)
  (attention_fc): Linear(in_features=128, out_features=128, bias=True)
  (fc): Linear(in_features=128, out_features=3, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
)

Instanti

In [79]:
for var in [
    "X_train", "y_train", "X_val", "y_val", "X_test", "y_test",
    "Number_features", "unique_classes", "num_classes",
    "current_model", "ewc_object",
    "prev_checkpoint", "prev_model_dict", "prev_checkpoint_path",
    "period_1_epoch", "ewc_sample_tag", "best_model_dir",
    "filtered_prev_state_dict", "current_model_dict",
    "ewc_fisher_dict", "ewc_params_dict"
]:
    if var in locals():
        del locals()[var]
# --- Force garbage collection ---
gc.collect()
torch.cuda.empty_cache()

In [27]:
# Reconstruct model with same config
model_template = BiGRUWithAttention(
    input_size=Number_features,
    hidden_size=64,
    output_size=num_classes,
    num_layers=4,
    dropout=0.0
)

# Prepare validation data
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
X_val_tensor = torch.tensor(X_val, dtype=torch.float32).to(device)
y_val_tensor = torch.tensor(y_val, dtype=torch.long).to(device)
criterion = nn.CrossEntropyLoss()

# Your model folder
model_folder = 'Class_Incremental_CL/Classif_Bi_Dir_GRU_Model/Trained_models/EWC_CIL/Period_2/1st_try'
ckpt_paths = sorted(glob.glob(os.path.join(model_folder, "*.pth")))

# Loop through and update each model
for file_path in ckpt_paths:
    model_copy = copy.deepcopy(model_template).to(device)
    checkpoint = torch.load(file_path, map_location=device)
    model_copy.load_state_dict(checkpoint['model_state_dict'])

    add_val_class_acc(
        model_copy=model_copy,
        output_size=num_classes,
        criterion=criterion,
        X_val=X_val_tensor,
        y_val=y_val_tensor,
        file_path=file_path,
        model_checkpoint=checkpoint,
        batch_size=64
    )

# ----------------------------
# 印出所有模型的 val_accuracy 和 Val-Class-Acc
# ----------------------------
show_val_accuracy_from_checkpoints(model_folder)


  checkpoint = torch.load(file_path, map_location=device)
  X_val = torch.tensor(X_val, dtype=torch.float32).to(device)
  y_val = torch.tensor(y_val, dtype=torch.long).to(device)


✅ Updated: BiGRUWithAttentionEWC_epoch_30.pth | Acc: 0.9789 | Saved to: Class_Incremental_CL/Classif_Bi_Dir_GRU_Model/Trained_models/EWC_CIL/Period_2/1st_try\BiGRUWithAttentionEWC_epoch_30.pth
✅ Updated: BiGRUWithAttentionEWC_epoch_33.pth | Acc: 0.9790 | Saved to: Class_Incremental_CL/Classif_Bi_Dir_GRU_Model/Trained_models/EWC_CIL/Period_2/1st_try\BiGRUWithAttentionEWC_epoch_33.pth
✅ Updated: BiGRUWithAttentionEWC_epoch_34.pth | Acc: 0.9795 | Saved to: Class_Incremental_CL/Classif_Bi_Dir_GRU_Model/Trained_models/EWC_CIL/Period_2/1st_try\BiGRUWithAttentionEWC_epoch_34.pth
✅ Updated: BiGRUWithAttentionEWC_epoch_36.pth | Acc: 0.9800 | Saved to: Class_Incremental_CL/Classif_Bi_Dir_GRU_Model/Trained_models/EWC_CIL/Period_2/1st_try\BiGRUWithAttentionEWC_epoch_36.pth
✅ Updated: BiGRUWithAttentionEWC_epoch_56.pth | Acc: 0.9790 | Saved to: Class_Incremental_CL/Classif_Bi_Dir_GRU_Model/Trained_models/EWC_CIL/Period_2/1st_try\BiGRUWithAttentionEWC_epoch_56.pth
✅ Updated: BiGRUWithAttentionEWC_fi

  checkpoint = torch.load(path, map_location="cpu")


### Period 3 --> Training and saving in __*'1st_try'*__ (BiGRUWithAttention, lambda_ewc= 40, learning_rate= 0.0001) ---> Val acc = 97.32 %
### Val-Class-Acc: {0: '89.81%', 1: '99.00%', 2: '90.93%', 3: '98.00%'}

In [28]:
"""
- 'trend': Categorized trend values based on the detected phases:
    - 0: No trend
    - 1: Moderate negative trend
    - 2: Very strong negative trend
    - 3: Moderate positive trend
    - 4: Very strong positive trend
"""
with contextlib.redirect_stdout(open(os.devnull, 'w')):
    X_train, y_train, X_val, y_val, X_test, y_test, Number_features = process_and_return_splits(
        with_indicators_file_path = list_period_files_full_path[2], # Change 
        downsampled_data_minutes = downsampled_data_minutes,
        exclude_columns = exclude_columns,
        lower_threshold = lower_threshold,
        upper_threshold = upper_threshold,
        reverse_steps = reverse_steps,
        sequence_length = sequence_length,
        sliding_interval = sliding_interval,
        trends_to_keep = {0, 1, 2, 3}  # Default keeps all trends : {0, 1, 2, 3, 4}
        # trends_to_keep = {0, 1, 2, 3, 4}  # Default keeps all trends
    )

print(f"\nNumber_features = {Number_features}")

unique_classes = np.unique(y_val)
num_classes = len(unique_classes)
print(f"unique_classes = {unique_classes}")
print(f"num_classes = {num_classes}\n")

print_class_distribution(y_train, "y_train")
print_class_distribution(y_val, "y_val")
print_class_distribution(y_test, "y_test")



Number_features = 7
unique_classes = [0 1 2 3]
num_classes = 4

Class Distribution for 'y_train':       Class 0   Percent:  13.79% || Class 1   Percent:  37.59% || Class 2   Percent:   8.89% || Class 3   Percent:  39.73%
Class Distribution for 'y_val':         Class 0   Percent:  10.26% || Class 1   Percent:  46.19% || Class 2   Percent:   4.21% || Class 3   Percent:  39.34%
Class Distribution for 'y_test':        Class 0   Percent:  10.30% || Class 1   Percent:  43.50% || Class 2   Percent:   8.50% || Class 3   Percent:  37.70%


In [81]:
# Model parameters
input_size = Number_features  # Number of features
hidden_size = 64  # Number of GRU units
output_size = num_classes # Must be dynamic, up to 5  # Number of trend classes (0, 15, 25, -15, -25)
num_layers = 4  # Number of GRU layers
dropout = 0.0
learning_rate = 0.0001 # For the optimizer # lr=0.00005
weight_decay=1e-5 # For the optimizer
use_scheduler = False
lambda_ewc=40 #<<<<<<<<<<<<0.4
# ---- ---- ---- #
num_epochs= 2000 # Number of epochs/ go through entire data
batch_size= 64 # How many sequences passed at once to the model
model_name = 'BiGRUWithAttentionEWC' # Name of the model to use for saving
best_results = [] # Initialize this outside the training function or at the beginning of training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Define a global stop signal
stop_signal_file = os.path.normpath(os.path.join('Class_Incremental_CL', 'Classif_Bi_Dir_GRU_Model/stop_training.txt'))  # Create this file to stop training
model_saving_folder = os.path.normpath(os.path.join('Class_Incremental_CL', "Classif_Bi_Dir_GRU_Model/Trained_models/EWC_CIL/Period_3/1st_try"))
ensure_folder(model_saving_folder)

# Instantiate the current period model (and load weights from previous period's best model)
current_model = BiGRUWithAttention(input_size, hidden_size, output_size, num_layers, dropout).to(device)
# Create independent copy (deep copy). 
# It will be used for Fisher Matrix computation for the current period's model
deepcopy = copy.deepcopy(current_model)
#-------------------------------------------------------------------------
best_overall = max(track_across_runs, key=lambda res: res['val_accuracy'])
best_model_path = best_overall['model_path']
del best_overall
#----------------------------------------------------------------------
# Initialize the list to store results across runs
track_across_runs = []
#----------------------------------------------------------------------
#-------------------------------------------------------------------------
prev_checkpoint_path = os.path.normpath(best_model_path)
prev_checkpoint = torch.load(prev_checkpoint_path, map_location=device, weights_only=True)
required_keys = ['model_state_dict', 'ewc_fisher', 'ewc_params']
for key in required_keys:
    if key not in prev_checkpoint:
        raise KeyError(f"Checkpoint {prev_checkpoint} is missing required key: '{key}'")
prev_model_dict = prev_checkpoint['model_state_dict']
ewc_fisher_dict = prev_checkpoint['ewc_fisher'] # Should be on CPU
ewc_params_dict = prev_checkpoint['ewc_params'] # Should be on CPU
print(f"Loaded previous period checkpoint from: \n\t{prev_checkpoint_path}")
# print(f"\nprev_checkpoint: \n{prev_checkpoint}\n")
#-------------------------------------------------------------------------

#----------------------------------------<<<<<<<<<<<<<<<
# --- !! START INSPECTION HERE !! ---
inspect_fisher_info(ewc_fisher_dict, label=f"Loaded Fisher Values (All Samples)")
# --- !! END INSPECTION !! ---
#----------------------------------------<<<<<<<<<<<<<<<

#-------------------------------------------------------------------------
# --- Transfer Compatible Weights from Previous Period Model to Current Period Model ---
print("Transferring compatible weights to the new model...")
current_model_dict = current_model.state_dict()
# Filter the loaded state dict: Keep only layers that exist in the current model AND have matching shapes
# Update only parameters with matching shapes (skip final fc if dimensions differ)
filtered_prev_state_dict = {
    k: v for k, v in prev_model_dict.items()
    if k in current_model_dict and v.size() == current_model_dict[k].size()
}
if not filtered_prev_state_dict:
     print("Warning: No compatible weights found to transfer. The new model will start from random initialization (except layers possibly initialized by default).")
else:
    # Load the filtered weights into the new model.
    # `strict=False` allows loading a partial state dict (ignoring missing keys like the final FC layer)
    missing_keys, unexpected_keys = current_model.load_state_dict(filtered_prev_state_dict, strict=False)
    print(f"  Weights loaded. Keys missing in loaded dict (expected: fc layer): {missing_keys}")
    print(f"  Keys in loaded dict but not in model (should be empty): {unexpected_keys}")
# Ensure the model is on the correct device *after* loading state dict
current_model.to(device)
print(f"\nCurrent Period Model Structure: \n{current_model}\n")
#-------------------------------------------------------------------------
# # --- Transfer Compatible Weights from Previous Period Model to Current Period Model ---
# current_model_dict = current_model.state_dict()
# # Filter the loaded state dict: Keep only layers that exist in the current model AND have matching shapes
# # Update only parameters with matching shapes (skip final fc if dimensions differ)
# prev_model_dict = {
#     k: v for k, v in prev_model_dict.items() 
#     if k in current_model_dict and v.size() == current_model_dict[k].size()
# }
# current_model_dict.update(prev_model_dict)
# current_model.load_state_dict(current_model_dict)
# # Ensure the model is on the correct device *after* loading state dict
# current_model.to(device)
# print(f"\nCurrent Period Model Structure: \n{current_model}\n")
#-------------------------------------------------------------------------
# --- Instantiate EWC Object for Training ---
# Use the Fisher matrix and optimal parameters loaded from the Previous Period checkpoint
print("Instantiating EWC object using loaded state...")
ewc_object = EWC(fisher=ewc_fisher_dict, params=ewc_params_dict)
# The fisher/params tensors are currently on CPU (as saved).
# The EWC.penalty method handles moving the param tensors to the model's device during calculation.
#-------------------------------------------------------------------------
# Define the loss function, optimizer and scheduler
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(current_model.parameters(), lr=learning_rate, weight_decay=weight_decay) # lr=0.00005
scheduler = None
if use_scheduler:
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.9, patience=10)
    # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
    scheduler_name = scheduler.__class__.__name__
    print(f"Using {scheduler_name} scheduler.\n")

# raise Exception("Stop here!") # <<<-----------------------------------------

train_and_validate_ewc(
    model=current_model,
    deepcopy=deepcopy,
    output_size=output_size,
    criterion=criterion,
    optimizer=optimizer,
    X_train=X_train, y_train=y_train,
    X_val=X_val, y_val=y_val,
    scheduler=scheduler,
    use_scheduler=use_scheduler,
    ewc=ewc_object, # Pass the EWC object
    lambda_ewc=lambda_ewc, # Pass the EWC strength
    num_epochs=num_epochs,
    batch_size=batch_size,
    model_saving_folder=model_saving_folder,
    model_name=model_name, # Use base name for saved files in this period
    stop_signal_file=stop_signal_file
)


#----------------------------------------------------------------------
# Append only the best result (already at index 0)
track_across_runs.append(best_results[0])
#----------------------------------------------------------------------

for res in best_results:        
    print(f"Epoch {res['epoch']}/{num_epochs}, "
            f"Train Loss: {res['train_loss']:.4f}, " 
            f"Val Loss: {res['val_loss']:.4f}, "
            f"Val Accuracy: {res['val_accuracy'] * 100:.2f}%, "
            f"Model Path: {res['model_path']}")      
print(f"\nclass_gru_model with ewc (current_model): \n{current_model}")

print(f"\nunique_classes = {unique_classes}")
print(f"num_classes = {num_classes}")

for var in [
    "X_train", "y_train", "X_val", "y_val", "X_test", "y_test",
    "Number_features", "unique_classes", "num_classes",
    "current_model", "ewc_object",
    "prev_checkpoint", "prev_model_dict", "prev_checkpoint_path",
    "deepcopy", "filtered_prev_state_dict", "current_model_dict",
    "ewc_fisher_dict", "ewc_params_dict"
]:
    if var in locals():
        del locals()[var]
# --- Force garbage collection ---
gc.collect()
torch.cuda.empty_cache()


Loaded previous period checkpoint from: 
	Class_Incremental_CL\Classif_Bi_Dir_GRU_Model\Trained_models\EWC_CIL\Period_2\1st_try\BiGRUWithAttentionEWC_epoch_36.pth

--- Inspecting Loaded Fisher Values (All Samples) ---
Total Fisher Sum: 5.0296e+02
Overall Max Fisher Abs Value: 1.2431e+02
Overall Min Positive Fisher Value: 1.7086e-20
Number of layers with all-zero Fisher: 0 / 36
--- Finished Inspecting Loaded Fisher Values (All Samples) ---

Transferring compatible weights to the new model...
  Weights loaded. Keys missing in loaded dict (expected: fc layer): ['fc.weight', 'fc.bias']
  Keys in loaded dict but not in model (should be empty): []

Current Period Model Structure: 
BiGRUWithAttention(
  (gru): GRU(7, 64, num_layers=4, batch_first=True, bidirectional=True)
  (attention_fc): Linear(in_features=128, out_features=128, bias=True)
  (fc): Linear(in_features=128, out_features=4, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
)

Instantiating EWC object using loaded state...
'

In [30]:
# ----------------------------
# 補上 val_accuracy 與 Val-Class-Acc for Period 3
# ----------------------------

# Reconstruct model with same config
model_template = BiGRUWithAttention(
    input_size=Number_features,
    hidden_size=64,
    output_size=num_classes,
    num_layers=4,
    dropout=0.0
)

# Prepare validation data
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
X_val_tensor = torch.tensor(X_val, dtype=torch.float32).to(device)
y_val_tensor = torch.tensor(y_val, dtype=torch.long).to(device)
criterion = nn.CrossEntropyLoss()

# Your model folder
model_folder = 'Class_Incremental_CL/Classif_Bi_Dir_GRU_Model/Trained_models/EWC_CIL/Period_3/1st_try'
ckpt_paths = sorted(glob.glob(os.path.join(model_folder, "*.pth")))

# Loop through and update each model
for file_path in ckpt_paths:
    model_copy = copy.deepcopy(model_template).to(device)
    checkpoint = torch.load(file_path, map_location=device)
    model_copy.load_state_dict(checkpoint['model_state_dict'])

    add_val_class_acc(
        model_copy=model_copy,
        output_size=num_classes,
        criterion=criterion,
        X_val=X_val_tensor,
        y_val=y_val_tensor,
        file_path=file_path,
        model_checkpoint=checkpoint,
        batch_size=64
    )

# ----------------------------
# 印出所有模型的 val_accuracy 和 Val-Class-Acc
# ----------------------------
show_val_accuracy_from_checkpoints(model_folder)


  checkpoint = torch.load(file_path, map_location=device)
  X_val = torch.tensor(X_val, dtype=torch.float32).to(device)
  y_val = torch.tensor(y_val, dtype=torch.long).to(device)


✅ Updated: BiGRUWithAttentionEWC_epoch_1852.pth | Acc: 0.9731 | Saved to: Class_Incremental_CL/Classif_Bi_Dir_GRU_Model/Trained_models/EWC_CIL/Period_3/1st_try\BiGRUWithAttentionEWC_epoch_1852.pth
✅ Updated: BiGRUWithAttentionEWC_epoch_1854.pth | Acc: 0.9727 | Saved to: Class_Incremental_CL/Classif_Bi_Dir_GRU_Model/Trained_models/EWC_CIL/Period_3/1st_try\BiGRUWithAttentionEWC_epoch_1854.pth
✅ Updated: BiGRUWithAttentionEWC_epoch_245.pth | Acc: 0.9729 | Saved to: Class_Incremental_CL/Classif_Bi_Dir_GRU_Model/Trained_models/EWC_CIL/Period_3/1st_try\BiGRUWithAttentionEWC_epoch_245.pth
✅ Updated: BiGRUWithAttentionEWC_epoch_372.pth | Acc: 0.9729 | Saved to: Class_Incremental_CL/Classif_Bi_Dir_GRU_Model/Trained_models/EWC_CIL/Period_3/1st_try\BiGRUWithAttentionEWC_epoch_372.pth
✅ Updated: BiGRUWithAttentionEWC_epoch_422.pth | Acc: 0.9732 | Saved to: Class_Incremental_CL/Classif_Bi_Dir_GRU_Model/Trained_models/EWC_CIL/Period_3/1st_try\BiGRUWithAttentionEWC_epoch_422.pth
✅ Updated: BiGRUWithA

  checkpoint = torch.load(path, map_location="cpu")


### Period 4 --> Training and saving in __*'1st_try'*__ (BiGRUWithAttention, lambda_ewc= 40, learning_rate= 0.0001) ---> Val acc = 97.44 %
### Val-Class-Acc: {0: '95.14%', 1: '98.82%', 2: '91.17%', 3: '98.53%', 4: '99.30%'}

In [31]:
"""
- 'trend': Categorized trend values based on the detected phases:
    - 0: No trend
    - 1: Moderate negative trend
    - 2: Very strong negative trend
    - 3: Moderate positive trend
    - 4: Very strong positive trend
"""
with contextlib.redirect_stdout(open(os.devnull, 'w')):
    X_train, y_train, X_val, y_val, X_test, y_test, Number_features = process_and_return_splits(
        with_indicators_file_path = list_period_files_full_path[3], # Change 
        downsampled_data_minutes = downsampled_data_minutes,
        exclude_columns = exclude_columns,
        lower_threshold = lower_threshold,
        upper_threshold = upper_threshold,
        reverse_steps = reverse_steps,
        sequence_length = sequence_length,
        sliding_interval = sliding_interval,
        trends_to_keep = {0, 1, 2, 3, 4}  # Default keeps all trends : {0, 1, 2, 3, 4}
        # trends_to_keep = {0, 1, 2, 3, 4}  # Default keeps all trends
    )

print(f"\nNumber_features = {Number_features}")

unique_classes = np.unique(y_val)
num_classes = len(unique_classes)
print(f"unique_classes = {unique_classes}")
print(f"num_classes = {num_classes}\n")

print_class_distribution(y_train, "y_train")
print_class_distribution(y_val, "y_val")
print_class_distribution(y_test, "y_test")



Number_features = 7
unique_classes = [0 1 2 3 4]
num_classes = 5

Class Distribution for 'y_train':       Class 0   Percent:   6.47% || Class 1   Percent:  39.17% || Class 2   Percent:   7.24% || Class 3   Percent:  41.06% || Class 4   Percent:   6.06%
Class Distribution for 'y_val':         Class 0   Percent:   5.35% || Class 1   Percent:  36.15% || Class 2   Percent:  14.78% || Class 3   Percent:  34.02% || Class 4   Percent:   9.70%
Class Distribution for 'y_test':        Class 0   Percent:   6.63% || Class 1   Percent:  40.53% || Class 2   Percent:   5.48% || Class 3   Percent:  41.29% || Class 4   Percent:   6.08%


In [83]:
# Model parameters
input_size = Number_features  # Number of features
hidden_size = 64  # Number of GRU units
output_size = num_classes # Must be dynamic, up to 5  # Number of trend classes (0, 15, 25, -15, -25)
num_layers = 4  # Number of GRU layers
dropout = 0.0
learning_rate = 0.0001 # For the optimizer # lr=0.00005
weight_decay=1e-5 # For the optimizer
use_scheduler = False
lambda_ewc=40 #<<<<<<<<<<<<0.4
# ---- ---- ---- #
num_epochs= 2000 # Number of epochs/ go through entire data
batch_size= 64 # How many sequences passed at once to the model
model_name = 'BiGRUWithAttentionEWC' # Name of the model to use for saving
best_results = [] # Initialize this outside the training function or at the beginning of training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Define a global stop signal
stop_signal_file = os.path.normpath(os.path.join('Class_Incremental_CL', 'Classif_Bi_Dir_GRU_Model/stop_training.txt'))  # Create this file to stop training
model_saving_folder = os.path.normpath(os.path.join('Class_Incremental_CL', "Classif_Bi_Dir_GRU_Model/Trained_models/EWC_CIL/Period_4/1st_try"))
ensure_folder(model_saving_folder)

# Instantiate the current period model (and load weights from previous period's best model)
current_model = BiGRUWithAttention(input_size, hidden_size, output_size, num_layers, dropout).to(device)
# Create independent copy (deep copy). 
# It will be used for Fisher Matrix computation for the current period's model
deepcopy = copy.deepcopy(current_model)
#-------------------------------------------------------------------------
best_overall = max(track_across_runs, key=lambda res: res['val_accuracy'])
best_model_path = best_overall['model_path']
del best_overall
#----------------------------------------------------------------------
# Initialize the list to store results across runs
track_across_runs = []
#----------------------------------------------------------------------
#-------------------------------------------------------------------------
prev_checkpoint_path = os.path.normpath(best_model_path)
prev_checkpoint = torch.load(prev_checkpoint_path, map_location=device, weights_only=True)
required_keys = ['model_state_dict', 'ewc_fisher', 'ewc_params']
for key in required_keys:
    if key not in prev_checkpoint:
        raise KeyError(f"Checkpoint {prev_checkpoint} is missing required key: '{key}'")
prev_model_dict = prev_checkpoint['model_state_dict']
ewc_fisher_dict = prev_checkpoint['ewc_fisher'] # Should be on CPU
ewc_params_dict = prev_checkpoint['ewc_params'] # Should be on CPU
print(f"Loaded previous period checkpoint from: \n\t{prev_checkpoint_path}")
# print(f"\nprev_checkpoint: \n{prev_checkpoint}\n")
#-------------------------------------------------------------------------

#----------------------------------------<<<<<<<<<<<<<<<
# --- !! START INSPECTION HERE !! ---
inspect_fisher_info(ewc_fisher_dict, label=f"Loaded Fisher Values (All Samples)")
# --- !! END INSPECTION !! ---
#----------------------------------------<<<<<<<<<<<<<<<

#-------------------------------------------------------------------------
# --- Transfer Compatible Weights from Previous Period Model to Current Period Model ---
print("Transferring compatible weights to the new model...")
current_model_dict = current_model.state_dict()
# Filter the loaded state dict: Keep only layers that exist in the current model AND have matching shapes
# Update only parameters with matching shapes (skip final fc if dimensions differ)
filtered_prev_state_dict = {
    k: v for k, v in prev_model_dict.items()
    if k in current_model_dict and v.size() == current_model_dict[k].size()
}
if not filtered_prev_state_dict:
     print("Warning: No compatible weights found to transfer. The new model will start from random initialization (except layers possibly initialized by default).")
else:
    # Load the filtered weights into the new model.
    # `strict=False` allows loading a partial state dict (ignoring missing keys like the final FC layer)
    missing_keys, unexpected_keys = current_model.load_state_dict(filtered_prev_state_dict, strict=False)
    print(f"  Weights loaded. Keys missing in loaded dict (expected: fc layer): {missing_keys}")
    print(f"  Keys in loaded dict but not in model (should be empty): {unexpected_keys}")
# Ensure the model is on the correct device *after* loading state dict
current_model.to(device)
print(f"\nCurrent Period Model Structure: \n{current_model}\n")
#-------------------------------------------------------------------------
# # --- Transfer Compatible Weights from Previous Period Model to Current Period Model ---
# current_model_dict = current_model.state_dict()
# # Filter the loaded state dict: Keep only layers that exist in the current model AND have matching shapes
# # Update only parameters with matching shapes (skip final fc if dimensions differ)
# prev_model_dict = {
#     k: v for k, v in prev_model_dict.items() 
#     if k in current_model_dict and v.size() == current_model_dict[k].size()
# }
# current_model_dict.update(prev_model_dict)
# current_model.load_state_dict(current_model_dict)
# # Ensure the model is on the correct device *after* loading state dict
# current_model.to(device)
# print(f"\nCurrent Period Model Structure: \n{current_model}\n")
#-------------------------------------------------------------------------
# --- Instantiate EWC Object for Training ---
# Use the Fisher matrix and optimal parameters loaded from the Previous Period checkpoint
print("Instantiating EWC object using loaded state...")
ewc_object = EWC(fisher=ewc_fisher_dict, params=ewc_params_dict)
# The fisher/params tensors are currently on CPU (as saved).
# The EWC.penalty method handles moving the param tensors to the model's device during calculation.
#-------------------------------------------------------------------------
# Define the loss function, optimizer and scheduler
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(current_model.parameters(), lr=learning_rate, weight_decay=weight_decay) # lr=0.00005
scheduler = None
if use_scheduler:
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.9, patience=10)
    # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
    scheduler_name = scheduler.__class__.__name__
    print(f"Using {scheduler_name} scheduler.\n")

# raise Exception("Stop here!") # <<<-----------------------------------------

train_and_validate_ewc(
    model=current_model,
    deepcopy=deepcopy,
    output_size=output_size,
    criterion=criterion,
    optimizer=optimizer,
    X_train=X_train, y_train=y_train,
    X_val=X_val, y_val=y_val,
    scheduler=scheduler,
    use_scheduler=use_scheduler,
    ewc=ewc_object, # Pass the EWC object
    lambda_ewc=lambda_ewc, # Pass the EWC strength
    num_epochs=num_epochs,
    batch_size=batch_size,
    model_saving_folder=model_saving_folder,
    model_name=model_name, # Use base name for saved files in this period
    stop_signal_file=stop_signal_file
)


#----------------------------------------------------------------------
# Append only the best result (already at index 0)
track_across_runs.append(best_results[0])
#----------------------------------------------------------------------

for res in best_results:        
    print(f"Epoch {res['epoch']}/{num_epochs}, "
            f"Train Loss: {res['train_loss']:.4f}, " 
            f"Val Loss: {res['val_loss']:.4f}, "
            f"Val Accuracy: {res['val_accuracy'] * 100:.2f}%, "
            f"Model Path: {res['model_path']}")      
print(f"\nclass_gru_model with ewc (current_model): \n{current_model}")

print(f"\nunique_classes = {unique_classes}")
print(f"num_classes = {num_classes}")

for var in [
    "X_train", "y_train", "X_val", "y_val", "X_test", "y_test",
    "Number_features", "unique_classes", "num_classes",
    "current_model", "ewc_object",
    "prev_checkpoint", "prev_model_dict", "prev_checkpoint_path",
    "deepcopy", "filtered_prev_state_dict", "current_model_dict",
    "ewc_fisher_dict", "ewc_params_dict"
]:
    if var in locals():
        del locals()[var]
# --- Force garbage collection ---
gc.collect()
torch.cuda.empty_cache()


Loaded previous period checkpoint from: 
	Class_Incremental_CL\Classif_Bi_Dir_GRU_Model\Trained_models\EWC_CIL\Period_3\1st_try\BiGRUWithAttentionEWC_epoch_422.pth

--- Inspecting Loaded Fisher Values (All Samples) ---
Total Fisher Sum: 1.8979e+02
Overall Max Fisher Abs Value: 4.4269e+01
Overall Min Positive Fisher Value: 5.1933e-20
Number of layers with all-zero Fisher: 0 / 36
--- Finished Inspecting Loaded Fisher Values (All Samples) ---

Transferring compatible weights to the new model...
  Weights loaded. Keys missing in loaded dict (expected: fc layer): ['fc.weight', 'fc.bias']
  Keys in loaded dict but not in model (should be empty): []

Current Period Model Structure: 
BiGRUWithAttention(
  (gru): GRU(7, 64, num_layers=4, batch_first=True, bidirectional=True)
  (attention_fc): Linear(in_features=128, out_features=128, bias=True)
  (fc): Linear(in_features=128, out_features=5, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
)

Instantiating EWC object using loaded state...


In [32]:
# ----------------------------
# 補上 val_accuracy 與 Val-Class-Acc for Period 4
# ----------------------------

# Reconstruct model with same config
model_template = BiGRUWithAttention(
    input_size=Number_features,
    hidden_size=64,
    output_size=num_classes,
    num_layers=4,
    dropout=0.0
)

# Prepare validation data
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
X_val_tensor = torch.tensor(X_val, dtype=torch.float32).to(device)
y_val_tensor = torch.tensor(y_val, dtype=torch.long).to(device)
criterion = nn.CrossEntropyLoss()

# Your model folder
model_folder = 'Class_Incremental_CL/Classif_Bi_Dir_GRU_Model/Trained_models/EWC_CIL/Period_4/1st_try'
ckpt_paths = sorted(glob.glob(os.path.join(model_folder, "*.pth")))

# Loop through and update each model
for file_path in ckpt_paths:
    model_copy = copy.deepcopy(model_template).to(device)
    checkpoint = torch.load(file_path, map_location=device)
    model_copy.load_state_dict(checkpoint['model_state_dict'])

    add_val_class_acc(
        model_copy=model_copy,
        output_size=num_classes,
        criterion=criterion,
        X_val=X_val_tensor,
        y_val=y_val_tensor,
        file_path=file_path,
        model_checkpoint=checkpoint,
        batch_size=64
    )

# ----------------------------
# 印出所有模型的 val_accuracy 和 Val-Class-Acc
# ----------------------------
show_val_accuracy_from_checkpoints(model_folder)


  checkpoint = torch.load(file_path, map_location=device)
  X_val = torch.tensor(X_val, dtype=torch.float32).to(device)
  y_val = torch.tensor(y_val, dtype=torch.long).to(device)


✅ Updated: BiGRUWithAttentionEWC_epoch_409.pth | Acc: 0.9744 | Saved to: Class_Incremental_CL/Classif_Bi_Dir_GRU_Model/Trained_models/EWC_CIL/Period_4/1st_try\BiGRUWithAttentionEWC_epoch_409.pth
✅ Updated: BiGRUWithAttentionEWC_epoch_417.pth | Acc: 0.9744 | Saved to: Class_Incremental_CL/Classif_Bi_Dir_GRU_Model/Trained_models/EWC_CIL/Period_4/1st_try\BiGRUWithAttentionEWC_epoch_417.pth
✅ Updated: BiGRUWithAttentionEWC_epoch_495.pth | Acc: 0.9744 | Saved to: Class_Incremental_CL/Classif_Bi_Dir_GRU_Model/Trained_models/EWC_CIL/Period_4/1st_try\BiGRUWithAttentionEWC_epoch_495.pth
✅ Updated: BiGRUWithAttentionEWC_epoch_496.pth | Acc: 0.9744 | Saved to: Class_Incremental_CL/Classif_Bi_Dir_GRU_Model/Trained_models/EWC_CIL/Period_4/1st_try\BiGRUWithAttentionEWC_epoch_496.pth
✅ Updated: BiGRUWithAttentionEWC_epoch_500.pth | Acc: 0.9744 | Saved to: Class_Incremental_CL/Classif_Bi_Dir_GRU_Model/Trained_models/EWC_CIL/Period_4/1st_try\BiGRUWithAttentionEWC_epoch_500.pth
✅ Updated: BiGRUWithAtten

  checkpoint = torch.load(path, map_location="cpu")


##  __Compute FWT & Model Size__

### Period 2

In [None]:
"""
- 'trend': Categorized trend values based on the detected phases:
    - 0: No trend
    - 1: Moderate negative trend
    - 2: Very strong negative trend
    - 3: Moderate positive trend
    - 4: Very strong positive trend
"""
with contextlib.redirect_stdout(open(os.devnull, 'w')):
    X_train, y_train, X_val, y_val, X_test, y_test, Number_features = process_and_return_splits(
        with_indicators_file_path = list_period_files_full_path[1], # Change
        downsampled_data_minutes = downsampled_data_minutes,
        exclude_columns = exclude_columns,
        lower_threshold = lower_threshold,
        upper_threshold = upper_threshold,
        reverse_steps = reverse_steps,
        sequence_length = sequence_length,
        sliding_interval = sliding_interval,
        trends_to_keep = {0, 1, 2}  # Default keeps all trends : {0, 1, 2, 3, 4}
        # trends_to_keep = {0, 1, 2, 3, 4}  # Default keeps all trends
    )

print(f"\nNumber_features = {Number_features}")

unique_classes = np.unique(y_val)
num_classes = len(unique_classes)
print(f"unique_classes = {unique_classes}")
print(f"num_classes = {num_classes}\n")

print_class_distribution(y_train, "y_train")
print_class_distribution(y_val, "y_val")
print_class_distribution(y_test, "y_test")



Number_features = 7
unique_classes = [0 1 2]
num_classes = 3

Class Distribution for 'y_train':       Class 0   Percent:  53.41% || Class 1   Percent:  36.03% || Class 2   Percent:  10.55%
Class Distribution for 'y_val':         Class 0   Percent:  52.72% || Class 1   Percent:  36.38% || Class 2   Percent:  10.90%
Class Distribution for 'y_test':        Class 0   Percent:  49.11% || Class 1   Percent:  32.27% || Class 2   Percent:  18.62%


In [None]:
# === 🧪 EWC Period 2 - FWT 分析 ===
input_size = Number_features
hidden_size = 64
output_size_prev = 2  # Period 1 的類別數量
num_layers = 4
dropout = 0.0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# === 📦 載入 Period 1 的 EWC 模型（含 Fisher） ===
ewc_sample_tag = "All"
period_1_epoch = 1987
model_name = 'BiGRUWithAttentionEWC'
best_model_path = os.path.normpath(
    os.path.join(
        'Class_Incremental_CL',
        f"Classif_Bi_Dir_GRU_Model/Trained_models/EWC_CIL/Period_1_weights_with_ewc/EWC_with_{ewc_sample_tag}_samples",
        f"{model_name}_epoch_{period_1_epoch}_ewc_{ewc_sample_tag}.pth"
    )
)
checkpoint = torch.load(best_model_path, map_location=device)

# === 🧱 previous_model（EWC 中只用 model_state_dict） ===
previous_model = BiGRUWithAttention(input_size, hidden_size, output_size_prev, num_layers, dropout).to(device)
previous_model.load_state_dict(checkpoint['model_state_dict'])
previous_model.output_size = output_size_prev
del checkpoint
gc.collect()

# === 🆕 初始化模型（不帶任何預訓練）===
init_model = BiGRUWithAttention(input_size, hidden_size, output_size_prev, num_layers, dropout).to(device)
init_model.output_size = output_size_prev

# === 📊 資料處理（與主流程一致）===
X_val_tensor = torch.tensor(X_val, dtype=torch.float32)
y_val_tensor = torch.tensor(y_val, dtype=torch.long)

# === ✅ 計算 FWT (Bitcoin 中 known_classes 是 [0, 1]) ===
known_classes = [1]
fwt, acc_prev, acc_init = compute_fwt_fixed_verbose(
    previous_model, init_model,
    X_val_tensor, y_val_tensor,
    known_classes
)

print(f"\n### Period 2:")
print(f"- FWT (EWC Period 2, old classes {known_classes}): {fwt * 100:.2f}%")
print(f"- Accuracy by previous model: {acc_prev * 100:.2f}%")
print(f"- Accuracy by init model:     {acc_init * 100:.2f}%")


  checkpoint = torch.load(best_model_path, map_location=device)


📋 Total matching sequences for known classes [1]: 454

### 🔍 FWT Debug Info:
- Total evaluated tokens: 165167
- Correct (PrevModel): 158370 / 165167 → Acc = 0.9588
- Correct (InitModel): 83192 / 165167 → Acc = 0.5037
- FWT = Acc_prev - Acc_init = 0.4552

### Period 2:
- FWT (EWC Period 2, old classes [1]): 45.52%
- Accuracy by previous model: 95.88%
- Accuracy by init model:     50.37%


In [None]:
# Model parameters
input_size = Number_features  # Number of features
hidden_size = 64  # Number of GRU units
output_size = num_classes # Must be dynamic, up to 5  # Number of trend classes (0, 15, 25, -15, -25)
num_layers = 4  # Number of GRU layers
dropout = 0.0
learning_rate = 0.0001 # For the optimizer # lr=0.00005
weight_decay=1e-5 # For the optimizer
use_scheduler = False
lambda_ewc=40 #<<<<<<<<<<<<0.4
# ---- ---- ---- #
num_epochs= 1 # Number of epochs/ go through entire data
batch_size= 64 # How many sequences passed at once to the model
model_name = 'BiGRUWithAttentionEWC' # Name of the model to use for saving
best_results = [] # Initialize this outside the training function or at the beginning of training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Define a global stop signal
stop_signal_file = os.path.normpath(os.path.join('Class_Incremental_CL', 'Classif_Bi_Dir_GRU_Model/stop_training.txt'))  # Create this file to stop training
model_saving_folder = os.path.normpath(os.path.join('Class_Incremental_CL', "Classif_Bi_Dir_GRU_Model/TT/EWC_CIL/Period_2/1st_try"))
ensure_folder(model_saving_folder)

# Instantiate the current period model (and load weights from previous period's best model)
current_model = BiGRUWithAttention(input_size, hidden_size, output_size, num_layers, dropout).to(device)
# Create independent copy (deep copy). 
# It will be used for Fisher Matrix computation for the current period's model
deepcopy = copy.deepcopy(current_model)
#-------------------------------------------------------------------------
period_1_epoch = 1987 # Epoch of the chosen Period 1 model
ewc_sample_tag = "All" # Options: "All" or "100" (must match the saved file)
best_model_dir = os.path.normpath(os.path.join('Class_Incremental_CL', f"Classif_Bi_Dir_GRU_Model/Trained_models/EWC_CIL/Period_1_weights_with_ewc/EWC_with_{ewc_sample_tag}_samples"))
best_model_path = os.path.normpath(os.path.join(best_model_dir, f"{model_name}_epoch_{period_1_epoch}_ewc_{ewc_sample_tag}.pth"))
#-------- OR ---------
# best_model_path = r"Class_Incremental_CL\Classif_Bi_Dir_GRU_Model\Trained_models\Baseline\Period_1\1st_try\BiGRUWithAttention_epoch_1987.pth"
#----------------------------------------------------------------------
# Initialize the list to store results across runs
track_across_runs = []
#----------------------------------------------------------------------
#-------------------------------------------------------------------------
prev_checkpoint_path = os.path.normpath(best_model_path)
prev_checkpoint = torch.load(prev_checkpoint_path, map_location=device, weights_only=True)
required_keys = ['model_state_dict', 'ewc_fisher', 'ewc_params']
for key in required_keys:
    if key not in prev_checkpoint:
        raise KeyError(f"Checkpoint {prev_checkpoint} is missing required key: '{key}'")
prev_model_dict = prev_checkpoint['model_state_dict']
ewc_fisher_dict = prev_checkpoint['ewc_fisher'] # Should be on CPU
ewc_params_dict = prev_checkpoint['ewc_params'] # Should be on CPU
print(f"Loaded previous period checkpoint from: \n\t{prev_checkpoint_path}")
# print(f"\nprev_checkpoint: \n{prev_checkpoint}\n")
#-------------------------------------------------------------------------

#----------------------------------------<<<<<<<<<<<<<<<
# --- !! START INSPECTION HERE !! ---
inspect_fisher_info(ewc_fisher_dict, label=f"Loaded Fisher Values ({ewc_sample_tag} Samples)")
# --- !! END INSPECTION !! ---
#----------------------------------------<<<<<<<<<<<<<<<

#-------------------------------------------------------------------------
# --- Transfer Compatible Weights from Previous Period Model to Current Period Model ---
print("Transferring compatible weights to the new model...")
current_model_dict = current_model.state_dict()
# Filter the loaded state dict: Keep only layers that exist in the current model AND have matching shapes
# Update only parameters with matching shapes (skip final fc if dimensions differ)
filtered_prev_state_dict = {
    k: v for k, v in prev_model_dict.items()
    if k in current_model_dict and v.size() == current_model_dict[k].size()
}
if not filtered_prev_state_dict:
     print("Warning: No compatible weights found to transfer. The new model will start from random initialization (except layers possibly initialized by default).")
else:
    # Load the filtered weights into the new model.
    # `strict=False` allows loading a partial state dict (ignoring missing keys like the final FC layer)
    missing_keys, unexpected_keys = current_model.load_state_dict(filtered_prev_state_dict, strict=False)
    print(f"  Weights loaded. Keys missing in loaded dict (expected: fc layer): {missing_keys}")
    print(f"  Keys in loaded dict but not in model (should be empty): {unexpected_keys}")
# Ensure the model is on the correct device *after* loading state dict
current_model.to(device)
print(f"\nCurrent Period Model Structure: \n{current_model}\n")
#-------------------------------------------------------------------------
# # --- Transfer Compatible Weights from Previous Period Model to Current Period Model ---
# current_model_dict = current_model.state_dict()
# # Filter the loaded state dict: Keep only layers that exist in the current model AND have matching shapes
# # Update only parameters with matching shapes (skip final fc if dimensions differ)
# prev_model_dict = {
#     k: v for k, v in prev_model_dict.items() 
#     if k in current_model_dict and v.size() == current_model_dict[k].size()
# }
# current_model_dict.update(prev_model_dict)
# current_model.load_state_dict(current_model_dict)
# # Ensure the model is on the correct device *after* loading state dict
# current_model.to(device)
# print(f"\nCurrent Period Model Structure: \n{current_model}\n")
#-------------------------------------------------------------------------
# --- Instantiate EWC Object for Training ---
# Use the Fisher matrix and optimal parameters loaded from the Previous Period checkpoint
print("Instantiating EWC object using loaded state...")
ewc_object = EWC(fisher=ewc_fisher_dict, params=ewc_params_dict)
# The fisher/params tensors are currently on CPU (as saved).
# The EWC.penalty method handles moving the param tensors to the model's device during calculation.
#-------------------------------------------------------------------------
# Define the loss function, optimizer and scheduler
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(current_model.parameters(), lr=learning_rate, weight_decay=weight_decay) # lr=0.00005
scheduler = None
if use_scheduler:
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.9, patience=10)
    # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
    scheduler_name = scheduler.__class__.__name__
    print(f"Using {scheduler_name} scheduler.\n")

# raise Exception("Stop here!") # <<<-----------------------------------------

print_model_info(current_model)

train_and_validate_ewc(
    model=current_model,
    deepcopy=deepcopy,
    output_size=output_size,
    criterion=criterion,
    optimizer=optimizer,
    X_train=X_train, y_train=y_train,
    X_val=X_val, y_val=y_val,
    scheduler=scheduler,
    use_scheduler=use_scheduler,
    ewc=ewc_object, # Pass the EWC object
    lambda_ewc=lambda_ewc, # Pass the EWC strength
    num_epochs=num_epochs,
    batch_size=batch_size,
    model_saving_folder=model_saving_folder,
    model_name=model_name, # Use base name for saved files in this period
    stop_signal_file=stop_signal_file
)


#----------------------------------------------------------------------
# Append only the best result (already at index 0)
track_across_runs.append(best_results[0])
#----------------------------------------------------------------------

for res in best_results:        
    print(f"Epoch {res['epoch']}/{num_epochs}, "
            f"Train Loss: {res['train_loss']:.4f}, " 
            f"Val Loss: {res['val_loss']:.4f}, "
            f"Val Accuracy: {res['val_accuracy'] * 100:.2f}%, "
            f"Model Path: {res['model_path']}")      
print(f"\nclass_gru_model with ewc (current_model): \n{current_model}")

print(f"\nunique_classes = {unique_classes}")
print(f"num_classes = {num_classes}")

for var in [
    "X_train", "y_train", "X_val", "y_val", "X_test", "y_test",
    "Number_features", "unique_classes", "num_classes",
    "current_model", "ewc_object",
    "prev_checkpoint", "prev_model_dict", "prev_checkpoint_path",
    "deepcopy", "filtered_prev_state_dict", "current_model_dict",
    "ewc_fisher_dict", "ewc_params_dict"
]:
    if var in locals():
        del locals()[var]
# --- Force garbage collection ---
gc.collect()
torch.cuda.empty_cache()


Loaded previous period checkpoint from: 
	Class_Incremental_CL/Classif_Bi_Dir_GRU_Model/Trained_models/EWC_CIL/Period_1_weights_with_ewc/EWC_with_All_samples/BiGRUWithAttentionEWC_epoch_1987_ewc_All.pth

--- Inspecting Loaded Fisher Values (All Samples) ---
Total Fisher Sum: 1.2840e+02
Overall Max Fisher Abs Value: 3.1890e+01
Overall Min Positive Fisher Value: 6.4935e-23
Number of layers with all-zero Fisher: 0 / 36
--- Finished Inspecting Loaded Fisher Values (All Samples) ---

Transferring compatible weights to the new model...
  Weights loaded. Keys missing in loaded dict (expected: fc layer): ['fc.weight', 'fc.bias']
  Keys in loaded dict but not in model (should be empty): []

Current Period Model Structure: 
BiGRUWithAttention(
  (gru): GRU(7, 64, num_layers=4, batch_first=True, bidirectional=True)
  (attention_fc): Linear(in_features=128, out_features=128, bias=True)
  (fc): Linear(in_features=128, out_features=3, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
)

Instanti

### Period 3

In [None]:
"""
- 'trend': Categorized trend values based on the detected phases:
    - 0: No trend
    - 1: Moderate negative trend
    - 2: Very strong negative trend
    - 3: Moderate positive trend
    - 4: Very strong positive trend
"""
with contextlib.redirect_stdout(open(os.devnull, 'w')):
    X_train, y_train, X_val, y_val, X_test, y_test, Number_features = process_and_return_splits(
        with_indicators_file_path = list_period_files_full_path[2], # Change 
        downsampled_data_minutes = downsampled_data_minutes,
        exclude_columns = exclude_columns,
        lower_threshold = lower_threshold,
        upper_threshold = upper_threshold,
        reverse_steps = reverse_steps,
        sequence_length = sequence_length,
        sliding_interval = sliding_interval,
        trends_to_keep = {0, 1, 2, 3}  # Default keeps all trends : {0, 1, 2, 3, 4}
        # trends_to_keep = {0, 1, 2, 3, 4}  # Default keeps all trends
    )

print(f"\nNumber_features = {Number_features}")

unique_classes = np.unique(y_val)
num_classes = len(unique_classes)
print(f"unique_classes = {unique_classes}")
print(f"num_classes = {num_classes}\n")

print_class_distribution(y_train, "y_train")
print_class_distribution(y_val, "y_val")
print_class_distribution(y_test, "y_test")



Number_features = 7
unique_classes = [0 1 2 3]
num_classes = 4

Class Distribution for 'y_train':       Class 0   Percent:  13.79% || Class 1   Percent:  37.59% || Class 2   Percent:   8.89% || Class 3   Percent:  39.73%
Class Distribution for 'y_val':         Class 0   Percent:  10.26% || Class 1   Percent:  46.19% || Class 2   Percent:   4.21% || Class 3   Percent:  39.34%
Class Distribution for 'y_test':        Class 0   Percent:  10.30% || Class 1   Percent:  43.50% || Class 2   Percent:   8.50% || Class 3   Percent:  37.70%


In [None]:
# === 🧪 EWC Period 3 - FWT 分析 ===
input_size = Number_features
hidden_size = 64
output_size_prev = 3  # Period 2: {0,1,2}
num_layers = 4
dropout = 0.0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === 📦 載入 Period 2 模型 checkpoint（僅取 model_state_dict）===
prev_checkpoint_path = "Class_Incremental_CL/Classif_Bi_Dir_GRU_Model/Trained_models/EWC_CIL/Period_2/1st_try/BiGRUWithAttentionEWC_epoch_56.pth"  # ← 請補上 epoch
checkpoint = torch.load(prev_checkpoint_path, map_location=device)

previous_model = BiGRUWithAttention(input_size, hidden_size, output_size_prev, num_layers, dropout).to(device)
previous_model.load_state_dict(checkpoint['model_state_dict'])
previous_model.output_size = output_size_prev
del checkpoint
gc.collect()

# === 初始化模型 ===
init_model = BiGRUWithAttention(input_size, hidden_size, output_size_prev, num_layers, dropout).to(device)
init_model.output_size = output_size_prev

# === 準備資料 ===
X_val_tensor = torch.tensor(X_val, dtype=torch.float32)
y_val_tensor = torch.tensor(y_val, dtype=torch.long)

known_classes = [1, 2]
fwt, acc_prev, acc_init = compute_fwt_fixed_verbose(
    previous_model, init_model,
    X_val_tensor, y_val_tensor,
    known_classes
)

print(f"\n### Period 3:")
print(f"- FWT (EWC Period 3, old classes {known_classes}): {fwt * 100:.2f}%")
print(f"- Accuracy by previous model: {acc_prev * 100:.2f}%")
print(f"- Accuracy by init model:     {acc_init * 100:.2f}%")


  checkpoint = torch.load(prev_checkpoint_path, map_location=device)


📋 Total matching sequences for known classes [1, 2]: 454

### 🔍 FWT Debug Info:
- Total evaluated tokens: 228825
- Correct (PrevModel): 222959 / 228825 → Acc = 0.9744
- Correct (InitModel): 74446 / 228825 → Acc = 0.3253
- FWT = Acc_prev - Acc_init = 0.6490

### Period 3:
- FWT (EWC Period 3, old classes [1, 2]): 64.90%
- Accuracy by previous model: 97.44%
- Accuracy by init model:     32.53%


In [None]:
# Model parameters
input_size = Number_features  # Number of features
hidden_size = 64  # Number of GRU units
output_size = num_classes # Must be dynamic, up to 5  # Number of trend classes (0, 15, 25, -15, -25)
num_layers = 4  # Number of GRU layers
dropout = 0.0
learning_rate = 0.0001 # For the optimizer # lr=0.00005
weight_decay=1e-5 # For the optimizer
use_scheduler = False
lambda_ewc=40 #<<<<<<<<<<<<0.4
# ---- ---- ---- #
num_epochs= 1 # Number of epochs/ go through entire data
batch_size= 64 # How many sequences passed at once to the model
model_name = 'BiGRUWithAttentionEWC' # Name of the model to use for saving
best_results = [] # Initialize this outside the training function or at the beginning of training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Define a global stop signal
stop_signal_file = os.path.normpath(os.path.join('Class_Incremental_CL', 'Classif_Bi_Dir_GRU_Model/stop_training.txt'))  # Create this file to stop training
model_saving_folder = os.path.normpath(os.path.join('Class_Incremental_CL', "Classif_Bi_Dir_GRU_Model/TT/EWC_CIL/Period_3/1st_try"))
ensure_folder(model_saving_folder)

# Instantiate the current period model (and load weights from previous period's best model)
current_model = BiGRUWithAttention(input_size, hidden_size, output_size, num_layers, dropout).to(device)
# Create independent copy (deep copy). 
# It will be used for Fisher Matrix computation for the current period's model
deepcopy = copy.deepcopy(current_model)
#-------------------------------------------------------------------------
best_overall = max(track_across_runs, key=lambda res: res['val_accuracy'])
best_model_path = best_overall['model_path']
del best_overall
#----------------------------------------------------------------------
# Initialize the list to store results across runs
track_across_runs = []
#----------------------------------------------------------------------
#-------------------------------------------------------------------------
prev_checkpoint_path = os.path.normpath(best_model_path)
prev_checkpoint = torch.load(prev_checkpoint_path, map_location=device, weights_only=True)
required_keys = ['model_state_dict', 'ewc_fisher', 'ewc_params']
for key in required_keys:
    if key not in prev_checkpoint:
        raise KeyError(f"Checkpoint {prev_checkpoint} is missing required key: '{key}'")
prev_model_dict = prev_checkpoint['model_state_dict']
ewc_fisher_dict = prev_checkpoint['ewc_fisher'] # Should be on CPU
ewc_params_dict = prev_checkpoint['ewc_params'] # Should be on CPU
print(f"Loaded previous period checkpoint from: \n\t{prev_checkpoint_path}")
# print(f"\nprev_checkpoint: \n{prev_checkpoint}\n")
#-------------------------------------------------------------------------

#----------------------------------------<<<<<<<<<<<<<<<
# --- !! START INSPECTION HERE !! ---
inspect_fisher_info(ewc_fisher_dict, label=f"Loaded Fisher Values (All Samples)")
# --- !! END INSPECTION !! ---
#----------------------------------------<<<<<<<<<<<<<<<

#-------------------------------------------------------------------------
# --- Transfer Compatible Weights from Previous Period Model to Current Period Model ---
print("Transferring compatible weights to the new model...")
current_model_dict = current_model.state_dict()
# Filter the loaded state dict: Keep only layers that exist in the current model AND have matching shapes
# Update only parameters with matching shapes (skip final fc if dimensions differ)
filtered_prev_state_dict = {
    k: v for k, v in prev_model_dict.items()
    if k in current_model_dict and v.size() == current_model_dict[k].size()
}
if not filtered_prev_state_dict:
     print("Warning: No compatible weights found to transfer. The new model will start from random initialization (except layers possibly initialized by default).")
else:
    # Load the filtered weights into the new model.
    # `strict=False` allows loading a partial state dict (ignoring missing keys like the final FC layer)
    missing_keys, unexpected_keys = current_model.load_state_dict(filtered_prev_state_dict, strict=False)
    print(f"  Weights loaded. Keys missing in loaded dict (expected: fc layer): {missing_keys}")
    print(f"  Keys in loaded dict but not in model (should be empty): {unexpected_keys}")
# Ensure the model is on the correct device *after* loading state dict
current_model.to(device)
print(f"\nCurrent Period Model Structure: \n{current_model}\n")
#-------------------------------------------------------------------------
# # --- Transfer Compatible Weights from Previous Period Model to Current Period Model ---
# current_model_dict = current_model.state_dict()
# # Filter the loaded state dict: Keep only layers that exist in the current model AND have matching shapes
# # Update only parameters with matching shapes (skip final fc if dimensions differ)
# prev_model_dict = {
#     k: v for k, v in prev_model_dict.items() 
#     if k in current_model_dict and v.size() == current_model_dict[k].size()
# }
# current_model_dict.update(prev_model_dict)
# current_model.load_state_dict(current_model_dict)
# # Ensure the model is on the correct device *after* loading state dict
# current_model.to(device)
# print(f"\nCurrent Period Model Structure: \n{current_model}\n")
#-------------------------------------------------------------------------
# --- Instantiate EWC Object for Training ---
# Use the Fisher matrix and optimal parameters loaded from the Previous Period checkpoint
print("Instantiating EWC object using loaded state...")
ewc_object = EWC(fisher=ewc_fisher_dict, params=ewc_params_dict)
# The fisher/params tensors are currently on CPU (as saved).
# The EWC.penalty method handles moving the param tensors to the model's device during calculation.
#-------------------------------------------------------------------------
# Define the loss function, optimizer and scheduler
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(current_model.parameters(), lr=learning_rate, weight_decay=weight_decay) # lr=0.00005
scheduler = None
if use_scheduler:
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.9, patience=10)
    # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
    scheduler_name = scheduler.__class__.__name__
    print(f"Using {scheduler_name} scheduler.\n")

# raise Exception("Stop here!") # <<<-----------------------------------------

print_model_info(current_model)

train_and_validate_ewc(
    model=current_model,
    deepcopy=deepcopy,
    output_size=output_size,
    criterion=criterion,
    optimizer=optimizer,
    X_train=X_train, y_train=y_train,
    X_val=X_val, y_val=y_val,
    scheduler=scheduler,
    use_scheduler=use_scheduler,
    ewc=ewc_object, # Pass the EWC object
    lambda_ewc=lambda_ewc, # Pass the EWC strength
    num_epochs=num_epochs,
    batch_size=batch_size,
    model_saving_folder=model_saving_folder,
    model_name=model_name, # Use base name for saved files in this period
    stop_signal_file=stop_signal_file
)


#----------------------------------------------------------------------
# Append only the best result (already at index 0)
track_across_runs.append(best_results[0])
#----------------------------------------------------------------------

for res in best_results:        
    print(f"Epoch {res['epoch']}/{num_epochs}, "
            f"Train Loss: {res['train_loss']:.4f}, " 
            f"Val Loss: {res['val_loss']:.4f}, "
            f"Val Accuracy: {res['val_accuracy'] * 100:.2f}%, "
            f"Model Path: {res['model_path']}")      
print(f"\nclass_gru_model with ewc (current_model): \n{current_model}")

print(f"\nunique_classes = {unique_classes}")
print(f"num_classes = {num_classes}")

for var in [
    "X_train", "y_train", "X_val", "y_val", "X_test", "y_test",
    "Number_features", "unique_classes", "num_classes",
    "current_model", "ewc_object",
    "prev_checkpoint", "prev_model_dict", "prev_checkpoint_path",
    "deepcopy", "filtered_prev_state_dict", "current_model_dict",
    "ewc_fisher_dict", "ewc_params_dict"
]:
    if var in locals():
        del locals()[var]
# --- Force garbage collection ---
gc.collect()
torch.cuda.empty_cache()


### Period 4

In [None]:
"""
- 'trend': Categorized trend values based on the detected phases:
    - 0: No trend
    - 1: Moderate negative trend
    - 2: Very strong negative trend
    - 3: Moderate positive trend
    - 4: Very strong positive trend
"""
with contextlib.redirect_stdout(open(os.devnull, 'w')):
    X_train, y_train, X_val, y_val, X_test, y_test, Number_features = process_and_return_splits(
        with_indicators_file_path = list_period_files_full_path[3], # Change 
        downsampled_data_minutes = downsampled_data_minutes,
        exclude_columns = exclude_columns,
        lower_threshold = lower_threshold,
        upper_threshold = upper_threshold,
        reverse_steps = reverse_steps,
        sequence_length = sequence_length,
        sliding_interval = sliding_interval,
        trends_to_keep = {0, 1, 2, 3, 4}  # Default keeps all trends : {0, 1, 2, 3, 4}
        # trends_to_keep = {0, 1, 2, 3, 4}  # Default keeps all trends
    )

print(f"\nNumber_features = {Number_features}")

unique_classes = np.unique(y_val)
num_classes = len(unique_classes)
print(f"unique_classes = {unique_classes}")
print(f"num_classes = {num_classes}\n")

print_class_distribution(y_train, "y_train")
print_class_distribution(y_val, "y_val")
print_class_distribution(y_test, "y_test")



Number_features = 7
unique_classes = [0 1 2 3 4]
num_classes = 5

Class Distribution for 'y_train':       Class 0   Percent:   6.47% || Class 1   Percent:  39.17% || Class 2   Percent:   7.24% || Class 3   Percent:  41.06% || Class 4   Percent:   6.06%
Class Distribution for 'y_val':         Class 0   Percent:   5.35% || Class 1   Percent:  36.15% || Class 2   Percent:  14.78% || Class 3   Percent:  34.02% || Class 4   Percent:   9.70%
Class Distribution for 'y_test':        Class 0   Percent:   6.63% || Class 1   Percent:  40.53% || Class 2   Percent:   5.48% || Class 3   Percent:  41.29% || Class 4   Percent:   6.08%


In [None]:
# === 🧪 EWC Period 4 - FWT 分析 ===
input_size = Number_features
hidden_size = 64
output_size_prev = 4  # Period 3: {0,1,2,3}
num_layers = 4
dropout = 0.0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === 📦 載入 Period 3 模型 checkpoint ===
prev_checkpoint_path = "Class_Incremental_CL/Classif_Bi_Dir_GRU_Model/Trained_models/EWC_CIL/Period_3/1st_try/BiGRUWithAttentionEWC_epoch_1854.pth"
checkpoint = torch.load(prev_checkpoint_path, map_location=device)

previous_model = BiGRUWithAttention(input_size, hidden_size, output_size_prev, num_layers, dropout).to(device)
previous_model.load_state_dict(checkpoint['model_state_dict'])
previous_model.output_size = output_size_prev
del checkpoint
gc.collect()

# === 初始化模型 ===
init_model = BiGRUWithAttention(input_size, hidden_size, output_size_prev, num_layers, dropout).to(device)
init_model.output_size = output_size_prev

# === 準備資料 ===
X_val_tensor = torch.tensor(X_val, dtype=torch.float32)
y_val_tensor = torch.tensor(y_val, dtype=torch.long)

known_classes = [1, 2, 3]
fwt, acc_prev, acc_init = compute_fwt_fixed_verbose(
    previous_model, init_model,
    X_val_tensor, y_val_tensor,
    known_classes
)

print(f"\n### Period 4:")
print(f"- FWT (EWC Period 4, old classes {known_classes}): {fwt * 100:.2f}%")
print(f"- Accuracy by previous model: {acc_prev * 100:.2f}%")
print(f"- Accuracy by init model:     {acc_init * 100:.2f}%")


  checkpoint = torch.load(prev_checkpoint_path, map_location=device)


📋 Total matching sequences for known classes [1, 2, 3]: 454

### 🔍 FWT Debug Info:
- Total evaluated tokens: 385661
- Correct (PrevModel): 374165 / 385661 → Acc = 0.9702
- Correct (InitModel): 92879 / 385661 → Acc = 0.2408
- FWT = Acc_prev - Acc_init = 0.7294

### Period 4:
- FWT (EWC Period 4, old classes [1, 2, 3]): 72.94%
- Accuracy by previous model: 97.02%
- Accuracy by init model:     24.08%


In [None]:
# Model parameters
input_size = 7  # Number of features
hidden_size = 64  # Number of GRU units
output_size = 5 # Must be dynamic, up to 5  # Number of trend classes (0, 15, 25, -15, -25)
num_layers = 4  # Number of GRU layers
dropout = 0.0
learning_rate = 0.0001 # For the optimizer # lr=0.00005
weight_decay=1e-5 # For the optimizer
use_scheduler = False
lambda_ewc=40 #<<<<<<<<<<<<0.4
# ---- ---- ---- #
num_epochs= 1 # Number of epochs/ go through entire data
batch_size= 64 # How many sequences passed at once to the model
model_name = 'BiGRUWithAttentionEWC' # Name of the model to use for saving
best_results = [] # Initialize this outside the training function or at the beginning of training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Define a global stop signal
stop_signal_file = os.path.normpath(os.path.join('Class_Incremental_CL', 'Classif_Bi_Dir_GRU_Model/stop_training.txt'))  # Create this file to stop training
model_saving_folder = os.path.normpath(os.path.join('Class_Incremental_CL', "Classif_Bi_Dir_GRU_Model/TT/EWC_CIL/Period_4/1st_try"))
ensure_folder(model_saving_folder)

# Instantiate the current period model (and load weights from previous period's best model)
current_model = BiGRUWithAttention(input_size, hidden_size, output_size, num_layers, dropout).to(device)
# Create independent copy (deep copy). 
# It will be used for Fisher Matrix computation for the current period's model
deepcopy = copy.deepcopy(current_model)
#-------------------------------------------------------------------------
# best_overall = max(track_across_runs, key=lambda res: res['val_accuracy'])
# best_model_path = best_overall['model_path']
# del best_overall
#----------------------------------------------------------------------
# Initialize the list to store results across runs
track_across_runs = []
#----------------------------------------------------------------------
#-------------------------------------------------------------------------
prev_checkpoint_path = os.path.normpath(r"Class_Incremental_CL/Classif_Bi_Dir_GRU_Model/Trained_models/EWC_CIL/Period_3/1st_try/BiGRUWithAttentionEWC_epoch_1854.pth")
prev_checkpoint = torch.load(prev_checkpoint_path, map_location=device, weights_only=True)
required_keys = ['model_state_dict', 'ewc_fisher', 'ewc_params']
for key in required_keys:
    if key not in prev_checkpoint:
        raise KeyError(f"Checkpoint {prev_checkpoint} is missing required key: '{key}'")
prev_model_dict = prev_checkpoint['model_state_dict']
ewc_fisher_dict = prev_checkpoint['ewc_fisher'] # Should be on CPU
ewc_params_dict = prev_checkpoint['ewc_params'] # Should be on CPU
print(f"Loaded previous period checkpoint from: \n\t{prev_checkpoint_path}")
# print(f"\nprev_checkpoint: \n{prev_checkpoint}\n")
#-------------------------------------------------------------------------

#----------------------------------------<<<<<<<<<<<<<<<
# --- !! START INSPECTION HERE !! ---
inspect_fisher_info(ewc_fisher_dict, label=f"Loaded Fisher Values (All Samples)")
# --- !! END INSPECTION !! ---
#----------------------------------------<<<<<<<<<<<<<<<

#-------------------------------------------------------------------------
# --- Transfer Compatible Weights from Previous Period Model to Current Period Model ---
print("Transferring compatible weights to the new model...")
current_model_dict = current_model.state_dict()
# Filter the loaded state dict: Keep only layers that exist in the current model AND have matching shapes
# Update only parameters with matching shapes (skip final fc if dimensions differ)
filtered_prev_state_dict = {
    k: v for k, v in prev_model_dict.items()
    if k in current_model_dict and v.size() == current_model_dict[k].size()
}
if not filtered_prev_state_dict:
     print("Warning: No compatible weights found to transfer. The new model will start from random initialization (except layers possibly initialized by default).")
else:
    # Load the filtered weights into the new model.
    # `strict=False` allows loading a partial state dict (ignoring missing keys like the final FC layer)
    missing_keys, unexpected_keys = current_model.load_state_dict(filtered_prev_state_dict, strict=False)
    print(f"  Weights loaded. Keys missing in loaded dict (expected: fc layer): {missing_keys}")
    print(f"  Keys in loaded dict but not in model (should be empty): {unexpected_keys}")
# Ensure the model is on the correct device *after* loading state dict
current_model.to(device)
print(f"\nCurrent Period Model Structure: \n{current_model}\n")
#-------------------------------------------------------------------------
# # --- Transfer Compatible Weights from Previous Period Model to Current Period Model ---
# current_model_dict = current_model.state_dict()
# # Filter the loaded state dict: Keep only layers that exist in the current model AND have matching shapes
# # Update only parameters with matching shapes (skip final fc if dimensions differ)
# prev_model_dict = {
#     k: v for k, v in prev_model_dict.items() 
#     if k in current_model_dict and v.size() == current_model_dict[k].size()
# }
# current_model_dict.update(prev_model_dict)
# current_model.load_state_dict(current_model_dict)
# # Ensure the model is on the correct device *after* loading state dict
# current_model.to(device)
# print(f"\nCurrent Period Model Structure: \n{current_model}\n")
#-------------------------------------------------------------------------
# --- Instantiate EWC Object for Training ---
# Use the Fisher matrix and optimal parameters loaded from the Previous Period checkpoint
print("Instantiating EWC object using loaded state...")
ewc_object = EWC(fisher=ewc_fisher_dict, params=ewc_params_dict)
# The fisher/params tensors are currently on CPU (as saved).
# The EWC.penalty method handles moving the param tensors to the model's device during calculation.
#-------------------------------------------------------------------------
# Define the loss function, optimizer and scheduler
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(current_model.parameters(), lr=learning_rate, weight_decay=weight_decay) # lr=0.00005
scheduler = None
if use_scheduler:
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.9, patience=10)
    # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
    scheduler_name = scheduler.__class__.__name__
    print(f"Using {scheduler_name} scheduler.\n")

# raise Exception("Stop here!") # <<<-----------------------------------------
print_model_info(current_model)

train_and_validate_ewc(
    model=current_model,
    deepcopy=deepcopy,
    output_size=output_size,
    criterion=criterion,
    optimizer=optimizer,
    X_train=X_train, y_train=y_train,
    X_val=X_val, y_val=y_val,
    scheduler=scheduler,
    use_scheduler=use_scheduler,
    ewc=ewc_object, # Pass the EWC object
    lambda_ewc=lambda_ewc, # Pass the EWC strength
    num_epochs=num_epochs,
    batch_size=batch_size,
    model_saving_folder=model_saving_folder,
    model_name=model_name, # Use base name for saved files in this period
    stop_signal_file=stop_signal_file
)


#----------------------------------------------------------------------
# Append only the best result (already at index 0)
track_across_runs.append(best_results[0])
#----------------------------------------------------------------------

for res in best_results:        
    print(f"Epoch {res['epoch']}/{num_epochs}, "
            f"Train Loss: {res['train_loss']:.4f}, " 
            f"Val Loss: {res['val_loss']:.4f}, "
            f"Val Accuracy: {res['val_accuracy'] * 100:.2f}%, "
            f"Model Path: {res['model_path']}")      
print(f"\nclass_gru_model with ewc (current_model): \n{current_model}")

print(f"\nunique_classes = {unique_classes}")
print(f"num_classes = {num_classes}")

for var in [
    "X_train", "y_train", "X_val", "y_val", "X_test", "y_test",
    "Number_features", "unique_classes", "num_classes",
    "current_model", "ewc_object",
    "prev_checkpoint", "prev_model_dict", "prev_checkpoint_path",
    "deepcopy", "filtered_prev_state_dict", "current_model_dict",
    "ewc_fisher_dict", "ewc_params_dict"
]:
    if var in locals():
        del locals()[var]
# --- Force garbage collection ---
gc.collect()
torch.cuda.empty_cache()


Loaded previous period checkpoint from: 
	Class_Incremental_CL/Classif_Bi_Dir_GRU_Model/Trained_models/EWC_CIL/Period_3/1st_try/BiGRUWithAttentionEWC_epoch_1854.pth

--- Inspecting Loaded Fisher Values (All Samples) ---
Total Fisher Sum: 2.6648e+02
Overall Max Fisher Abs Value: 5.5683e+01
Overall Min Positive Fisher Value: 1.8909e-21
Number of layers with all-zero Fisher: 0 / 36
--- Finished Inspecting Loaded Fisher Values (All Samples) ---

Transferring compatible weights to the new model...
  Weights loaded. Keys missing in loaded dict (expected: fc layer): ['fc.weight', 'fc.bias']
  Keys in loaded dict but not in model (should be empty): []

Current Period Model Structure: 
BiGRUWithAttention(
  (gru): GRU(7, 64, num_layers=4, batch_first=True, bidirectional=True)
  (attention_fc): Linear(in_features=128, out_features=128, bias=True)
  (fc): Linear(in_features=128, out_features=5, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
)

Instantiating EWC object using loaded state...

---

## 📊 Comprehensive Test Results:
### (`EWC_CL_Classif_Bi_Dir_GRU_Model.ipynb`)

| Period | Model & Config                                                                                  | Validation Accuracy | Class-wise Accuracy                                                   |
|--------|--------------------------------------------------------------------------------------------------|---------------------|------------------------------------------------------------------------|
| 1      | `BiGRUWithAttention`<br>(num_layers=4)<br>saved in `'1st_try'`                                  | **98.35%**          | {0: 98.63%, 1: 97.92%}                                                |
| 2      | `BiGRUWithAttention`<br>(lambda_ewc=40, learning_rate=0.0001)<br>saved in `'1st_try'`           | **98.00%**          | {0: 99.70%, 1: 97.60%, 2: 91.06%}                                     |
| 3      | `BiGRUWithAttention`<br>(lambda_ewc=40, learning_rate=0.0001)<br>saved in `'1st_try'`           | **97.32%**          | {0: 89.81%, 1: 99.00%, 2: 90.93%, 3: 98.00%}                          |
| 4      | `BiGRUWithAttention`<br>(lambda_ewc=40, learning_rate=0.0001)<br>saved in `'1st_try'`           | **97.44%**          | {0: 95.14%, 1: 98.82%, 2: 91.17%, 3: 98.53%, 4: 99.30%}               |

---


## 📊 Summary: 

### ✔️ Finance - EWC: Validation Summary

| Period | Validation Accuracy | Class-wise Accuracy                                                   |
|--------|---------------------|------------------------------------------------------------------------|
| 1      | **98.35%**          | {0: 98.63%, 1: 97.92%}                                                |
| 2      | **98.00%**          | {0: 99.70%, 1: 97.60%, 2: 91.06%}                                     |
| 3      | **97.32%**          | {0: 89.81%, 1: 99.00%, 2: 90.93%, 3: 98.00%}                          |
| 4      | **97.44%**          | {0: 95.14%, 1: 98.82%, 2: 91.17%, 3: 98.53%, 4: 99.30%}               |


### 🧠 Continual Learning Metrics

| Period | AA_old (%) | AA_new (%) | BWT (%) | FWT (%) | FWT Classes     | Prev. Model Acc | Init Model Acc |
|--------|------------|------------|---------|---------|------------------|------------------|-----------------|
| 2      | 98.65%     | 91.06%     | +0.37%  | 45.52%  | [1]              | 95.88%           | 50.37%          |
| 3      | 93.25%     | 98.00%     | -2.87%  | 64.90%  | [1, 2]           | 97.44%           | 32.53%          |
| 4      | 95.92%     | 99.30%     | +1.48%  | 72.94%  | [1, 2, 3]        | 97.02%           | 24.08%          |


### 📦 Model Size per Period

| Period | Output Size | Total Params | Δ Params vs Prev | Δ % vs Prev | Model Size (float32) |
|--------|-------------|--------------|------------------|-------------|-----------------------|
| 1      | 2           | 268,290      | —                | —           | 1.02 MB               |
| 2      | 3           | 268,419      | +129             | +0.05%      | 1.02 MB               |
| 3      | 4           | 268,548      | +129             | +0.05%      | 1.02 MB               |
| 4      | 5           | 268,677      | +129             | +0.05%      | 1.02 MB               |

**📈 Model Growth Rate (MGR) = (268,677 - 268,290) / (268,290 × 3) ≈ +0.05%**

**📈 Max trainable ratio = 100%**