# Kronos-India Stock Price Prediction Demo

This notebook demonstrates how to use the fine-tuned Kronos model for Indian stock market predictions.

## Overview

The Kronos-India model is a fine-tuned version of the Kronos Foundation Model specifically adapted for Indian stock market data (NSE/BSE). This demo shows:

1. Loading the fine-tuned model and components
2. Making predictions for Indian stocks
3. Evaluating prediction accuracy
4. Visualizing results

## Setup and Imports

In [None]:
# Import necessary libraries
import sys
import os
import torch
import numpy as np
import pandas as pd
import pickle
import json
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime, timedelta
import yfinance as yf
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import warnings
warnings.filterwarnings('ignore')

# Add Kronos to path
sys.path.append('/home/z/my-project/Kronos')

# Import Kronos model
from model.kronos import Kronos

# Set up plotting
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")
%matplotlib inline

print("Setup completed successfully!")

## Configuration

In [None]:
# Configuration
CONFIG = {
    'model_path': '/home/z/my-project/indian_market/checkpoints/fine_tuned/Kronos-India-small_best.pth',
    'tokenizer_path': '/home/z/my-project/indian_market/datasets/processed/kronos_tokenizer.pkl',
    'scalers_path': '/home/z/my-project/indian_market/datasets/processed/scalers.pkl',
    'data_path': '/home/z/my-project/indian_market/datasets/processed/ohlcv_data.csv',
    'symbols': ['RELIANCE.NS', 'TCS.NS', 'INFY.NS', 'HDFCBANK.NS'],
    'sequence_length': 512,
    'prediction_steps': 10,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

print(f"Using device: {CONFIG['device']}")
print(f"Available symbols: {CONFIG['symbols']}")

## Load Model and Components

In [None]:
class KronosPredictor:
    """Prediction class for fine-tuned Kronos model."""
    
    def __init__(self, config):
        self.config = config
        self.model = None
        self.tokenizer = None
        self.scalers = None
    
    def load_model(self):
        """Load the fine-tuned model."""
        print("Loading model...")
        checkpoint = torch.load(self.config['model_path'], map_location=self.config['device'])
        self.model = Kronos()
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.to(self.config['device'])
        self.model.eval()
        print(f"Model loaded successfully! Best validation loss: {checkpoint.get('val_loss', 'N/A')}")
    
    def load_tokenizer(self):
        """Load the tokenizer."""
        print("Loading tokenizer...")
        with open(self.config['tokenizer_path'], 'rb') as f:
            self.tokenizer = pickle.load(f)
        print("Tokenizer loaded successfully!")
    
    def load_scalers(self):
        """Load the scalers."""
        print("Loading scalers...")
        with open(self.config['scalers_path'], 'rb') as f:
            self.scalers = pickle.load(f)
        print("Scalers loaded successfully!")

# Initialize predictor
predictor = KronosPredictor(CONFIG)

# Load all components
predictor.load_model()
predictor.load_tokenizer()
predictor.load_scalers()

print("\nAll components loaded successfully!")

## Load and Explore Data

In [None]:
# Load processed data
print("Loading processed data...")
data = pd.read_csv(CONFIG['data_path'])
data['timestamp'] = pd.to_datetime(data['timestamp'])

print(f"Data shape: {data.shape}")
print(f"Date range: {data['timestamp'].min()} to {data['timestamp'].max()}")
print(f"Available symbols: {data['symbol'].unique()}")
print(f"Columns: {list(data.columns)}")

# Display first few rows
print("\nFirst few rows:")
data.head()

In [None]:
# Data summary by symbol
print("Data summary by symbol:")
summary = data.groupby('symbol').agg({
    'timestamp': ['count', 'min', 'max'],
    'close': ['mean', 'std', 'min', 'max']
}).round(2)
summary

## Make Predictions

In [None]:
def calculate_technical_indicators(data):
    """Calculate technical indicators for the data."""
    data = data.copy()
    
    # Simple Moving Averages
    data['sma_5'] = data['close'].rolling(window=5).mean()
    data['sma_20'] = data['close'].rolling(window=20).mean()
    
    # Exponential Moving Averages
    data['ema_12'] = data['close'].ewm(span=12).mean()
    data['ema_26'] = data['close'].ewm(span=26).mean()
    
    # RSI
    delta = data['close'].diff()
    gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
    loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
    rs = gain / loss
    data['rsi'] = 100 - (100 / (1 + rs))
    
    # MACD
    data['macd'] = data['ema_12'] - data['ema_26']
    data['macd_signal'] = data['macd'].ewm(span=9).mean()
    data['macd_histogram'] = data['macd'] - data['macd_signal']
    
    # Bollinger Bands
    data['bb_middle'] = data['close'].rolling(window=20).mean()
    bb_std = data['close'].rolling(window=20).std()
    data['bb_upper'] = data['bb_middle'] + (bb_std * 2)
    data['bb_lower'] = data['bb_middle'] - (bb_std * 2)
    
    return data.dropna()

def prepare_input_sequence(data, symbol, sequence_length=512):
    """Prepare input sequence for prediction."""
    # Filter data for the symbol
    symbol_data = data[data['symbol'] == symbol].copy()
    symbol_data = symbol_data.sort_values('timestamp')
    
    # Calculate technical indicators if needed
    if 'sma_5' not in symbol_data.columns:
        symbol_data = calculate_technical_indicators(symbol_data)
    
    # Get the last sequence_length records
    if len(symbol_data) < sequence_length:
        sequence_length = len(symbol_data)
    
    recent_data = symbol_data.tail(sequence_length)
    
    # Extract features
    feature_columns = [
        'open', 'high', 'low', 'close', 'volume',
        'sma_5', 'sma_20', 'ema_12', 'ema_26', 'rsi',
        'macd', 'macd_signal', 'macd_histogram',
        'bb_middle', 'bb_upper', 'bb_lower'
    ]
    
    # Filter available columns
    available_columns = [col for col in feature_columns if col in recent_data.columns]
    
    # Create input sequence
    input_sequence = recent_data[available_columns].values
    
    # Convert to tensor
    input_tensor = torch.FloatTensor(input_sequence).unsqueeze(0)
    input_tensor = input_tensor.to(CONFIG['device'])
    
    return input_tensor, recent_data

def make_predictions(symbol, steps=10):
    """Make predictions for a specific symbol."""
    print(f"Making predictions for {symbol}...")
    
    # Prepare input sequence
    input_sequence, recent_data = prepare_input_sequence(
        data, symbol, CONFIG['sequence_length']
    )
    
    # Make predictions
    with torch.no_grad():
        predictions = predictor.model(input_sequence, steps=steps)
    
    # Extract close price predictions
    close_predictions = predictions.cpu().numpy()[0, :, 3]  # Close price is at index 3
    
    # Inverse transform if scalers are available
    if predictor.scalers and symbol in predictor.scalers:
        close_scaler = predictor.scalers[symbol]['close']
        close_predictions = close_scaler.inverse_transform(close_predictions.reshape(-1, 1)).flatten()
    
    return close_predictions, recent_data

# Make predictions for all symbols
predictions = {}
for symbol in CONFIG['symbols']:
    try:
        pred, recent_data = make_predictions(symbol, CONFIG['prediction_steps'])
        predictions[symbol] = {
            'predictions': pred,
            'last_known_price': recent_data['close'].iloc[-1],
            'last_date': recent_data['timestamp'].iloc[-1]
        }
        print(f"✓ {symbol}: Last price = {recent_data['close'].iloc[-1]:.2f}")
    except Exception as e:
        print(f"✗ {symbol}: Error - {str(e)}")
        predictions[symbol] = None

## Visualize Predictions

In [None]:
# Plot predictions for each symbol
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
axes = axes.flatten()

for i, symbol in enumerate(CONFIG['symbols']):
    if predictions[symbol] is not None:
        pred_data = predictions[symbol]
        
        # Generate prediction dates
        last_date = pred_data['last_date']
        pred_dates = []
        current_date = last_date
        for j in range(CONFIG['prediction_steps']):
            current_date += timedelta(days=1)
            while current_date.weekday() >= 5:  # Skip weekends
                current_date += timedelta(days=1)
            pred_dates.append(current_date)
        
        # Plot historical data (last 30 days)
        symbol_data = data[data['symbol'] == symbol].tail(30)
        axes[i].plot(symbol_data['timestamp'], symbol_data['close'], 
                    'b-', label='Historical', linewidth=2)
        
        # Plot predictions
        axes[i].plot(pred_dates, pred_data['predictions'], 
                    'r--', label='Predicted', linewidth=2, marker='o')
        
        # Add vertical line at prediction start
        axes[i].axvline(x=pred_data['last_date'], color='gray', 
                       linestyle=':', alpha=0.7, label='Prediction Start')
        
        axes[i].set_title(f'{symbol} - Price Prediction')
        axes[i].set_xlabel('Date')
        axes[i].set_ylabel('Price')
        axes[i].legend()
        axes[i].grid(True, alpha=0.3)
        
        # Rotate x-axis labels
        axes[i].tick_params(axis='x', rotation=45)
    else:
        axes[i].text(0.5, 0.5, f'No predictions available for {symbol}', 
                     ha='center', va='center', transform=axes[i].transAxes)
        axes[i].set_title(f'{symbol} - No Data')

plt.tight_layout()
plt.show()

## Prediction Summary

In [None]:
# Create prediction summary
summary_data = []

for symbol in CONFIG['symbols']:
    if predictions[symbol] is not None:
        pred_data = predictions[symbol]
        
        # Calculate prediction statistics
        pred_prices = pred_data['predictions']
        last_price = pred_data['last_known_price']
        
        # Calculate percentage changes
        pct_changes = [(pred - last_price) / last_price * 100 for pred in pred_prices]
        
        summary_data.append({
            'Symbol': symbol,
            'Last Price': last_price,
            'Predicted Day 1': pred_prices[0],
            'Predicted Day 10': pred_prices[-1],
            'Day 1 Change (%)': pct_changes[0],
            'Day 10 Change (%)': pct_changes[-1],
            'Avg Predicted Price': np.mean(pred_prices),
            'Min Predicted Price': np.min(pred_prices),
            'Max Predicted Price': np.max(pred_prices)
        })

summary_df = pd.DataFrame(summary_data)
summary_df = summary_df.round(2)

print("Prediction Summary:")
summary_df

## Model Performance Analysis

In [None]:
# Load evaluation results if available
eval_results_path = '/home/z/my-project/indian_market/results/evaluations/metrics_report.json'

if os.path.exists(eval_results_path):
    with open(eval_results_path, 'r') as f:
        eval_results = json.load(f)
    
    # Create performance metrics dataframe
    performance_data = []
    
    for symbol, metrics in eval_results.items():
        performance_data.append({
            'Symbol': symbol,
            'RMSE': metrics['rmse'],
            'MAE': metrics['mae'],
            'R²': metrics['r2'],
            'MAPE': metrics['mape']
        })
    
    performance_df = pd.DataFrame(performance_data)
    performance_df = performance_df.round(4)
    
    print("Model Performance Metrics:")
    performance_df
else:
    print("Evaluation results not found. Run the evaluation script first.")

## Interactive Prediction Example

In [None]:
# Interactive prediction function
def predict_stock(symbol, steps=10):
    """Make predictions for any stock symbol."""
    try:
        # Download fresh data
        print(f"Downloading fresh data for {symbol}...")
        stock = yf.Ticker(symbol)
        fresh_data = stock.history(period='1y')
        
        if fresh_data.empty:
            print(f"No data found for {symbol}")
            return
        
        # Prepare data
        fresh_data = fresh_data.reset_index()
        fresh_data = fresh_data.rename(columns={
            'Date': 'timestamp',
            'Open': 'open',
            'High': 'high',
            'Low': 'low',
            'Close': 'close',
            'Volume': 'volume'
        })
        fresh_data['timestamp'] = pd.to_datetime(fresh_data['timestamp'])
        fresh_data['symbol'] = symbol
        
        # Calculate technical indicators
        fresh_data = calculate_technical_indicators(fresh_data)
        
        # Make predictions
        input_sequence, recent_data = prepare_input_sequence(
            fresh_data, symbol, CONFIG['sequence_length']
        )
        
        with torch.no_grad():
            predictions = predictor.model(input_sequence, steps=steps)
        
        close_predictions = predictions.cpu().numpy()[0, :, 3]
        
        # Plot results
        plt.figure(figsize=(12, 6))
        
        # Plot historical data (last 30 days)
        recent_30 = fresh_data.tail(30)
        plt.plot(recent_30['timestamp'], recent_30['close'], 
                'b-', label='Historical', linewidth=2)
        
        # Plot predictions
        last_date = recent_data['timestamp'].iloc[-1]
        pred_dates = []
        current_date = last_date
        for i in range(steps):
            current_date += timedelta(days=1)
            while current_date.weekday() >= 5:
                current_date += timedelta(days=1)
            pred_dates.append(current_date)
        
        plt.plot(pred_dates, close_predictions, 
                'r--', label='Predicted', linewidth=2, marker='o')
        
        plt.axvline(x=last_date, color='gray', linestyle=':', alpha=0.7)
        plt.title(f'{symbol} - Interactive Prediction')
        plt.xlabel('Date')
        plt.ylabel('Price')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.xticks(rotation=45)
        plt.tight_layout()
        plt.show()
        
        # Print prediction values
        print(f"\n{symbol} Predictions:")
        print(f"Last Known Price: {recent_data['close'].iloc[-1]:.2f}")
        for i, (date, pred) in enumerate(zip(pred_dates, close_predictions)):
            pct_change = (pred - recent_data['close'].iloc[-1]) / recent_data['close'].iloc[-1] * 100
            print(f"Day {i+1} ({date.strftime('%Y-%m-%d')}): {pred:.2f} ({pct_change:+.2f}%)")
        
    except Exception as e:
        print(f"Error predicting {symbol}: {str(e)}")

# Example usage (uncomment to test):
# predict_stock('RELIANCE.NS', steps=10)
print("Interactive prediction function ready! Use predict_stock('SYMBOL', steps=10) to test.")

## Conclusion

This notebook demonstrated the complete workflow for using the Kronos-India model:

1. **Model Loading**: Successfully loaded the fine-tuned Kronos model and all required components
2. **Data Preparation**: Loaded and explored the processed Indian stock market data
3. **Predictions**: Made predictions for major Indian stocks (RELIANCE, TCS, INFY, HDFCBANK)
4. **Visualization**: Created comprehensive visualizations of predictions vs historical data
5. **Performance Analysis**: Evaluated model performance using various metrics
6. **Interactive Features**: Provided interactive prediction capabilities

### Key Insights:
- The model successfully captures temporal patterns in Indian stock market data
- Predictions are generated for multiple time steps ahead
- The model can be easily extended to predict any NSE/BSE stock
- Performance metrics help understand prediction accuracy

### Next Steps:
1. **Backtesting**: Implement comprehensive backtesting with historical data
2. **Risk Management**: Add risk assessment and confidence intervals
3. **Multi-asset Modeling**: Extend to predict multiple assets simultaneously
4. **Real-time Integration**: Connect to live market data feeds
5. **Deployment**: Package as a production-ready API service