# Phase 0: Pre-Training Validation

This notebook validates the enriched data before FinBERT training.

**Checks performed:**
1. Feature leakage (no future-looking columns in features)
2. Target distribution (class imbalance)
3. Author bias analysis
4. Data quality metrics
5. Premarket risk assessment
6. Optional: Sentiment signal check


## Cell 1: Setup and Load Data


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Load enriched data
df = pd.read_csv("../output/15-dec-enrich7.csv")

# Global configuration (from FINBERT_TRAINING_PLAN.md)
TARGET_COLUMN = "label_1d_3class"

NUMERICAL_FEATURES = [
    "volatility_7d",
    "relative_volume",
    "rsi_14",
    "distance_from_ma_20",
]

FORBIDDEN_AS_FEATURES = [
    "spy_return_1d",
    "spy_return_1hr",
    "return_1hr",
    "price_1hr_after",
    "return_to_next_open",
    "price_next_open",
]

print(f"Loaded {len(df)} rows with {len(df.columns)} columns")
print(f"\nColumns: {list(df.columns)}")


## Cell 2: Feature Leakage Check

Verify that no future-looking columns are in the feature set.


In [None]:
print("=" * 50)
print("FEATURE LEAKAGE CHECK")
print("=" * 50)

# Check 1: No forbidden columns in NUMERICAL_FEATURES
leakage_found = False
for col in FORBIDDEN_AS_FEATURES:
    if col in NUMERICAL_FEATURES:
        print(f"❌ LEAK: {col} found in NUMERICAL_FEATURES!")
        leakage_found = True

if not leakage_found:
    print("✓ No future columns in NUMERICAL_FEATURES")

# Check 2: Specifically verify spy_return_1d exclusion
assert "spy_return_1d" not in NUMERICAL_FEATURES, "LEAK: spy_return_1d uses day T close!"
print("✓ spy_return_1d correctly excluded from features")

# Display what IS included vs excluded
print(f"\n--- Features to be used (safe) ---")
for f in NUMERICAL_FEATURES:
    exists = f in df.columns
    status = "✓" if exists else "⚠️ MISSING"
    print(f"  {status} {f}")

print(f"\n--- Columns excluded (future-looking) ---")
for f in FORBIDDEN_AS_FEATURES:
    exists = f in df.columns
    status = "(in data, excluded)" if exists else "(not in data)"
    print(f"  ✓ {f} {status}")


## Cell 3: Target Distribution

Check class distribution for `label_1d_3class` (expecting ~61% HOLD based on plan).


In [None]:
print("=" * 50)
print(f"TARGET DISTRIBUTION ({TARGET_COLUMN})")
print("=" * 50)

# Value counts
target_dist = df[TARGET_COLUMN].value_counts()
target_pct = df[TARGET_COLUMN].value_counts(normalize=True) * 100

print("\nAbsolute counts:")
print(target_dist)

print("\nPercentages:")
for label, pct in target_pct.items():
    print(f"  {label}: {pct:.1f}%")

# Check for missing targets
missing_target = df[TARGET_COLUMN].isna().sum()
print(f"\nMissing targets: {missing_target} ({100*missing_target/len(df):.1f}%)")

# Visualization
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Bar chart
colors = {'SELL': '#e74c3c', 'HOLD': '#95a5a6', 'BUY': '#27ae60'}
target_dist.plot(kind='bar', ax=axes[0], color=[colors.get(x, '#3498db') for x in target_dist.index])
axes[0].set_title(f'{TARGET_COLUMN} Distribution')
axes[0].set_xlabel('Label')
axes[0].set_ylabel('Count')
axes[0].tick_params(axis='x', rotation=0)

# Pie chart
target_dist.plot(kind='pie', ax=axes[1], autopct='%1.1f%%', 
                 colors=[colors.get(x, '#3498db') for x in target_dist.index])
axes[1].set_title('Class Balance')
axes[1].set_ylabel('')

plt.tight_layout()
plt.show()

# Imbalance warning
max_class_pct = target_pct.max()
if max_class_pct > 60:
    print(f"\n⚠️  WARNING: Class imbalance detected ({max_class_pct:.1f}% in majority class)")
    print("   → Use class weights in training loss function")
else:
    print("\n✓ Class distribution is reasonably balanced")


## Cell 4: Author Bias Analysis

Check author distribution (plan notes 2 authors = 65% of data).


In [None]:
print("=" * 50)
print("AUTHOR BIAS ANALYSIS")
print("=" * 50)

author_dist = df["author"].value_counts()
author_pct = df["author"].value_counts(normalize=True) * 100

print(f"\nTotal unique authors: {df['author'].nunique()}")
print("\nTop 5 authors:")
for i, (author, count) in enumerate(author_dist.head(5).items()):
    pct = author_pct[author]
    print(f"  {i+1}. {author}: {count} tweets ({pct:.1f}%)")

