In [1]:
import timesfm
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error, mean_squared_error
import warnings
import numpy as np
import os
# Import configuration and reusable functions
from config_times_fm import TimesFmConfig

TimesFM v1.2.0. See https://github.com/google-research/timesfm/blob/master/README.md for updated APIs.
Loaded Jax TimesFM.
Loaded PyTorch TimesFM.


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Initialize TimesFm model using configuration
tfm = timesfm.TimesFm(
    hparams=timesfm.TimesFmHparams(
        backend=TimesFmConfig.BACKEND,
        per_core_batch_size=TimesFmConfig.PER_CORE_BATCH_SIZE,
        horizon_len=TimesFmConfig.HORIZON_LEN,
        num_layers=TimesFmConfig.NUM_LAYERS,
        model_dims=TimesFmConfig.MODEL_DIMS,
        context_len=TimesFmConfig.CONTEXT_LEN
    ),
    checkpoint=timesfm.TimesFmCheckpoint(huggingface_repo_id=TimesFmConfig.CHECKPOINT_REPO)
)

Fetching 5 files: 100%|██████████| 5/5 [00:00<?, ?it/s]


In [6]:
# Define reusable plotting function
def plot_forecast(actual_data, forecast_data, title, xlabel='Date', ylabel='Price ($)'):
    plt.figure(figsize=(18, 6))
    plt.plot(actual_data['ds'], actual_data['y'], color='green', label='Actual')
    plt.plot(forecast_data['ds'], forecast_data['timesfm'], color='red', linestyle='--', label='Predicted')
    plt.title(title, fontsize=14)
    plt.xlabel(xlabel, fontsize=10)
    plt.ylabel(ylabel, fontsize=10)
    plt.xticks(rotation=45)
    plt.legend(frameon=True, shadow=True)
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.tight_layout()
    plt.show()

In [7]:
# Function to process a single dataset
def process_dataset(file_path):
    # Extract ticker symbol and interval from file name
    filename = os.path.basename(file_path)
    parts = filename.split('_')
    ticker = parts[0]
    interval = '1H'  # Default to 1H if not specified
    if len(parts) > 1 and parts[1].endswith('.csv'):
        interval = parts[1].replace('.csv', '')
    
    print(f'Processing {ticker} dataset with {interval} interval...')
    
    # Load the dataset
    df = pd.read_csv(file_path)
    df['Datetime'] = pd.to_datetime(df['Datetime'])
    
    # Format dataframe for TimesFM
    input_df = pd.DataFrame({
        'unique_id': [1] * len(df),
        'ds': df['Datetime'].values.astype('datetime64[ns]'), 
        'y': df['Close']
    })
    
    # Config
    context_window = 2048
    forecast_horizon = 128
    max_start = len(input_df) - context_window - forecast_horizon
    
    if max_start < 0:
        print(f'Warning: {ticker} dataset too small for forecasting with current window sizes')
        return
    
    # Define the starting points for each backtesting window
    backtest_starts = list(range(0, max_start + 1, forecast_horizon))
    
    # Create results directory for this ticker if it doesn't exist
    ticker_results_dir = os.path.join('results', f'timesfm_{interval}_{ticker}')
    os.makedirs(ticker_results_dir, exist_ok=True)
    
    # Loop through each backtesting window
    for idx, start_idx in enumerate(backtest_starts):
        print(f'Processing window {idx+1}/{len(backtest_starts)}...')
        
        context_end = start_idx + context_window
        context_data = input_df.iloc[start_idx:context_end]
        
        forecast_df = tfm.forecast_on_df(
            context_data,
            freq='h',
            value_name='y',
            num_jobs=-1
        )[:forecast_horizon]
        
        # Align the forecast with the actual data
        actual_start = context_end
        actual_end = actual_start + forecast_horizon
        actual_data = input_df.iloc[actual_start:actual_end]
        forecast_df['ds'] = actual_data['ds'].values
        
        # Calculate metrics
        mae = mean_absolute_error(actual_data['y'], forecast_df['timesfm'])
        mse = mean_squared_error(actual_data['y'], forecast_df['timesfm'])
        rmse = np.sqrt(mse)
        
        # Save results to CSV
        results_df = pd.DataFrame({
            'date': actual_data['ds'],
            'actual': actual_data['y'].values,
            'forecast': forecast_df['timesfm'].values
        })
        csv_path = os.path.join(ticker_results_dir, f'{ticker}_window_{idx+1}.csv')
        results_df.to_csv(csv_path, index=False)
        
        # Plot and save the forecast for this window
        plt.figure(figsize=(18, 6))
        plt.plot(actual_data['ds'], actual_data['y'], color='green', label='Actual')
        plt.plot(forecast_df['ds'], forecast_df['timesfm'], color='red', linestyle='--', marker='o', markersize=3, label='Predicted')
        title = f'{ticker} ({interval}) - Window {idx+1} Forecast (MAE: {mae:.4f}, RMSE: {rmse:.4f})'
        plt.title(title, fontsize=14)
        plt.xlabel('Date', fontsize=10)
        plt.ylabel('Price ($)', fontsize=10)
        plt.xticks(rotation=45)
        plt.legend(frameon=True, shadow=True)
        plt.grid(True, linestyle='--', alpha=0.6)
        plt.tight_layout()
        
        # Save the plot
        plot_path = os.path.join(ticker_results_dir, f'{ticker}_window_{idx+1}.png')
        plt.savefig(plot_path, dpi=300)
        plt.show()  # Display the plot
        
        print(f'  MAE: {mae:.4f}, RMSE: {rmse:.4f}')

## Process 1H Data

In [8]:
# Define 1H data directory
data_dir_1h = os.path.join(os.getcwd(), "data", "1H")

# Get list of all CSV files in the 1H directory
csv_files_1h = [os.path.join(data_dir_1h, f) for f in os.listdir(data_dir_1h) if f.endswith('.csv')]

if not csv_files_1h:
    print(f"No CSV files found in {data_dir_1h}")
else:
    print(f"Found {len(csv_files_1h)} CSV files in 1H directory:")
    for file in csv_files_1h:
        print(f"- {os.path.basename(file)}")

Found 11 CSV files in 1H directory:
- INTC_1H.csv
- IONQ_1H.csv
- MSTR_1H.csv
- MU_1H.csv
- NVDA_1H.csv
- QBTS_1H.csv
- RGTI_1H.csv
- SMCI_1H.csv
- SRPT_1H.csv
- TSLA_1H.csv
- VKTX_1H.csv


In [None]:

# Combine all forecasts into single DataFrames
# full_forecast = pd.concat(all_forecasts)
# full_actual = pd.concat(all_actuals)

# Plot the aggregated forecast
# plt.figure(figsize=(18, 6))
# plt.plot(full_actual['ds'], full_actual['y'], color='#069d12', label='Actual')
# plt.plot(full_forecast['ds'], full_forecast['timesfm'], color='#e32227', linestyle='--', label='Predicted')
# plt.title('Aggregated Forecast: Actual vs Predicted ($INTC)', fontsize=16)
# plt.xlabel('Date', fontsize=12)
# plt.ylabel('Price ($)', fontsize=12)
# plt.xticks(rotation=45, ha='right')
# plt.legend()
# plt.grid(True, linestyle='--', alpha=0.6)
# plt.tight_layout()
# plt.show()