Install dependencies

Grab ALL crypto data available and compress it to parquets in storage

In [None]:
!pip install ccxt yfinance pytrends tweepy praw newsapi-python aiohttp scipy wbdata tqdm pyarrow db-sqlite3 pyyaml tensorflow.keras

The project represents a sophisticated crypto data warehouse system with several key components that can be leveraged for advanced algorithmic trading. Let me break down how we can enhance this system using various machine learning approaches.

Data Collection and Storage Architecture

The current system efficiently collects data from multiple sources:

Market data (CCXT, YFinance)
Social sentiment (Twitter, Reddit)
Technical indicators
On-chain metrics

The ParquetManager and MetadataManager provide a robust foundation for handling large-scale data storage. We can enhance this by implementing a feature store layer for machine learning:

Enhanced Feature Store Implementation

In [None]:
from typing import Dict, List, Optional
import polars as pl
from datetime import datetime, timedelta

class MLFeatureStore:
    def __init__(self, config):
        self.parquet_manager = ParquetManager(config)
        self.feature_pipeline = MemoryOptimizedFeaturePipeline(config)

    async def create_training_dataset(
        self,
        lookback_window: int,
        prediction_horizon: int,
        feature_groups: List[str]
    ) -> pl.DataFrame:
        """Create a dataset optimized for ML training."""
        features = []

        # Market data features
        if 'market' in feature_groups:
            market_data = await self._load_market_features(lookback_window)
            features.append(market_data)

        # Sentiment features
        if 'sentiment' in feature_groups:
            sentiment_data = await self._load_sentiment_features(lookback_window)
            features.append(sentiment_data)

        # Technical indicators
        if 'technical' in feature_groups:
            technical_data = await self._compute_technical_features(lookback_window)
            features.append(technical_data)

        # Combine features and create labels
        combined = pl.concat(features)
        labels = self._generate_labels(combined, prediction_horizon)

        return combined.join(labels, on='timestamp')

    async def _load_market_features(self, lookback_window: int) -> pl.DataFrame:
        """Load and prepare market data features."""
        market_data = await self.parquet_manager.load_data(
            self.config.data_dir / 'processed' / 'market_features.parquet'
        )

        # Create lagged features
        for lag in range(1, lookback_window + 1):
            market_data = market_data.with_columns([
                pl.col('close').shift(lag).alias(f'close_lag_{lag}'),
                pl.col('volume').shift(lag).alias(f'volume_lag_{lag}'),
                pl.col('volatility').shift(lag).alias(f'volatility_lag_{lag}')
            ])

        return market_data

    def _generate_labels(
        self,
        df: pl.DataFrame,
        horizon: int
    ) -> pl.DataFrame:
        """Generate multi-horizon prediction labels."""
        return df.select([
            'timestamp',
            (pl.col('close').shift(-horizon) / pl.col('close') - 1)
                .alias(f'return_{horizon}'),
            (pl.col('close').shift(-horizon) > pl.col('close'))
                .alias(f'direction_{horizon}')
        ])

In [None]:
# --------------------------------------------------------------
# Grab Raw Data
# --------------------------------------------------------------
import os
import pandas as pd
import pyarrow.parquet as pq
import pyarrow as pa
import requests
import ccxt
import yfinance as yf
import logging
import time
import random
import concurrent.futures
import tweepy
import praw
from newsapi import NewsApiClient
import aiohttp
import asyncio
import wbdata  # World Bank Data API

from google.colab import drive

# Define Google Drive Directory
GOOGLE_DRIVE_DIR = "/content/drive/MyDrive/data_warehouse"

# Set the threshold for old data (in days)
DATA_EXPIRY_DAYS = 30

# --------------------------------------------------------------
# Logging Configuration
# --------------------------------------------------------------
logging.basicConfig(
    filename="data_fetching.log",
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    filemode='a'
)

# --------------------------------------------------------------
# Google Drive Mounting
# --------------------------------------------------------------
def mount_google_drive(max_retries=5):
    """Mount Google Drive with a retry mechanism."""
    for attempt in range(max_retries):
        try:
            drive.mount('/content/drive', force_remount=True)
            logging.info("Google Drive mounted successfully.")
            return True
        except Exception as e:
            logging.error(f"Drive mount attempt {attempt + 1} failed: {e}")
            if attempt < max_retries - 1:
                time.sleep(2 ** attempt)
            else:
                logging.critical("Failed to mount Google Drive after several attempts.")
                return False

# --------------------------------------------------------------
# Utility Functions
# --------------------------------------------------------------
def retry_on_failure(max_retries=3, delay=2, backoff=2):
    """Decorator to retry a function on failure with exponential backoff."""
    def decorator(func):
        def wrapper(*args, **kwargs):
            for attempt in range(max_retries):
                try:
                    result = func(*args, **kwargs)
                    logging.info(f"Function {func.__name__} executed successfully on attempt {attempt + 1}.")
                    return result
                except Exception as e:
                    logging.error(f"Attempt {attempt + 1} failed in function {func.__name__}: {e}")
                    if attempt < max_retries - 1:
                        sleep_time = delay * (backoff ** attempt) + random.uniform(0, 1)
                        logging.info(f"Retrying {func.__name__} after {sleep_time:.2f} seconds...")
                        time.sleep(sleep_time)
                    else:
                        logging.critical(f"Max retries reached. Function {func.__name__} failed.")
                        raise e
        return wrapper
    return decorator

# --------------------------------------------------------------
# Saving Data to Compressed Parquet with Snappy
# --------------------------------------------------------------
@retry_on_failure()
def save_to_parquet(df, file_name, dir_path):
    """Save DataFrame to a compressed Parquet file with Snappy compression."""
    try:
        if df.empty:
            raise ValueError(f"DataFrame is empty. Skipping file: {file_name}")
        if df.isnull().values.any():
            raise ValueError(f"DataFrame contains NaN values. Skipping file: {file_name}")

        file_path = os.path.join(dir_path, file_name)
        os.makedirs(dir_path, exist_ok=True)  # Ensure directory exists
        table = pa.Table.from_pandas(df)

        # Save Parquet with Snappy compression
        pq.write_table(table, file_path, compression='snappy')
        logging.info(f"Data saved to {file_path} with Snappy compression")
        print(f"File successfully saved: {file_path}")
    except Exception as e:
        logging.error(f"Failed to save data to {file_name}: {e}")
        raise e