# Top 2 authors combined
top2_pct = author_pct.head(2).sum()
print(f"\nTop 2 authors combined: {top2_pct:.1f}%")

# Visualization
fig, ax = plt.subplots(figsize=(10, 5))
author_dist.head(10).plot(kind='barh', ax=ax, color='#3498db')
ax.set_title('Top 10 Authors by Tweet Count')
ax.set_xlabel('Number of Tweets')
ax.invert_yaxis()
plt.tight_layout()
plt.show()

# Bias warning
if top2_pct > 50:
    print(f"\n⚠️  WARNING: Author concentration detected ({top2_pct:.1f}% from top 2)")
    print("   → Add author embedding to model to control for author bias")
    print("   → Consider stratified sampling by author")
else:
    print("\n✓ Author distribution is reasonably diverse")

# Check label distribution per top author
print("\n--- Label distribution by top 3 authors ---")
for author in author_dist.head(3).index:
    author_labels = df[df["author"] == author][TARGET_COLUMN].value_counts(normalize=True) * 100
    print(f"\n{author}:")
    for label, pct in author_labels.items():
        print(f"  {label}: {pct:.1f}%")


## Cell 5: Data Quality Metrics

Check total samples, reliable labels, and valid 1-day labels.


In [None]:
print("=" * 50)
print("DATA QUALITY METRICS")
print("=" * 50)

total_samples = len(df)

# Reliable labels
if "is_reliable_label" in df.columns:
    df_reliable = df[df["is_reliable_label"] == True]
    reliable_count = len(df_reliable)
else:
    df_reliable = df
    reliable_count = total_samples
    print("⚠️  'is_reliable_label' column not found, using all samples")

# Valid 1-day labels
valid_1d_labels = df[TARGET_COLUMN].notna().sum()

# Summary table
print(f"\n{'Metric':<30} {'Count':>10} {'Percentage':>12}")
print("-" * 55)
print(f"{'Total samples':<30} {total_samples:>10} {'100.0%':>12}")
print(f"{'Reliable 1hr labels':<30} {reliable_count:>10} {100*reliable_count/total_samples:>11.1f}%")
print(f"{'With 1-day labels':<30} {valid_1d_labels:>10} {100*valid_1d_labels/total_samples:>11.1f}%")

# Check numerical features for missing values
print("\n--- Missing values in numerical features ---")
for col in NUMERICAL_FEATURES:
    if col in df.columns:
        missing = df[col].isna().sum()
        pct = 100 * missing / total_samples
        status = "✓" if pct < 5 else "⚠️"
        print(f"  {status} {col}: {missing} missing ({pct:.1f}%)")

# Text quality
print("\n--- Text quality ---")
if "text" in df.columns:
    empty_text = (df["text"].isna() | (df["text"].str.strip() == "")).sum()
    avg_text_len = df["text"].str.len().mean()
    print(f"  Empty/missing text: {empty_text}")
    print(f"  Average text length: {avg_text_len:.0f} chars")

# Recommendation
print(f"\n=== RECOMMENDATION ===")
print(f"Use {reliable_count} reliable samples for training")
print(f"Filter: df_train = df[df['is_reliable_label'] == True]")


## Cell 6: Premarket Risk Assessment

Premarket tweets have the highest leakage risk (technical indicators may use future data).


In [None]:
print("=" * 50)
print("PREMARKET RISK ASSESSMENT")
print("=" * 50)

if "session" in df.columns:
    session_dist = df["session"].value_counts()
    session_pct = df["session"].value_counts(normalize=True) * 100
    
    print("\nTweets by session:")
    for session, count in session_dist.items():
        pct = session_pct[session]
        risk = "⚠️ HIGH RISK" if session == "premarket" else ""
        print(f"  {session}: {count} ({pct:.1f}%) {risk}")
    
    # Visualization
    fig, ax = plt.subplots(figsize=(8, 4))
    colors_session = {'premarket': '#e74c3c', 'market': '#27ae60', 'afterhours': '#3498db'}
    session_dist.plot(kind='bar', ax=ax, 
                      color=[colors_session.get(x, '#95a5a6') for x in session_dist.index])
    ax.set_title('Tweets by Session')
    ax.set_xlabel('Session')
    ax.set_ylabel('Count')
    ax.tick_params(axis='x', rotation=0)
    plt.tight_layout()
    plt.show()
    
    # Premarket warning
    premarket_count = session_dist.get("premarket", 0)
    premarket_pct = 100 * premarket_count / len(df)
    
    print(f"\n=== PREMARKET ANALYSIS ===")
    print(f"Premarket tweets: {premarket_count} ({premarket_pct:.1f}%)")
    
    if premarket_pct > 10:
        print(f"\n⚠️  WARNING: {premarket_pct:.1f}% premarket tweets")
        print("   For these, technical indicators (RSI, volatility) use day T close (FUTURE DATA)")
        print("   Options:")
        print("   1. CONSERVATIVE: df_clean = df[df['session'] != 'premarket']")
        print("   2. ACCEPT: Minor leakage (~1%), indicators change slowly day-to-day")
    else:
        print("✓ Low premarket exposure, acceptable for training")
        
