# Inference on Your Own Data

This notebook allows you to apply the pre-trained TD2C model to your own time series data (e.g., CSV file).

### Input Requirements
1.  **Format:** A CSV file where **columns are variables** and **rows are time steps**.
2.  **Stationarity:** The data should ideally be stationary. If not, consider differencing ($X_t - X_{t-1}$).
3.  **Normalization:** The pipeline handles standardization internally.

### 1. Load the Pre-trained Model
In this example, we check if a saved model exists. If not, we train a robust one on synthetic data on-the-fly.

In [None]:
import pandas as pd
import numpy as np
import joblib
import os
from imblearn.ensemble import BalancedRandomForestClassifier
from td2c.data_generation.builder import TSBuilder
from td2c.descriptors import D2C, DataLoader

MODEL_PATH = "pretrained_td2c.joblib"

def get_model():
    if os.path.exists(MODEL_PATH):
        print(f"Loading model from {MODEL_PATH}...")
        return joblib.load(MODEL_PATH)
    else:
        print("Pre-trained model not found. Training a robust model now (this takes ~1 min)...")
        # Generate generic training data
        builder = TSBuilder(n_variables=5, maxlags=2, observations_per_time_series=200, time_series_per_process=20, verbose=False)
        builder.build()
        loader = DataLoader(n_variables=5, maxlags=2)
        loader.from_tsbuilder(builder)
        
        engine = D2C(loader.get_observations(), loader.get_dags(), n_variables=5, maxlags=2, n_jobs=-1, full=True, dynamic=True, mb_estimator="ts")
        engine.initialize()
        df_train = engine.get_descriptors_df()
        
        X = df_train.drop(columns=["graph_id", "edge_source", "edge_dest", "is_causal"])
        y = df_train["is_causal"]
        
        clf = BalancedRandomForestClassifier(n_estimators=200, random_state=42, n_jobs=-1)
        clf.fit(X, y)
        
        joblib.dump(clf, MODEL_PATH)
        print("Model trained and saved.")
        return clf

model = get_model()

### 2. Load Your Data
Replace `'my_data.csv'` with your file path. For this demo, we create a dummy dataframe.

In [None]:
# --- User Input Section ---
# df_user = pd.read_csv("my_data.csv")

# Creating dummy data for demonstration
T = 100
# A causes B with lag 1
A = np.random.randn(T)
B = np.roll(A, 1) + np.random.normal(0, 0.2, T)
C = np.random.randn(T) # Noise
df_user = pd.DataFrame({'Var_A': A, 'Var_B': B, 'Var_C': C}).iloc[1:] # Drop first row due to shift

print("Data Preview:")
print(df_user.head())   

### 3. Run Inference
We use the `D2CWrapper` to process the data and predict causal links.

In [None]:
from td2c.benchmark import D2CWrapper

# Convert to numpy array
data_array = df_user.values
n_vars = data_array.shape[1]
var_names = df_user.columns.tolist()

print(f"Analyzing {n_vars} variables...")

# Run Inference
wrapper = D2CWrapper(
    ts_list=[data_array], 
    model=model,
    n_variables=n_vars,
    maxlags=2,
    mb_estimator="ts"
)
wrapper.run()

# Extract Results
results = wrapper.get_causal_dfs()[0]
# Filter for significant links
significant_links = results[results['probability'] > 0.5].copy()

print(f"\nFound {len(significant_links)} causal links.")    

### 4. Visualize the Graph

In [None]:
import networkx as nx
import matplotlib.pyplot as plt

G = nx.DiGraph()

# Add nodes
for name in var_names:
    G.add_node(name)

# Add edges
for _, row in significant_links.iterrows():
    # Map index to name
    # The output 'from'/'to' are indices in the lagged flattened array
    # We need to map them back to variable names
    # Simplified mapping for visualization:
    
    # TD2C output format: 0..N-1 are variables at t
    # N..2N-1 are variables at t-1, etc.
    
    source_idx = int(row['from'])
    target_idx = int(row['to'])
    
    # Determine Lag
    lag = source_idx // n_vars
    actual_source_var_idx = source_idx % n_vars
    
    source_name = var_names[actual_source_var_idx]
    target_name = var_names[target_idx]
    
    # Add edge with lag info
    G.add_edge(source_name, target_name, lag=lag, weight=row['probability'])

# Plot
plt.figure(figsize=(8, 6))
pos = nx.spring_layout(G, seed=42)
nx.draw(G, pos, with_labels=True, node_color='lightblue', node_size=2000, font_size=12, arrowsize=20)
edge_labels = nx.get_edge_attributes(G, 'lag')
edge_labels = {k: f"Lag {v}" for k, v in edge_labels.items()}
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
plt.title("Inferred Causal Graph")
plt.show()