# --------------------------------------------------------------
# API Initialization Functions (Twitter, Reddit, NewsAPI)
# --------------------------------------------------------------
def init_twitter_api():
    api_key = "SICXIAhJiu1EZs6x2LT4ZYhoY"
    api_secret_key = "d1lbP2oNNvMMO5173KBhpOqRpgr8pH9wycUU0hA5oVTeqUKPWi"
    auth = tweepy.AppAuthHandler(api_key, api_secret_key)
    api = tweepy.API(auth)
    return api

def init_reddit_api():
    client_id = "GusCH1g6lbueOv4cjUylfA"
    client_secret = "Gwwl2xRAqz0FyG4XiQFPa5YzTIYcKw"
    user_agent = "Sentiment"
    reddit = praw.Reddit(client_id=client_id, client_secret=client_secret, user_agent=user_agent)
    return reddit

# --------------------------------------------------------------
# Data Fetching Functions (API-specific)
# --------------------------------------------------------------
@retry_on_failure()
def fetch_twitter_data(keywords, max_tweets=100):
    """Fetch Twitter data for given keywords."""
    try:
        api = init_twitter_api()
        for keyword in keywords:
            tweets = tweepy.Cursor(api.search_tweets, q=keyword, lang="en").items(max_tweets)
            tweet_data = [[tweet.created_at, tweet.text, tweet.user.screen_name] for tweet in tweets]
            df = pd.DataFrame(tweet_data, columns=['created_at', 'text', 'user'])
            file_name = f"twitter_{keyword}.parquet"
            save_to_parquet(df, file_name, GOOGLE_DRIVE_DIR)
    except Exception as e:
        logging.error(f"Failed to fetch Twitter data for {keywords}: {e}")

@retry_on_failure()
def fetch_reddit_data(subreddits, max_posts=100):
    """Fetch Reddit data from specified subreddits."""
    try:
        reddit = init_reddit_api()
        for subreddit_name in subreddits:
            subreddit = reddit.subreddit(subreddit_name)
            posts = subreddit.top(limit=max_posts)
            post_data = [[post.created_utc, post.title, post.selftext, post.score] for post in posts]
            df = pd.DataFrame(post_data, columns=['created_utc', 'title', 'selftext', 'score'])
            df['created_utc'] = pd.to_datetime(df['created_utc'], unit='s')
            file_name = f"reddit_{subreddit_name}.parquet"
            save_to_parquet(df, file_name, GOOGLE_DRIVE_DIR)
    except Exception as e:
        logging.error(f"Failed to fetch Reddit data for {subreddits}: {e}")

@retry_on_failure()
def fetch_alpha_vantage_data(symbol, interval='1min', outputsize='compact'):
    """Fetch financial data from Alpha Vantage for a given symbol."""
    try:
        api_key = "R95N2BARMVWY25VT"
        url = f"https://www.alphavantage.co/query?function=TIME_SERIES_INTRADAY&symbol={symbol}&interval={interval}&outputsize={outputsize}&apikey={api_key}"
        response = requests.get(url)
        data = response.json()
        time_series = data.get(f'Time Series ({interval})', {})
        df = pd.DataFrame(time_series).transpose().reset_index()
        df.columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
        df['timestamp'] = pd.to_datetime(df['timestamp'])
        file_name = f"alpha_vantage_{symbol}_{interval}.parquet"
        save_to_parquet(df, file_name, GOOGLE_DRIVE_DIR)
    except Exception as e:
        logging.error(f"Failed to fetch Alpha Vantage data for {symbol}: {e}")

@retry_on_failure()
def fetch_world_bank_data(indicator, country='all', start_date='2000', end_date='2023'):
    """Fetch global development data from the World Bank API."""
    try:
        df = wbdata.get_dataframe({indicator: indicator}, country=country, data_date=(start_date, end_date))
        df.reset_index(inplace=True)
        file_name = f"world_bank_{indicator}.parquet"
        save_to_parquet(df, file_name, GOOGLE_DRIVE_DIR)
    except Exception as e:
        logging.error(f"Failed to fetch World Bank data for {indicator}: {e}")

@retry_on_failure()
def fetch_imf_data(database_id, indicator, country='all', start_year='2000', end_year='2023'):
    """Fetch global economic data from the IMF API."""
    try:
        url = f"http://www.bd-econ.com/imfapi1.html?db={database_id}&series={indicator}&start_year={start_year}&end_year={end_year}&country={country}"
        response = requests.get(url)
        data = response.json()
        df = pd.DataFrame(data['CompactData']['DataSet']['Series']['Obs'])
        file_name = f"imf_{indicator}.parquet"
        save_to_parquet(df, file_name, GOOGLE_DRIVE_DIR)
    except Exception as e:
        logging.error(f"Failed to fetch IMF data for {indicator}: {e}")

# --------------------------------------------------------------
# Asynchronous Fetching for CoinMarketCap and CoinGecko
# --------------------------------------------------------------
async def fetch_async(url, headers=None, params=None):
    """Asynchronous HTTP requests for data fetching."""
    async with aiohttp.ClientSession() as session:
        async with session.get(url, headers=headers, params=params) as response:
            if response.status == 200:
                return await response.json()
            else:
                print(f"Failed to fetch data: HTTP {response.status}")
                return None

async def fetch_and_save_async(url, headers, params, file_name, dir_path):
    """Asynchronous fetch and save data to Parquet."""
    data = await fetch_async(url, headers=headers, params=params)
    if data:
        df = pd.DataFrame(data['data'])
        save_to_parquet(df, file_name, dir_path)

async def fetch_coinmarketcap_data_async():
    url = "https://pro-api.coinmarketcap.com/v1/cryptocurrency/listings/latest"
    headers = {
        'Accepts': 'application/json',
        'X-CMC_PRO_API_KEY': 'your_api_key_here',  # Replace with your actual API key
    }
    params = {
        'start': '1',
        'limit': '250',
        'convert': 'USD'
    }
    file_name = "coinmarketcap_data.parquet"
    await fetch_and_save_async(url, headers, params, file_name, GOOGLE_DRIVE_DIR)

async def fetch_coingecko_data_async():
    url = "https://api.coingecko.com/api/v3/coins/markets"
    params = {
        'vs_currency': 'usd',
        'order': 'market_cap_desc',
        'per_page': '250',
        'page': '1',
        'sparkline': 'false'
    }
    file_name = "coingecko_data.parquet"
    await fetch_and_save_async(url, None, params, file_name, GOOGLE_DRIVE_DIR)

