# Leg 2: News Model Training

This notebook is an optimized version of `02_train_leg2_news.ipynb`


**Optimizations**:
1. **Device**: Uses `mps` (Metal Performance Shaders) for GPU acceleration
2. **Batch Size**: Increased to 128
3. **Epochs**: Increased to 10
4. **Data Loading**: Increased `num_workers` to 8 and enabled `pin_memory`

**Workflow**:
1. Load preprocessed news data
2. Wrap the PyTorch model (`Leg2HANWrapper`) with MPS support
3. Run Walk-Forward Cross-Validation
4. Save raw OOF predictions (`/kaggle/working/leg2_oof_preds.pkl`)
5. Evaluate raw performance
6. Upload results to Kaggle Hub

In [None]:
import sys
import os
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import ast
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import mean_squared_error
from scipy.stats import pearsonr, spearmanr

# Added for Kaggle Upload
import kagglehub
from datetime import datetime

REPO_URL = "https://github.com/brianrp09232000/multimodal-eq-sizing.git"
REPO_DIR = "/kaggle/working/multimodal-eq-sizing"

if not os.path.exists(REPO_DIR):
    print(f"Cloning {REPO_URL} into {REPO_DIR}...")
    !git clone {REPO_URL} {REPO_DIR}
else:
    print("Repository already exists. Pulling latest changes...")
    !cd {REPO_DIR} && git pull

print("Fixing protobuf version...")
!pip install "protobuf==3.20.3" 

print("Installing requirements...")
!pip install -r {REPO_DIR}/requirements.txt

src_path = os.path.join(REPO_DIR, "src")
if os.path.exists(src_path):
    print(f"src folder found at: {src_path}")
    if src_path not in sys.path:
        sys.path.append(src_path)
else:
    print("folder MISSING. Check git clone output.")

# Add source to path for imports
current_dir = Path(os.getcwd())
# Try standard Kaggle repo path first, then local fallback
kaggle_repo_path = Path("/kaggle/working/multimodal-eq-sizing/src")
local_repo_path = current_dir.parent / 'src'

if kaggle_repo_path.exists():
    sys.path.append(str(kaggle_repo_path.parent))
    sys.path.append(str(kaggle_repo_path))
elif local_repo_path.exists() and str(local_repo_path) not in sys.path:
    sys.path.append(str(local_repo_path))

# Project Imports
try:
    from src.models.HAN_l2 import FinbertHAN
    from src.data.loaders import NewsDataset, han_collate_fn
    from src.utils.cv import generate_yearly_oof
except ImportError:
    # Fallback if src is directly in path
    from models.HAN_l2 import FinbertHAN
    from data.loaders import NewsDataset, han_collate_fn
    from utils.cv import generate_yearly_oof

if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
    print("Using Apple MPS acceleration")
elif torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    print("Using CUDA acceleration")
else:
    DEVICE = torch.device("cpu")
    print("Using CPU")

SEED = 42
BATCH_SIZE = 128 
EPOCHS = 10 
LR = 1e-4
MAX_GRAD_NORM = 1.0
NUM_WORKERS = 8
PIN_MEMORY = True
# Set seeds
torch.manual_seed(SEED)
np.random.seed(SEED)

print(f"Device: {DEVICE}")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Epochs: {EPOCHS}")

## 1. Define Model Wrapper

We update the wrapper to use `pin_memory` and `NUM_WORKERS` in the DataLoader.

