# Takahashi Diffusion Example Usage

This notebook demonstrates how to use Takahashi Diffusion to generate synthetic time series data.

## Steps:
1. Generate mock stochastic exponential time series data (Geometric Brownian Motion)
2. Preprocess data using preprocessing utilities (convert to log returns)
3. Train Takahashi Diffusion model (assumes data is already log returns)
4. Generate synthetic samples

**Note:** Takahashi Diffusion uses wavelet transforms and diffusion models (DDPM) for generation.


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import sys

project_root = Path().resolve().parents[0].parents[0]
sys.path.append(str(project_root))

from src.models.non_parametric.takahashi import TakahashiDiffusion
from src.utils.preprocessing_utils import (
    LogReturnTransformation,
    preprocess_non_parametric,
    create_dataloaders
)


## Step 1: Generate Mock Stochastic Exponential Time Series Data

We'll generate a Geometric Brownian Motion (GBM) time series, which is commonly used for modeling stock prices.


In [None]:
def generate_gbm_prices(
    initial_price: float = 100.0,
    mu: float = 0.05,  # Annual drift
    sigma: float = 0.2,  # Annual volatility
    num_days: int = 1000,
    dt: float = 1/252,  # Daily time step (252 trading days per year)
    seed: int = 42
) -> torch.Tensor:
    """
    Generate Geometric Brownian Motion price series.
    
    Args:
        initial_price: Starting price
        mu: Annual drift rate
        sigma: Annual volatility
        num_days: Number of days to simulate
        dt: Time step (default: daily)
        seed: Random seed
    
    Returns:
        Price series as torch.Tensor
    """
    torch.manual_seed(seed)
    np.random.seed(seed)
    
    # Generate random shocks
    Z = torch.randn(num_days)
    
    # GBM: dS = mu*S*dt + sigma*S*dW
    # In log space: d(log S) = (mu - 0.5*sigma^2)*dt + sigma*dW
    log_returns = (mu - 0.5 * sigma**2) * dt + sigma * np.sqrt(dt) * Z
    
    # Convert log returns to prices
    log_prices = torch.cumsum(torch.cat([torch.tensor([np.log(initial_price)]), log_returns]), dim=0)
    prices = torch.exp(log_prices)
    
    return prices

# Generate synthetic price data
original_prices = generate_gbm_prices(
    initial_price=100.0,
    mu=0.1,  # 10% annual return
    sigma=0.25,  # 25% annual volatility
    num_days=10000,
    seed=42
)

print(f"Generated {len(original_prices)} days of price data")
print(f"Price range: [{original_prices.min():.2f}, {original_prices.max():.2f}]")

# Visualize the generated data
plt.figure(figsize=(12, 5))
plt.plot(original_prices.numpy())
plt.title('Generated GBM Price Series')
plt.xlabel('Days')
plt.ylabel('Price')
plt.grid(True)
plt.show()


## Step 2: Preprocess Data

Convert prices to log returns and create sliding windows for non-parametric models.

**Note:** Takahashi Diffusion assumes input data is already log returns, so we convert prices to log returns here.


In [None]:
# Convert prices to log returns
scaler = LogReturnTransformation()
log_returns, initial_value = scaler.transform(original_prices)

print(f"Log returns shape: {log_returns.shape}")
print(f"Log returns stats: mean={log_returns.mean():.6f}, std={log_returns.std():.6f}")

# Preprocess for non-parametric model
seq_length = 100  # Window size for time series
train_data, valid_data, test_data, train_initial, valid_initial, test_initial = preprocess_non_parametric(
    ori_data=log_returns,
    original_prices=original_prices,
    seq_length=seq_length,
    valid_ratio=0.1,
    test_ratio=0.1,
    stride=1
)

print(f"\nPreprocessed data shapes:")
print(f"Train: {train_data.shape}")
print(f"Valid: {valid_data.shape}")
print(f"Test: {test_data.shape}")

# Visualize a few training windows
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
axes = axes.flatten()
for i in range(4):
    axes[i].plot(train_data[i].numpy())
    axes[i].set_title(f'Training Window {i+1}')
    axes[i].set_xlabel('Time Step')
    axes[i].set_ylabel('Log Return')
    axes[i].grid(True)
plt.tight_layout()
plt.show()


tio