# --------------------------------------------------------------
# CCXT Data Fetching with Subfolder Management
# --------------------------------------------------------------
@retry_on_failure()
def fetch_data_for_exchange(exchange_id, timeframes=['1m', '5m', '15m', '1h', '1d'], limit=1000, max_failures=10):
    """Fetch OHLCV data for all markets on a specified exchange using CCXT."""
    try:
        exchange = getattr(ccxt, exchange_id)()
        markets = exchange.load_markets()
        failure_count = 0

        for symbol in markets:
            if failure_count >= max_failures:
                logging.warning(f"Reached {max_failures} consecutive failures for {exchange_id}. Moving to next exchange.")
                break

            for timeframe in timeframes:
                if exchange.has['fetchOHLCV']:
                    try:
                        dir_path = os.path.join(GOOGLE_DRIVE_DIR, "CCXT", exchange_id, symbol.replace("/", "_"), timeframe)
                        file_name = f"{exchange_id}_{symbol.replace('/', '_')}_{timeframe}.parquet"
                        if os.path.exists(os.path.join(dir_path, file_name)):
                            continue  # Skip if file already exists
                        ohlcv = exchange.fetch_ohlcv(symbol, timeframe, limit=limit)
                        df = pd.DataFrame(ohlcv, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
                        df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
                        save_to_parquet(df, file_name, dir_path)
                        failure_count = 0
                    except Exception as e:
                        logging.error(f"Failed to fetch data from {exchange_id} for {symbol} with {timeframe} timeframe.")
                        failure_count += 1
                        if failure_count >= max_failures:
                            logging.warning(f"Skipping {exchange_id} after {failure_count} consecutive failures.")
                            break

        time.sleep(1)
    except Exception as e:
        logging.error(f"Failed to load markets for {exchange_id}: {e}")

def fetch_all_ccxt_data_parallel(timeframes=['1m', '5m', '15m', '1h', '1d'], limit=1000, max_failures=10, max_workers=10):
    """Fetch data from all exchanges in parallel using CCXT."""
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = []
        for exchange_id in ccxt.exchanges:
            futures.append(executor.submit(fetch_data_for_exchange, exchange_id, timeframes, limit, max_failures))
        for future in concurrent.futures.as_completed(futures):
            try:
                future.result()
            except Exception as e:
                logging.error(f"An error occurred during parallel execution: {e}")

# --------------------------------------------------------------
# YFinance Data Fetching (Dynamic Tickers)
# --------------------------------------------------------------
@retry_on_failure()
def fetch_yfinance_data():
    """Fetch financial data from Yahoo Finance with dynamic tickers and subfoldering."""
    asset_types = {
        'indices': '^GSPC ^IXIC ^DJI ^RUT ^VIX ^FTSE ^N225 ^HSI ^SSEC ^DAX ^FCHI ^BSESN ^STI ^AORD ^KS11 ^TWII ^GSPTSE'.split(),
        'commodities': 'GC=F SI=F CL=F NG=F HG=F PL=F PA=F ZW=F ZC=F ZS=F LE=F HE=F'.split(),
        'forex': 'EURUSD=X JPY=X GBPUSD=X AUDUSD=X NZDUSD=X USDCHF=X USDCAD=X'.split(),
        'bonds': 'TNX TLT BND'.split(),
        'mutual_funds': 'VFINX VTSMX FBGRX'.split(),
        'etfs': 'SPY QQQ DIA GLD VXX'.split(),
        'stocks': 'AAPL MSFT GOOGL AMZN TSLA FB NVDA JPM V MA DIS HD'.split()
    }

    periods = ["1mo", "3mo", "6mo", "1y", "5y", "max"]
    intervals = ["1m", "5m", "15m", "1h", "1d", "1wk", "1mo"]

    for asset_type, symbols in asset_types.items():
        for symbol in symbols:
            for period in periods:
                for interval in intervals:
                    try:
                        file_name = f"{asset_type}_{symbol}_{period}_{interval}.parquet"
                        file_path = os.path.join(GOOGLE_DRIVE_DIR, "YFinance", asset_type, symbol)
                        if os.path.exists(file_path):
                            continue
                        data = yf.download(symbol, period=period, interval=interval)
                        if not data.empty:
                            data.reset_index(inplace=True)
                            save_to_parquet(data, file_name, file_path)
                    except Exception as e:
                        logging.error(f"Failed to fetch data for {symbol} ({asset_type}) with period {period} and interval {interval}.")

# --------------------------------------------------------------
# Main Execution
# --------------------------------------------------------------
if mount_google_drive():
    # Fetch CCXT data in parallel
    fetch_all_ccxt_data_parallel()

    # Fetch Yahoo Finance data dynamically with subfoldering
    fetch_yfinance_data()

    # Fetch Twitter, Reddit, etc.
    fetch_twitter_data(keywords=['Bitcoin', 'Cryptocurrency', 'BTC'])
    fetch_reddit_data(subreddits=['Bitcoin', 'CryptoCurrency', 'Crypto', 'BTC'])

    # Fetch from Alpha Vantage, World Bank, IMF, EIA, Fear & Greed, CryptoPanic
    fetch_alpha_vantage_data(symbol='BTCUSD')
    fetch_world_bank_data(indicator='NY.GDP.MKTP.CD')
    fetch_imf_data(database_id='WEO', indicator='NGDP_RPCH')

    # Asynchronous CoinMarketCap and CoinGecko data fetching
    #buggdincolab:) asyncio.run(fetch_coinmarketcap_data_async())
    #buggdincolab:) asyncio.run(fetch_coingecko_data_async())



Machine Learning Integration

We can enhance the existing LSTM implementation by creating a multi-model ensemble system that leverages different architectures:

In [None]:
import numpy as np
from typing import Dict, List
import tensorflow as tf
from transformers import AutoModelForSequenceClassification
import xgboost as xgb
from sklearn.ensemble import RandomForestClassifier
import torch

class EnsemblePredictor:
    def __init__(self, config):
        self.config = config
        self.models = {}
        self.weights = {}

    async def initialize_models(self):
        """Initialize all models in the ensemble."""
        # LSTM for time series
        self.models['lstm'] = self._create_lstm_model()

        # Transformer for sequence modeling
        self.models['transformer'] = self._create_transformer_model()

        # XGBoost for tabular data
        self.models['xgboost'] = xgb.XGBRegressor(
            objective='reg:squarederror',
            tree_method='gpu_hist',
            max_depth=6,
            learning_rate=0.1
        )

        # Random Forest for robust predictions
        self.models['rf'] = RandomForestClassifier(
            n_estimators=100,
            max_depth=10,
            n_jobs=-1
        )

        # Initialize weights equally
        model_count = len(self.models)
        self.weights = {name: 1.0/model_count for name in self.models}

    def _create_lstm_model(self) -> tf.keras.Model:
        """Create LSTM model with attention."""
        return tf.keras.Sequential([
            tf.keras.layers.LSTM(128, return_sequences=True),
            tf.keras.layers.MultiHeadAttention(
                num_heads=4, key_dim=32
            ),
            tf.keras.layers.GlobalAveragePooling1D(),
            tf.keras.layers.Dense(64, activation='relu'),
            tf.keras.layers.Dense(1)
        ])

    def _create_transformer_model(self) -> torch.nn.Module:
        """Create a custom transformer for time series."""
        config = self.config.model.transformer
        return AutoModelForSequenceClassification.from_pretrained(
            config.base_model,
            num_labels=1
        )

    async def predict(self, features: Dict[str, np.ndarray]) -> np.ndarray:
        """Generate ensemble predictions."""
        predictions = {}

        # Get predictions from each model
        for model_name, model in self.models.items():
            if model_name == 'lstm':
                pred = model.predict(features['sequence_data'])
            elif model_name == 'transformer':
                pred = self._get_transformer_predictions(
                    features['sequence_data']
                )
            else:  # XGBoost and RF
                pred = model.predict(features['tabular_data'])
            predictions[model_name] = pred

        # Weighted ensemble
        final_prediction = sum(
            self.weights[name] * pred
            for name, pred in predictions.items()
        )

        return final_prediction

    async def update_weights(
        self,
        performance_metrics: Dict[str, float]
    ):
        """Update model weights based on performance."""
        total_score = sum(performance_metrics.values())
        self.weights = {
            name: score/total_score
            for name, score in performance_metrics.items()
        }

Feature Engineering and Processing

The existing FeaturePipeline can be enhanced to support more advanced features:

In [None]:
import polars as pl
from typing import List, Dict
import numpy as np
from scipy import stats
import talib

class AdvancedFeatureEngineer:
    def __init__(self, config):
        self.config = config

    def compute_advanced_features(self, df: pl.DataFrame) -> pl.DataFrame:
        """Compute advanced trading features."""
        expressions = []

        # Volatility features
        expressions.extend(self._volatility_features())

        # Market regime features
        expressions.extend(self._market_regime_features())

        # Order book features
        expressions.extend(self._orderbook_features())

        # Cross-asset correlation features
        expressions.extend(self._correlation_features())

        return df.with_columns(expressions)

    def _volatility_features(self) -> List[pl.Expr]:
        """Complex volatility indicators."""
        return [
            # Realized volatility
            (pl.col('returns').rolling_std(window=20) * np.sqrt(252))
                .alias('realized_vol'),

            # Parkinson volatility
            (pl.col('high').log() - pl.col('low').log())
                .rolling_std(window=20)
                .mul(1/np.sqrt(4 * np.log(2)))
                .alias('parkinson_vol'),

            # GARCH volatility estimation
            pl.col('returns').map_batches(self._estimate_garch)
                .alias('garch_vol')
        ]

    def _market_regime_features(self) -> List[pl.Expr]:
        """Market regime identification features."""
        return [
            # Trend strength
            pl.col('close').map_batches(self._compute_hurst_exponent)
                .alias('hurst_exponent'),

            # Volatility regime
            pl.col('realized_vol')
                .map_batches(self._identify_vol_regime)
                .alias('vol_regime'),

            # Market efficiency ratio
            (pl.col('close').diff().abs().rolling_sum(20) /
             (pl.col('high').rolling_max(20) -
              pl.col('low').rolling_min(20)))
                .alias('efficiency_ratio')
        ]

    def _orderbook_features(self) -> List[pl.Expr]:
        """Order book derived features."""
        return [
            # Bid-ask spread
            ((pl.col('ask') - pl.col('bid')) / pl.col('mid'))
                .alias('relative_spread'),

            # Order book imbalance
            ((pl.col('bid_size') - pl.col('ask_size')) /
             (pl.col('bid_size') + pl.col('ask_size')))
                .alias('ob_imbalance'),

            # Market depth
            (pl.col('bid_size') + pl.col('ask_size'))
                .alias('market_depth')
        ]

    def _correlation_features(self) -> List[pl.Expr]:
        """Cross-asset correlation features."""
        return [
            # Rolling correlation with market
            pl.col('returns').rolling_corr(
                pl.col('market_returns'),
                window_size=20
            ).alias('market_correlation'),

            # Sector correlation
            pl.col('returns').rolling_corr(
                pl.col('sector_returns'),
                window_size=20
            ).alias('sector_correlation')
        ]

    @staticmethod
    def _estimate_garch(returns: np.ndarray) -> np.ndarray:
        """Estimate GARCH(1,1) volatility."""
        # Implementation of GARCH estimation
        omega = 0.000001
        alpha = 0.1
        beta = 0.8

        vol = np.zeros_like(returns)
        vol[0] = np.std(returns)

        for t in range(1, len(returns)):
            vol[t] = np.sqrt(
                omega +
                alpha * returns[t-1]**2 +
                beta * vol[t-1]**2
            )

        return vol

    @staticmethod
    def _compute_hurst_exponent(prices: np.ndarray, lags: int = 20) -> float:
        """Compute Hurst exponent for trend strength."""
        # Implementation of Hurst exponent calculation
        lags = range(2, lags)
        tau = [np.sqrt(np.std(np.subtract(prices[lag:], prices[:-lag])))
               for lag in lags]

        reg = np.polyfit(np.log(lags), np.log(tau), 1)
        return reg[0]  # Hurst exponent

Integration with Ollama and LLMs

We can enhance the system by incorporating large language models for sentiment analysis and market context:

In [None]:
import requests
from typing import List, Dict
import aiohttp
import json

class MarketContextAnalyzer:
    def __init__(self, config):
        self.config = config
        self.ollama_endpoint = "http://localhost:11434/api/generate"

    async def analyze_market_context(
        self,
        market_data: Dict,
        news_data: List[str],
        social_data: List[str]
    ) -> Dict:
        """Analyze market context using LLM."""
        prompt = self._construct_analysis_prompt(
            market_data, news_data, social_data
        )

        async with aiohttp.ClientSession() as session:
            async with session.post(
                self.ollama_endpoint,
                json={
                    "model": "llama2",
                    "prompt": prompt,
                    "stream": False
                }
            ) as response:
                result = await response.json()

        return self._parse_llm_response(result['response'])

    def _construct_analysis_prompt(
        self,
        market_data: Dict,
        news_data: List[str],
        social_data: List[str]
    ) -> str:
        """Construct prompt for market analysis."""
        return f"""
        Analyze the following market context:

        Market Data:
        - Current Price: {market_data['price']}
        - 24h Change: {market_data['change_24h']}%
        - Volume: {market_data['volume']}

        Recent News:
        {self._format_news(news_data)}

        Social Sentiment:
        {self._format_social(social_data)}

        Provide analysis of:
        1. Market sentiment
        2. Key drivers
        3. Risk factors
        4. Trading opportunities

        Format response as JSON with these keys.
        """

    def _parse_llm_response(self, response: str) -> Dict:
        """Parse and validate LLM response."""
        try:
            analysis = json.loads(response)
            required_keys = [
                'market_sentiment',
                'key_drivers',

Generate Metadata Catalogue Indexing /CCXT

In [None]:
# --------------------------------------------------------------
# Generate metadata catalogue
# --------------------------------------------------------------

import os
import sqlite3
import pyarrow.parquet as pq
import datetime
from tqdm import tqdm
from pathlib import Path
import logging
import yaml
import time
import shutil
import json
from concurrent.futures import ProcessPoolExecutor

# Check if the script is running in Colab
try:
    from google.colab import drive
    drive.mount('/content/drive')
except ModuleNotFoundError:
    logging.info("Not running in Colab, skipping Google Drive mount.")

# Path to config file
CONFIG_FILE = "/content/drive/MyDrive/data_warehouse/config.yaml"

# Default configuration
default_config = {
    'batch_size': 1000,
    'log_level': 'INFO',
    'google_drive_dir': '/content/drive/MyDrive/data_warehouse/',
    'db_file': '/content/drive/MyDrive/data_warehouse/parquet_metadata_catalog.db',
    'backup_dir': '/content/drive/MyDrive/data_warehouse/backups/'
}

# Check if config.yaml exists, and if not, create it
if not os.path.exists(CONFIG_FILE):
    # Create the directory if it doesn't exist
    os.makedirs(os.path.dirname(CONFIG_FILE), exist_ok=True)

    # Write the default configuration to the file
    with open(CONFIG_FILE, 'w') as yaml_file:
        yaml.dump(default_config, yaml_file)
    print(f"Default config.yaml created at {CONFIG_FILE}")

# Load configuration from YAML
with open(CONFIG_FILE, 'r') as f:
    config = yaml.safe_load(f)

# Configuration
BATCH_SIZE = config.get('batch_size', 1000)
LOG_LEVEL = config.get('log_level', 'INFO').upper()
GOOGLE_DRIVE_DIR = Path(config.get('google_drive_dir', '/content/drive/MyDrive/data_warehouse/'))
DB_FILE = Path(config.get('db_file', '/content/drive/MyDrive/data_warehouse/parquet_metadata_catalog.db'))
BACKUP_DIR = Path(config.get('backup_dir', '/content/drive/MyDrive/data_warehouse/backups/'))

# Setup logging
logging.basicConfig(level=LOG_LEVEL, format='%(asctime)s - %(levelname)s - %(message)s')

# Ensure directory exists for the DB and backup
DB_FILE.parent.mkdir(parents=True, exist_ok=True)
BACKUP_DIR.mkdir(parents=True, exist_ok=True)

# Backup function
def backup_database(db_file, backup_dir):
    backup_file = backup_dir / f"parquet_metadata_backup_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.db"
    shutil.copy(db_file, backup_file)
    logging.info(f"Database backup successful: {backup_file}")

# Function to extract metadata from Parquet files
def extract_parquet_metadata(file_path):
    try:
        parquet_file = pq.ParquetFile(file_path)
        source = file_path.stem  # Extract filename without extension
        timeframe = file_path.parent.name  # Folder name as timeframe
        schema_json = json.dumps({field.name: field.physical_type for field in parquet_file.schema})
        return {
            "file_name": file_path.name,
            "source": source,
            "timeframe": timeframe,
            "schema": schema_json,
            "last_updated": datetime.datetime.fromtimestamp(file_path.stat().st_mtime)
        }
    except pq.lib.ArrowInvalid as e:
        logging.error(f"Invalid Parquet file {file_path}: {e}")
        return None
    except Exception as e:
        logging.error(f"Error processing {file_path}: {e}")
        return None

# Initialize SQLite connection with timeout
def initialize_db(db_file):
    conn = sqlite3.connect(db_file, timeout=10)
    cursor = conn.cursor()

    # Create metadata and processed_files tables if they don't exist
    cursor.execute('''
        CREATE TABLE IF NOT EXISTS metadata (
            id INTEGER PRIMARY KEY,
            file_name TEXT,
            source TEXT,
            timeframe TEXT,
            schema TEXT,
            last_updated TIMESTAMP
        )
    ''')

    cursor.execute('''
        CREATE TABLE IF NOT EXISTS processed_files (
            file_name TEXT PRIMARY KEY
        )
    ''')

    # Create index to improve query performance
    cursor.execute('CREATE INDEX IF NOT EXISTS idx_processed_file_name ON processed_files(file_name)')

    conn.commit()
    return conn, cursor

# Process Parquet files concurrently and return the metadata
def process_files_concurrently(files):
    with ProcessPoolExecutor(max_workers=4) as executor:  # Adjust max_workers if necessary
        metadata_list = list(tqdm(executor.map(extract_parquet_metadata, files), total=len(files), desc="Processing files"))
        return [m for m in metadata_list if m]  # Filter out None results

# Main function to process Parquet files and insert metadata
def process_parquet_files():
    conn, cursor = initialize_db(DB_FILE)
    metadata_list = []
    failed_files = []

    logging.info("Collecting Parquet files...")
    all_files = list(tqdm(GOOGLE_DRIVE_DIR.rglob("*.parquet"), desc="Collecting Parquet files"))
    logging.info(f"Found {len(all_files)} Parquet files to process.")

    try:
        # Process files concurrently and collect metadata
        metadata_list = process_files_concurrently(all_files)

        # Batch insert metadata into the database
        for i in range(0, len(metadata_list), BATCH_SIZE):
            batch = metadata_list[i:i + BATCH_SIZE]
            cursor.executemany('''
                INSERT INTO metadata (file_name, source, timeframe, schema, last_updated)
                VALUES (?, ?, ?, ?, ?)
            ''', [(m['file_name'], m['source'], m['timeframe'], m['schema'], m['last_updated']) for m in batch])
            conn.commit()
            logging.info(f"Inserted {len(batch)} records into the database.")

        # Mark files as processed
        for metadata in metadata_list:
            cursor.execute('INSERT OR IGNORE INTO processed_files (file_name) VALUES (?)', (metadata["file_name"],))

        # Retry logic for failed files (if any)
        if failed_files:
            logging.warning(f"Retrying failed files: {failed_files}")
            for file in failed_files:
                metadata = extract_parquet_metadata(GOOGLE_DRIVE_DIR / file)
                if metadata:
                    cursor.execute('''
                        INSERT INTO metadata (file_name, source, timeframe, schema, last_updated)
                        VALUES (?, ?, ?, ?, ?)
                    ''', (
                        metadata["file_name"],
                        metadata["source"],
                        metadata["timeframe"],
                        metadata["schema"],
                        metadata["last_updated"]
                    ))
                    conn.commit()

        # Create indexes for better query performance
        logging.info("Creating indexes on source and timeframe columns...")
        cursor.execute('CREATE INDEX IF NOT EXISTS idx_source ON metadata(source)')
        cursor.execute('CREATE INDEX IF NOT EXISTS idx_timeframe ON metadata(timeframe)')
        conn.commit()

        # Optimize the database
        logging.info("Running VACUUM to optimize the database...")
        cursor.execute('VACUUM')

    except KeyboardInterrupt:
        logging.warning("Script interrupted by user.")
    finally:
        # Close connection gracefully
        conn.close()
        logging.info("Database connection closed.")
        # Backup the database after processing
        backup_database(DB_FILE, BACKUP_DIR)

if __name__ == "__main__":
    # Track total time for performance metrics
    start_time = time.time()

    # Process Parquet files
    process_parquet_files()

    # Log total time
    end_time = time.time()
    logging.info(f"Total processing time: {end_time - start_time} seconds")


Train models

In [None]:
#!/usr/bin/env python3
"""
Advanced Crypto Data Aggregator, Preprocessor, and Model Trainer using Polars
with a Pre-Computed File Index for Fast Recursive Lookups

This script performs the following tasks:
1. Loads a metadata catalogue from a SQLite database.
2. Recursively scans the ROOT_DIR once to build an index of all Parquet files.
3. Uses that index to quickly locate files for each metadata record.
4. Aggregates data from thousands of Parquet files (hundreds of exchanges/pairs).
5. Applies advanced feature engineering (technical indicators, lag features,
   one-hot encoding of categorical variables).
6. Creates sequences for LSTM training.
7. Splits data chronologically.
8. Trains multiple models (baseline RandomForest/XGBoost with GridSearchCV tuning,
   plus an LSTM model) and compares performance (MAE/MSE).
9. Saves the best-performing model.
10. Provides extensive logging and error handling.

Dependencies:
    - polars, numpy, pyarrow, tqdm
    - scikit-learn, xgboost, tensorflow
    - sqlite3 (built-in), pickle
    - concurrent.futures
"""

import os
import sys
import sqlite3
import logging
import json
import datetime
import time
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor, as_completed

import polars as pl
import numpy as np
from tqdm import tqdm

# For model training (we switch to Pandas/NumPy as needed)
from sklearn.ensemble import RandomForestRegressor
from xgboost import XGBRegressor
from sklearn.metrics import mean_absolute_error, mean_squared_error
from sklearn.model_selection import train_test_split, GridSearchCV

import tensorflow as tf
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import LSTM, Dense, Dropout

import pickle

# ------------------- Logging Configuration ------------------- #
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# ------------------- Global Configuration ------------------- #
# Update these paths according to your Google Drive structure.
DATA_FOLDER = Path("/content/drive/MyDrive/YourDataFolder")  # (e.g., Folder with id: 1zk392ymU_pBtRgua-OWCq_wZyJ5ejnlF)
ROOT_DIR = Path("/content/drive/MyDrive/data_warehouse")       # (e.g., Folder with id: 1ACX5jdVuRwCUo_0L9rVUC3HOdGz06dPb)
METADATA_DB = Path("/content/drive/MyDrive/data_warehouse/parquet_metadata_catalog.db")  # (your metadata catalogue)

# ------------------- 0. File Indexing ------------------- #
def index_all_files(root_dir: Path) -> dict:
    """
    Recursively index all .parquet files under root_dir.
    Returns a dictionary mapping file names (e.g., "binance_BTC_USDT_1h.parquet")
    to a list of full paths (Path objects). This is done only once to speed up lookups.
    """
    logger.info(f"Indexing all Parquet files under {root_dir} ...")
    file_index = {}
    # os.walk is used for speed (uses os.scandir internally)
    for dirpath, dirnames, filenames in os.walk(root_dir):
        for fname in filenames:
            if fname.endswith('.parquet'):
                full_path = Path(dirpath) / fname
                file_index.setdefault(fname, []).append(full_path)
    logger.info(f"Indexed {len(file_index)} unique Parquet file names.")
    return file_index

# Global index variable (to be built once)
FILE_INDEX = index_all_files(ROOT_DIR)

def find_parquet_file(file_name: str, file_index: dict) -> Path:
    """
    Look up a file by file_name using the pre-built file_index.
    Returns the first matching Path if found; otherwise, logs a warning and returns None.
    """
    if file_name in file_index:
        return file_index[file_name][0]  # Return first occurrence
    else:
        logger.warning(f"File {file_name} not found in index.")
        return None

# ------------------- 1. Data Loader Functions (using Polars) ------------------- #
def load_metadata(db_path: Path):
    """
    Load metadata from the SQLite database.
    Returns a pandas DataFrame containing metadata records.
    """
    try:
        conn = sqlite3.connect(str(db_path))
        import pandas as pd  # Use pandas here for SQL convenience
        df_meta = pd.read_sql_query("SELECT * FROM metadata", conn)
        conn.close()
        logger.info(f"Loaded {len(df_meta)} metadata records.")
        return df_meta
    except Exception as e:
        logger.error(f"Error loading metadata: {e}")
        sys.exit(1)

def load_single_parquet_polars(file_path: Path) -> pl.DataFrame:
    """
    Load a single Parquet file using Polars.
    Extracts metadata (exchange, pair, timeframe) from the file name and folder structure,
    and adds these as columns.
    """
    try:
        df = pl.read_parquet(str(file_path))
        file_stem = file_path.stem  # e.g., "binance_BTC_USDT_1h"
        parts = file_stem.split('_')
        if len(parts) >= 3:
            exchange = parts[0]
            pair = "_".join(parts[1:3])
        else:
            exchange = "unknown"
            pair = "unknown"
        timeframe = file_path.parent.name  # Use parent folder name as timeframe
        df = df.with_columns([
            pl.lit(exchange).alias("exchange"),
            pl.lit(pair).alias("pair"),
            pl.lit(timeframe).alias("timeframe")
        ])
        return df
    except Exception as e:
        logger.error(f"Error loading {file_path}: {e}")
        return pl.DataFrame()

def aggregate_data(metadata_df, root_dir: Path, file_index: dict) -> pl.DataFrame:
    """
    Aggregate all Parquet files listed in the metadata DataFrame.
    Uses the pre-built file_index for fast lookups and ProcessPoolExecutor for parallel loading.
    Returns a concatenated Polars DataFrame.
    """
    file_paths = []
    for _, row in metadata_df.iterrows():
        file_name = row["file_name"]
        fp = find_parquet_file(file_name, file_index)
        if fp is not None:
            file_paths.append(fp)
    logger.info(f"Found {len(file_paths)} Parquet files to load (via index).")

    df_list = []
    with ProcessPoolExecutor(max_workers=8) as executor:
        futures = {executor.submit(load_single_parquet_polars, fp): fp for fp in file_paths}
        for future in tqdm(as_completed(futures), total=len(futures), desc="Loading Parquet Files"):
            df_result = future.result()
            if df_result.height > 0:
                df_list.append(df_result)
    if df_list:
        aggregated_df = pl.concat(df_list)
        if "timestamp" in aggregated_df.columns:
            aggregated_df = aggregated_df.with_column(pl.col("timestamp").cast(pl.Datetime))
            aggregated_df = aggregated_df.drop_nulls(subset=["timestamp"]).sort("timestamp")
        logger.info(f"Aggregated DataFrame shape: {aggregated_df.shape}")
        return aggregated_df
    else:
        logger.error("No data loaded from Parquet files.")
        sys.exit(1)

# ------------------- 2. Preprocessing Pipeline Functions (using Polars) ------------------- #
def calculate_technical_indicators(df: pl.DataFrame) -> pl.DataFrame:
    """
    Compute technical indicators on the 'close' price.
    Computes approximate EMA_8, EMA_21 (using rolling means), SMA_10, and RSI_14.
    """
    df = df.with_columns([
        pl.col("close").rolling_mean(window_size=8).alias("EMA_8"),
        pl.col("close").rolling_mean(window_size=21).alias("EMA_21"),
        pl.col("close").rolling_mean(window_size=10).alias("SMA_10")
    ])
    # RSI calculation:
    df = df.with_column(pl.col("close").diff().alias("delta"))
    df = df.with_columns([
        pl.when(pl.col("delta") > 0).then(pl.col("delta")).otherwise(0).alias("gain"),
        pl.when(pl.col("delta") < 0).then(-pl.col("delta")).otherwise(0).alias("loss")
    ])
    df = df.with_columns([
        pl.col("gain").rolling_mean(window_size=14).alias("avg_gain"),
        pl.col("loss").rolling_mean(window_size=14).alias("avg_loss")
    ])
    df = df.with_columns([
        (pl.col("avg_gain") / pl.col("avg_loss")).alias("rs")
    ])
    df = df.with_columns([
        (100 - (100 / (1 + pl.col("rs")))).alias("RSI_14")
    ])
    df = df.drop(["delta", "gain", "loss", "avg_gain", "avg_loss", "rs"])
    return df

def add_lag_features_and_target(df: pl.DataFrame, lag: int = 1) -> pl.DataFrame:
    """
    Add lag features and define the target variable.
    Target is defined as the next closing price.
    """
    df = df.with_column(pl.col("close").shift(lag).alias("close_lag_1"))
    df = df.with_column(pl.col("close").shift(-1).alias("target"))
    df = df.drop_nulls()
    return df

def encode_categorical_features(df: pl.DataFrame) -> pl.DataFrame:
    """
    One-hot encode categorical features: exchange, pair, timeframe.
    Uses Polars' to_dummies.
    """
    df = df.to_dummies(columns=["exchange", "pair", "timeframe"])
    return df

def preprocess_data(df: pl.DataFrame) -> pl.DataFrame:
    """
    Full preprocessing pipeline:
      1. Calculate technical indicators.
      2. Add lag features and target variable.
      3. One-hot encode categorical features.
    """
    logger.info("Calculating technical indicators...")
    df = calculate_technical_indicators(df)
    logger.info("Adding lag features and target variable...")
    df = add_lag_features_and_target(df, lag=1)
    logger.info("Encoding categorical features...")
    df = encode_categorical_features(df)
    logger.info(f"Preprocessed DataFrame shape: {df.shape}")
    return df

def split_data(df: pl.DataFrame, test_ratio: float = 0.2):
    """
    Chronologically split data into training and testing sets.
    Returns X_train, X_test, y_train, y_test as Pandas objects (for ML compatibility).
    """
    df = df.sort("timestamp")
    n_total = df.height
    split_idx = int(n_total * (1 - test_ratio))
    features = df.drop(["timestamp", "target"])
    target = df.select("target")
    X = features.to_pandas()
    y = target.to_pandas().values.ravel()
    X_train = X.iloc[:split_idx]
    X_test = X.iloc[split_idx:]
    y_train = y[:split_idx]
    y_test = y[split_idx:]
    logger.info(f"Split data: Training set: {X_train.shape}, Test set: {X_test.shape}")
    return X_train, X_test, y_train, y_test

# ------------------- 3. Sequence Creation for LSTM ------------------- #
def create_sequences(X, y, sequence_length: int = 10):
    """
    Create sequences for LSTM training.
    X: Pandas DataFrame, y: numpy array.
    Returns: 3D numpy array for features and 1D array for targets.
    """
    X_values = X.to_numpy()
    y_values = y
    X_seq, y_seq = [], []
    for i in range(len(X_values) - sequence_length):
        X_seq.append(X_values[i:i+sequence_length])
        y_seq.append(y_values[i+sequence_length])
    X_seq = np.array(X_seq)
    y_seq = np.array(y_seq)
    logger.info(f"Created sequences: X_seq shape = {X_seq.shape}, y_seq shape = {y_seq.shape}")
    return X_seq, y_seq

# ------------------- 4. Model Training Functions ------------------- #
def train_rf_model(X_train, y_train, X_test, y_test):
    """
    Train a RandomForestRegressor with hyperparameter tuning via GridSearchCV.
    Returns the best model and its evaluation metrics.
    """
    rf = RandomForestRegressor(random_state=42)
    param_grid = {"n_estimators": [100, 200], "max_depth": [5, 10, None]}
    grid = GridSearchCV(rf, param_grid, cv=3, scoring="neg_mean_absolute_error", n_jobs=-1)
    grid.fit(X_train, y_train)
    best_rf = grid.best_estimator_
    y_pred = best_rf.predict(X_test)
    mae = mean_absolute_error(y_test, y_pred)
    mse = mean_squared_error(y_test, y_pred)
    logger.info(f"RandomForest best params: {grid.best_params_}")
    logger.info(f"RandomForest - MAE: {mae:.4f}, MSE: {mse:.4f}")
    return best_rf, mae, mse

def train_xgb_model(X_train, y_train, X_test, y_test):
    """
    Train an XGBRegressor with hyperparameter tuning via GridSearchCV.
    Returns the best model and its evaluation metrics.
    """
    xgb = XGBRegressor(random_state=42)
    param_grid = {"n_estimators": [100, 200], "learning_rate": [0.05, 0.1], "max_depth": [5, 10]}
    grid = GridSearchCV(xgb, param_grid, cv=3, scoring="neg_mean_absolute_error", n_jobs=-1)
    grid.fit(X_train, y_train)
    best_xgb = grid.best_estimator_
    y_pred = best_xgb.predict(X_test)
    mae = mean_absolute_error(y_test, y_pred)
    mse = mean_squared_error(y_test, y_pred)
    logger.info(f"XGBoost best params: {grid.best_params_}")
    logger.info(f"XGBoost - MAE: {mae:.4f}, MSE: {mse:.4f}")
    return best_xgb, mae, mse

def train_lstm_model(X_train, y_train, X_test, y_test, epochs=20, batch_size=32, sequence_length=10):
    """
    Train an LSTM model on sequential data.
    Returns the trained LSTM model and its evaluation metrics.
    """
    X_train_seq, y_train_seq = create_sequences(X_train, y_train, sequence_length)
    X_test_seq, y_test_seq = create_sequences(X_test, y_test, sequence_length)

    model = Sequential([
        LSTM(64, return_sequences=True, input_shape=(X_train_seq.shape[1], X_train_seq.shape[2])),
        Dropout(0.2),
        LSTM(64),
        Dropout(0.2),
        Dense(1)
    ])
    model.compile(optimizer="adam", loss="mean_squared_error")
    model.fit(X_train_seq, y_train_seq, epochs=epochs, batch_size=batch_size,
              validation_data=(X_test_seq, y_test_seq), verbose=1)

    y_pred_seq = model.predict(X_test_seq).flatten()
    mae = mean_absolute_error(y_test_seq, y_pred_seq)
    mse = mean_squared_error(y_test_seq, y_pred_seq)
    logger.info(f"LSTM - MAE: {mae:.4f}, MSE: {mse:.4f}")
    return model, mae, mse

def compare_model_results(results_dict):
    """
    Compare models based on MAE, log performance, and return the key of the best model.
    """
    best_model_key = None
    best_mae = float("inf")
    for model_name, metrics in results_dict.items():
        mae = metrics["MAE"]
        mse = metrics["MSE"]
        logger.info(f"Model {model_name}: MAE = {mae:.4f}, MSE = {mse:.4f}")
        if mae < best_mae:
            best_mae = mae
            best_model_key = model_name
    logger.info(f"Best model: {best_model_key} with MAE {best_mae:.4f}")
    return best_model_key

# ------------------- 5. Orchestration: Main Execution ------------------- #
def main():
    overall_start = time.time()
    # Step 1: Load metadata from SQLite
    metadata_df = load_metadata(METADATA_DB)

    # Step 2: Aggregate Parquet data using the pre-built FILE_INDEX
    aggregated_df = aggregate_data(metadata_df, ROOT_DIR, FILE_INDEX)

    # Step 3: Preprocess the aggregated data using Polars
    preprocessed_df = preprocess_data(aggregated_df)

    # Step 4: Split the data chronologically into training and testing sets
    X_train, X_test, y_train, y_test = split_data(preprocessed_df, test_ratio=0.2)

    # Step 5: Train models
    results = {}
    rf_model, rf_mae, rf_mse = train_rf_model(X_train, y_train, X_test, y_test)
    results["RandomForest"] = {"model": rf_model, "MAE": rf_mae, "MSE": rf_mse}

    xgb_model, xgb_mae, xgb_mse = train_xgb_model(X_train, y_train, X_test, y_test)
    results["XGBoost"] = {"model": xgb_model, "MAE": xgb_mae, "MSE": xgb_mse}

    lstm_model, lstm_mae, lstm_mse = train_lstm_model(X_train, y_train, X_test, y_test,
                                                       epochs=20, batch_size=32, sequence_length=10)
    results["LSTM"] = {"model": lstm_model, "MAE": lstm_mae, "MSE": lstm_mse}

    # Step 6: Compare models and select the best one based on MAE
    best_model_key = compare_model_results(results)

    # Step 7: Save the best-performing model
    if best_model_key == "LSTM":
        save_path = "best_lstm_model.h5"
        results["LSTM"]["model"].save(save_path)
    else:
        save_path = f"best_{best_model_key}_model.pkl"
        with open(save_path, "wb") as f:
            pickle.dump(results[best_model_key]["model"], f)
    logger.info(f"Best model saved to {save_path}")

    overall_end = time.time()
    logger.info(f"Total execution time: {overall_end - overall_start:.2f} seconds")

if __name__ == "__main__":
    main()


100 years  later