In [None]:
class Leg2HANWrapper:
    """
    Scikit-learn compatible wrapper for the FinbertHAN PyTorch model
    Produces raw regression outputs
    Optimized for Apple Silicon via DataLoader params
    """
    def __init__(self, 
                 batch_size=BATCH_SIZE, 
                 epochs=EPOCHS, 
                 lr=LR, 
                 device=DEVICE):
        self.batch_size = batch_size
        self.epochs = epochs
        self.lr = lr
        self.device = device
        self.model = None
        self.train_loss_history = []

    def fit(self, X: pd.DataFrame, y: pd.Series = None):
        # Prepare data
        train_df = X.copy()
        if y is not None:
            train_df['target'] = y.values
        
        dataset = NewsDataset(train_df)
        loader = DataLoader(
            dataset, 
            batch_size=self.batch_size, 
            shuffle=True, 
            collate_fn=han_collate_fn,
            num_workers=NUM_WORKERS,
            pin_memory=PIN_MEMORY
        )

        self.model = FinbertHAN(aux_dim=4).to(self.device)
        optimizer = optim.AdamW(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.lr)
        criterion = nn.MSELoss()

        self.model.train()
        self.train_loss_history = []

        for epoch in range(self.epochs):
            epoch_loss = 0.0
            count = 0
            for batch in loader:
                input_ids = batch['input_ids'].to(self.device)
                att_mask = batch['attention_mask'].to(self.device)
                doc_lens = batch['doc_lengths']
                time_gaps = batch['time_gaps'].to(self.device)
                aux_feats = batch['aux_features'].to(self.device)
                news_mask = batch['news_mask'].to(self.device)
                targets = batch['targets'].to(self.device)

                optimizer.zero_grad()
                preds, _, _ = self.model(input_ids, att_mask, doc_lens, time_gaps, aux_feats, news_mask)
                loss = criterion(preds, targets)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), MAX_GRAD_NORM)
                optimizer.step()

                epoch_loss += loss.item() * targets.size(0)
                count += targets.size(0)
            
            avg_loss = epoch_loss / max(1, count)
            self.train_loss_history.append(avg_loss)
            print(f"Epoch {epoch+1}/{self.epochs} - Loss: {avg_loss:.6f}")
        return self

    def predict(self, X: pd.DataFrame) -> np.ndarray:
        self.model.eval()
        test_df = X.copy()
        if 'target' not in test_df.columns:
            test_df['target'] = 0.0 
            
        dataset = NewsDataset(test_df)
        loader = DataLoader(
            dataset, 
            batch_size=self.batch_size, 
            shuffle=False, 
            collate_fn=han_collate_fn,
            num_workers=NUM_WORKERS,
            pin_memory=PIN_MEMORY
        )

        all_preds = []
        with torch.no_grad():
            for batch in loader:
                input_ids = batch['input_ids'].to(self.device)
                att_mask = batch['attention_mask'].to(self.device)
                doc_lens = batch['doc_lengths']
                time_gaps = batch['time_gaps'].to(self.device)
                aux_feats = batch['aux_features'].to(self.device)
                news_mask = batch['news_mask'].to(self.device)
                preds, _, _ = self.model(input_ids, att_mask, doc_lens, time_gaps, aux_feats, news_mask)
                all_preds.append(preds.cpu().numpy())
        
        return np.concatenate(all_preds).flatten()

## 2. Load Data

Same loading logic as the standard notebook.

In [None]:
# Define search directories for Kaggle
SEARCH_DIRS = [
    "/kaggle/input", 
    "/kaggle/working", 
    "../src/data/datasets" # Fallback for local
]

def find_file(filename, search_dirs):
    print(f"Searching for {filename}...")
    for directory in search_dirs:
        # Check direct path
        path = Path(directory) / filename
        if path.exists():
            return path
        # Check recursive
        if Path(directory).exists():
            try:
                matches = list(Path(directory).rglob(filename))
                if matches:
                    return matches[0]
            except Exception as e:
                print(f"Access error in {directory}: {e}")
    return None

NEWS_FILENAME = "filtered_news_dataset.csv"
PRICES_FILENAME = "prices_dataset.csv"

NEWS_PATH = find_file(NEWS_FILENAME, SEARCH_DIRS)
PRICES_PATH = find_file(PRICES_FILENAME, SEARCH_DIRS)

def robust_literal_eval(val):
    if pd.isna(val) or val == "" or str(val).lower() == 'nan':
        return []
    try:
        return ast.literal_eval(val)
    except (ValueError, SyntaxError):
        return [str(val)]

if NEWS_PATH:
    print(f"Loading News from {NEWS_PATH}...")
    news_converters = {
        'sentences': robust_literal_eval,
        'Article_title': robust_literal_eval,
        'entities': robust_literal_eval,
        'entities_today': robust_literal_eval
    }
    df_news = pd.read_csv(NEWS_PATH, converters=news_converters)
    
    if 'sentences' in df_news.columns and 'Article_title' not in df_news.columns:
        df_news = df_news.rename(columns={'sentences': 'Article_title'})
    if 'Stock_symbol' in df_news.columns and 'ticker' not in df_news.columns:
        df_news = df_news.rename(columns={'Stock_symbol': 'ticker'})
    df_news['Date'] = pd.to_datetime(df_news['Date'], utc=True)
    print(f"News Loaded. Shape: {df_news.shape}.")
else:
    raise FileNotFoundError(f"News file not found in {SEARCH_DIRS}")


if PRICES_PATH:
    print(f"Loading Prices from {PRICES_PATH}...")
    df_prices = pd.read_csv(PRICES_PATH)
    df_prices['Date'] = pd.to_datetime(df_prices['Date'], utc=True)
    print(f"Prices Loaded. Shape: {df_prices.shape}")