else:
    print("⚠️  'session' column not found in data")
    print("   Cannot assess premarket risk")


## Cell 7: Optional - Sentiment Signal Check

Quick check to see if there's any correlation between text sentiment and returns.

**Note:** Requires `transformers` library and may take a few minutes to run.


In [None]:
# Set to True to run sentiment analysis (requires transformers, takes ~5 mins)
RUN_SENTIMENT_CHECK = False

if RUN_SENTIMENT_CHECK:
    try:
        from transformers import pipeline
        
        print("Loading FinBERT sentiment model...")
        sentiment = pipeline("sentiment-analysis", model="yiyanghkust/finbert-tone")
        
        # Sample 100 tweets for quick check
        df_sample = df.dropna(subset=["text", "return_to_next_open"]).sample(min(100, len(df)), random_state=42)
        
        print(f"Analyzing {len(df_sample)} sample tweets...")
        
        def get_sentiment_score(text):
            try:
                result = sentiment(str(text)[:512])[0]
                label_map = {'positive': 1, 'neutral': 0, 'negative': -1}
                return label_map.get(result['label'].lower(), 0)
            except Exception:
                return 0
        
        df_sample["sentiment_score"] = df_sample["text"].apply(get_sentiment_score)
        
        # Calculate correlation
        corr = df_sample["sentiment_score"].corr(df_sample["return_to_next_open"])
        
        print(f"\n=== SENTIMENT SIGNAL CHECK ===")
        print(f"Sentiment-return correlation: {corr:.4f}")
        
        if abs(corr) > 0.05:
            print(f"✓ Positive signal detected! Correlation = {corr:.4f}")
        else:
            print(f"⚠️  Weak signal (correlation = {corr:.4f})")
            print("   This is expected - FinBERT needs fine-tuning on this data")
            
    except ImportError:
        print("⚠️  transformers library not installed")
        print("   Run: pip install transformers")
    except Exception as e:
        print(f"Error during sentiment check: {e}")
else:
    print("Sentiment check skipped (set RUN_SENTIMENT_CHECK = True to enable)")
    print("Note: This requires the transformers library and takes ~5 minutes")


## Summary

Final validation summary and recommendations.


In [None]:
print("=" * 60)
print("PHASE 0 VALIDATION SUMMARY")
print("=" * 60)

# Collect all checks
checks = []

# 1. Leakage check
leakage_ok = all(col not in NUMERICAL_FEATURES for col in FORBIDDEN_AS_FEATURES)
checks.append(("Feature Leakage", "PASS" if leakage_ok else "FAIL", leakage_ok))

# 2. Target availability
target_available = df[TARGET_COLUMN].notna().sum() / len(df) > 0.5
checks.append(("Target Availability", f"{100*df[TARGET_COLUMN].notna().sum()/len(df):.0f}%", target_available))

# 3. Class balance (warn if >70% in one class)
max_class = df[TARGET_COLUMN].value_counts(normalize=True).max()
balance_ok = max_class < 0.70
checks.append(("Class Balance", f"{100*max_class:.0f}% max", balance_ok))

# 4. Author diversity
top2_authors = df["author"].value_counts(normalize=True).head(2).sum()
author_ok = True  # Always OK if we use author embeddings
checks.append(("Author Embedding Needed", f"{100*top2_authors:.0f}% top 2", True))

# 5. Data quality
if "is_reliable_label" in df.columns:
    reliable_pct = df["is_reliable_label"].sum() / len(df)
else:
    reliable_pct = 1.0
quality_ok = reliable_pct > 0.5
checks.append(("Reliable Labels", f"{100*reliable_pct:.0f}%", quality_ok))

# Print summary table
print(f"\n{'Check':<25} {'Result':>15} {'Status':>10}")
print("-" * 52)
for name, result, ok in checks:
    status = "✓" if ok else "⚠️"
    print(f"{name:<25} {result:>15} {status:>10}")

# Overall verdict
all_pass = all(ok for _, _, ok in checks)
print("\n" + "=" * 52)
if all_pass:
    print("✓ ALL CHECKS PASSED - Data is ready for training!")
else:
    print("⚠️  Some checks need attention (see warnings above)")

# Recommended filtering
print(f"\n=== RECOMMENDED DATA FILTERING ===")
print(f"df_train = df[df['is_reliable_label'] == True].dropna(subset=['{TARGET_COLUMN}'])")
print(f"Expected training samples: ~{int(reliable_pct * df[TARGET_COLUMN].notna().sum())}")