else:
    raise FileNotFoundError(f"Prices file not found in {SEARCH_DIRS}")

df_full = pd.merge(df_prices, df_news, on=['Date', 'ticker'], how='left')
df_full = df_full.sort_values('Date').reset_index(drop=True)

target_col = 'excess_return'

if target_col in df_full.columns:
    df_full = df_full.dropna(subset=[target_col])
    X = df_full.drop(columns=[target_col])
    y = df_full[target_col]
    dates = df_full['Date']
    print(f"Data Ready. Final Shape: {df_full.shape}")
else:
    raise KeyError(f"Target column '{target_col}' not found.")

## 3. Walk-Forward Cross Validation (OOF Generation)

Uses `generate_yearly_oof` with the optimized factory.

In [None]:
# Define the optimized model factory
def model_factory():
    return Leg2HANWrapper(batch_size=BATCH_SIZE, epochs=EPOCHS, lr=LR, device=DEVICE)

# Run OOF Generation
# We rely on the speedup inside the model training for performance gains
if 'X' in locals():
    oof_preds_raw, oof_targets, fold_stats = generate_yearly_oof(
        model_factory=model_factory,
        X=X,
        y=y,
        dates=dates,
        min_train_years=2,
        n_jobs=1 
    )
    
    print(f"OOF Prediction Complete. Generated {len(oof_preds_raw)} predictions.")
    print(pd.DataFrame(fold_stats))

## 4. Save Raw OOF Predictions

Saves to `/kaggle/working/leg2_oof_preds.pkl` - this will overwrite any existing file

In [None]:
if 'oof_preds_raw' in locals() and len(oof_preds_raw) > 0:
    # 1. Reconstruct Metadata (Date, ticker)
    unique_years = sorted(dates.dt.year.unique())
    min_train_years = 2
    val_years = unique_years[min_train_years:]
    
    oof_dfs = []
    
    for val_year in val_years:
        val_mask = dates.dt.year == val_year
        subset = df_full.loc[val_mask, ['Date', 'ticker', target_col]].copy()
        oof_dfs.append(subset)
        
    oof_df = pd.concat(oof_dfs)
    oof_df['prediction'] = oof_preds_raw
    
    print(f"OOF Metadata Shape: {oof_df.shape}")
    print(f"OOF Preds Length: {len(oof_preds_raw)}")
    assert len(oof_df) == len(oof_preds_raw), "Mismatch between OOF metadata and predictions length"
    
    # 3. Save to Kaggle Working Directory
    SAVE_PATH = "/kaggle/working/leg2_oof_preds.pkl"
    if not os.access(os.path.dirname(SAVE_PATH), os.W_OK):
        # Fallback to local
        SAVE_PATH = "../data/leg2_oof_preds.pkl"
    
    os.makedirs(os.path.dirname(SAVE_PATH), exist_ok=True)
    
    oof_df.to_pickle(SAVE_PATH)
    print(f"Saved OOF predictions to: {SAVE_PATH}")
    print(oof_df.head())

## 5. Raw Performance Metrics

In [None]:
if 'oof_preds_raw' in locals() and len(oof_preds_raw) > 0:
    mse_raw = mean_squared_error(oof_targets, oof_preds_raw)
    ic_raw, _ = pearsonr(oof_preds_raw, oof_targets)
    
    print(f"Leg 2 Results")
    print(f"MSE Raw: {mse_raw:.6f}")
    print(f"IC Raw: {ic_raw:.4f}")
    
    plt.figure(figsize=(8, 6))
    indices = np.random.choice(len(oof_targets), min(5000, len(oof_targets)), replace=False)
    
    sns.regplot(x=oof_preds_raw[indices], y=oof_targets[indices], 
                scatter_kws={'alpha':0.3, 's': 10}, line_kws={'color':'red'})
    plt.title(f"Raw Scores vs Excess Return (IC={ic_raw:.3f})")
    plt.xlabel("Raw NN Output")
    plt.ylabel("True Excess Return")
    plt.show()

In [None]:
handle = "iinarixf0x/leg2-news-model"
local_dataset_dir = f"/kaggle/working/"
current_date = datetime.today().strftime("%Y-%m-%d")

try:
    kagglehub.dataset_upload(handle, local_dataset_dir, version_notes=f"Dataset {current_date}")
    print("Upload successful")
except Exception as e:
    print(f"Upload failed: {e}")
