# FinBERT Tweet Classifier Training - Google Colab

This notebook allows you to train the FinBERT-based multi-modal tweet classifier on Google Colab with GPU acceleration.

**Before running:**
1. Go to `Runtime` → `Change runtime type` → Select `GPU` (T4 recommended)
2. Upload your training data CSV file when prompted

**Requirements:**
- Enriched CSV data file (e.g., `15-dec-enrich7.csv`)
- Google Colab with GPU enabled


## 1. Setup Environment


In [None]:
# Check GPU availability
!nvidia-smi


In [None]:
# Install required packages
%pip install -q transformers>=4.30.0 accelerate>=0.26.0 datasets>=2.14.0 scikit-learn>=1.3.0 torch>=2.0.0 seaborn


In [None]:
# Import libraries
import json
import logging
import hashlib
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import joblib
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, f1_score
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Dataset
from transformers import BertModel, BertTokenizer, Trainer, TrainingArguments

# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")


## 2. Configuration


In [None]:
# =============================================================================
# CONFIGURATION - Modify these settings as needed
# =============================================================================

# Target Configuration
TARGET_COLUMN = "label_1d_3class"  # 1-day labels
LABEL_MAP = {"SELL": 0, "HOLD": 1, "BUY": 2}
LABEL_MAP_INV = {0: "SELL", 1: "HOLD", 2: "BUY"}
NUM_CLASSES = 3

# Feature Configuration
NUMERICAL_FEATURES = [
    # Core indicators (baseline)
    "volatility_7d",
    "relative_volume",
    "rsi_14",
    "distance_from_ma_20",
    # Phase 2: Multi-period momentum
    "return_5d",
    "return_20d",
    # Phase 2: Trend confirmation
    "above_ma_20",
    "slope_ma_20",
    # Phase 2: Shock/Gap features
    "gap_open",
    "intraday_range",
]

TEXT_COLUMN = "text"

# Model Configuration
FINBERT_MODEL_NAME = "yiyanghkust/finbert-tone"
MAX_TEXT_LENGTH = 128

# Embedding dimensions
AUTHOR_EMBEDDING_DIM = 16
CATEGORY_EMBEDDING_DIM = 8
MARKET_REGIME_EMBEDDING_DIM = 4
SECTOR_EMBEDDING_DIM = 8
MARKET_CAP_EMBEDDING_DIM = 4
NUMERICAL_HIDDEN_DIM = 32

# Training Defaults - ADJUST THESE FOR COLAB
DEFAULT_BATCH_SIZE = 16  # Increase to 32 if you have a T4/V100
DEFAULT_LEARNING_RATE = 2e-5
DEFAULT_NUM_EPOCHS = 5
DEFAULT_DROPOUT = 0.3
DEFAULT_WARMUP_RATIO = 0.1
DEFAULT_WEIGHT_DECAY = 0.01

# Data Split Configuration
DEFAULT_TEST_SIZE = 0.15
DEFAULT_VAL_SIZE = 0.15
RANDOM_SEED = 42


## 3. Upload Data


In [None]:
# Option 1: Upload from local machine
from google.colab import files

print("Upload your enriched CSV file:")
uploaded = files.upload()
DATA_FILE = list(uploaded.keys())[0]
print(f"\nUploaded: {DATA_FILE}")


In [None]:
# Option 2: Mount Google Drive (uncomment if you prefer this method)
# from google.colab import drive
# drive.mount('/content/drive')
# DATA_FILE = '/content/drive/MyDrive/path/to/your/15-dec-enrich7.csv'


## 4. Data Loading & Processing Utilities


In [None]:
def load_enriched_data(file_path: str) -> pd.DataFrame:
    """Load enriched tweet data from CSV."""
    df = pd.read_csv(file_path)
    logger.info(f"Loaded {len(df)} rows from {file_path}")
    return df


def filter_reliable(df: pd.DataFrame) -> pd.DataFrame:
    """Filter to reliable samples with valid targets."""
    # Must have target label
    df = df[df[TARGET_COLUMN].notna()].copy()
    
    # Must have text
    df = df[df[TEXT_COLUMN].notna()].copy()
    df = df[df[TEXT_COLUMN].str.len() > 0].copy()
    
    return df


def split_by_hash(
    df: pd.DataFrame,
    test_size: float = DEFAULT_TEST_SIZE,
    val_size: float = DEFAULT_VAL_SIZE,
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """Split data by tweet hash for reproducibility."""
    if "tweet_hash" not in df.columns:
        # Create hash from text if not present
        df = df.copy()
        df["tweet_hash"] = df[TEXT_COLUMN].apply(
            lambda x: hashlib.md5(str(x).encode()).hexdigest()
        )
    
    # Use first 8 chars of hash as integer for deterministic split
    df = df.copy()
    df["hash_int"] = df["tweet_hash"].apply(lambda x: int(x[:8], 16) % 1000)
    
    # Split boundaries
    test_thresh = int(test_size * 1000)
    val_thresh = int((test_size + val_size) * 1000)
    
    df_test = df[df["hash_int"] < test_thresh].drop(columns=["hash_int"])
    df_val = df[(df["hash_int"] >= test_thresh) & (df["hash_int"] < val_thresh)].drop(columns=["hash_int"])
    df_train = df[df["hash_int"] >= val_thresh].drop(columns=["hash_int"])
    
    return df_train, df_val, df_test


def compute_class_weights(labels: pd.Series) -> Dict[int, float]:
    """Compute inverse frequency class weights."""
    if labels.dtype == object:
        labels = labels.map(LABEL_MAP)
    
    counts = labels.value_counts()
    total = len(labels)
    n_classes = len(counts)
    
    weights = {}
    for cls in range(n_classes):
        if cls in counts.index:
            weights[cls] = total / (n_classes * counts[cls])
        else:
            weights[cls] = 1.0
    
    return weights


def weights_to_tensor(weights: Dict[int, float]) -> torch.Tensor:
    """Convert weight dict to tensor."""
    return torch.tensor([weights[i] for i in range(len(weights))], dtype=torch.float32)


## 5. Dataset Class


In [None]:
class TweetDataset(Dataset):
    """PyTorch Dataset for tweet classification with multi-modal features."""

    def __init__(
        self,
        texts: Union[pd.Series, List[str]],
        numerical_features: Union[pd.DataFrame, np.ndarray],
        author_indices: Union[pd.Series, np.ndarray],
        category_indices: Union[pd.Series, np.ndarray],
        market_regime_indices: Union[pd.Series, np.ndarray],
        sector_indices: Union[pd.Series, np.ndarray],
        market_cap_indices: Union[pd.Series, np.ndarray],
        labels: Union[pd.Series, np.ndarray],
        tokenizer,
        max_length: int = MAX_TEXT_LENGTH,
    ):
        if isinstance(texts, pd.Series):
            texts = texts.tolist()

        self.encodings = tokenizer(
            texts,
            truncation=True,
            padding="max_length",
            max_length=max_length,
            return_tensors="pt",
        )

        if isinstance(numerical_features, pd.DataFrame):
            numerical_features = numerical_features.values
        self.numerical = torch.tensor(numerical_features, dtype=torch.float32)

        if isinstance(author_indices, pd.Series):
            author_indices = author_indices.values
        if isinstance(category_indices, pd.Series):
            category_indices = category_indices.values
        if isinstance(market_regime_indices, pd.Series):
            market_regime_indices = market_regime_indices.values
        if isinstance(sector_indices, pd.Series):
            sector_indices = sector_indices.values
        if isinstance(market_cap_indices, pd.Series):
            market_cap_indices = market_cap_indices.values

        self.author_idx = torch.tensor(author_indices, dtype=torch.long)
        self.category_idx = torch.tensor(category_indices, dtype=torch.long)
        self.market_regime_idx = torch.tensor(market_regime_indices, dtype=torch.long)
        self.sector_idx = torch.tensor(sector_indices, dtype=torch.long)
        self.market_cap_idx = torch.tensor(market_cap_indices, dtype=torch.long)

        if isinstance(labels, pd.Series):
            labels = labels.values
        if isinstance(labels[0], str):
            labels = np.array([LABEL_MAP[label] for label in labels])
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        return {
            "input_ids": self.encodings["input_ids"][idx],
            "attention_mask": self.encodings["attention_mask"][idx],
            "numerical": self.numerical[idx],
            "author_idx": self.author_idx[idx],
            "category_idx": self.category_idx[idx],
            "market_regime_idx": self.market_regime_idx[idx],
            "sector_idx": self.sector_idx[idx],
            "market_cap_idx": self.market_cap_idx[idx],
            "labels": self.labels[idx],
        }


In [None]:
def create_categorical_encodings(df: pd.DataFrame) -> Dict[str, Any]:
    """Create mappings from categorical values to indices."""
    authors = df["author"].unique().tolist()
    categories = df["category"].unique().tolist()
    market_regimes = df["market_regime"].fillna("calm").unique().tolist()
    sectors = df["sector"].fillna("Other").unique().tolist()
    market_caps = df["market_cap_bucket"].fillna("unknown").unique().tolist()

    return {
        "author_to_idx": {auth: i for i, auth in enumerate(authors)},
        "category_to_idx": {cat: i for i, cat in enumerate(categories)},
        "market_regime_to_idx": {reg: i for i, reg in enumerate(market_regimes)},
        "sector_to_idx": {sec: i for i, sec in enumerate(sectors)},
        "market_cap_to_idx": {cap: i for i, cap in enumerate(market_caps)},
        "num_authors": len(authors),
        "num_categories": len(categories),
        "num_market_regimes": len(market_regimes),
        "num_sectors": len(sectors),
        "num_market_caps": len(market_caps),
    }


def encode_categorical(
    df: pd.DataFrame,
    author_to_idx: Dict[str, int],
    category_to_idx: Dict[str, int],
    market_regime_to_idx: Dict[str, int],
    sector_to_idx: Dict[str, int],
    market_cap_to_idx: Dict[str, int],
) -> pd.DataFrame:
    """Encode categorical columns to indices."""
    df = df.copy()
    df["market_regime"] = df["market_regime"].fillna("calm")
    df["sector"] = df["sector"].fillna("Other")
    df["market_cap_bucket"] = df["market_cap_bucket"].fillna("unknown")

    df["author_idx"] = df["author"].map(lambda x: author_to_idx.get(x, 0))
    df["category_idx"] = df["category"].map(lambda x: category_to_idx.get(x, 0))
    df["market_regime_idx"] = df["market_regime"].map(lambda x: market_regime_to_idx.get(x, 0))
    df["sector_idx"] = df["sector"].map(lambda x: sector_to_idx.get(x, 0))
    df["market_cap_idx"] = df["market_cap_bucket"].map(lambda x: market_cap_to_idx.get(x, 0))

    return df


def create_dataset_from_df(
    df: pd.DataFrame,
    tokenizer,
    encodings: Dict[str, Any],
    scaler: Optional[StandardScaler] = None,
    fit_scaler: bool = False,
) -> Tuple[TweetDataset, StandardScaler]:
    """Create TweetDataset from DataFrame."""
    df = encode_categorical(
        df,
        encodings["author_to_idx"],
        encodings["category_to_idx"],
        encodings["market_regime_to_idx"],
        encodings["sector_to_idx"],
        encodings["market_cap_to_idx"],
    )

    numerical = df[NUMERICAL_FEATURES].fillna(0).values

    if scaler is None:
        scaler = StandardScaler()
        fit_scaler = True

    if fit_scaler:
        numerical = scaler.fit_transform(numerical)
    else:
        numerical = scaler.transform(numerical)

    dataset = TweetDataset(
        texts=df[TEXT_COLUMN],
        numerical_features=numerical,
        author_indices=df["author_idx"],
        category_indices=df["category_idx"],
        market_regime_indices=df["market_regime_idx"],
        sector_indices=df["sector_idx"],
        market_cap_indices=df["market_cap_idx"],
        labels=df[TARGET_COLUMN],
        tokenizer=tokenizer,
    )

    return dataset, scaler


## 6. Model Definition


In [None]:
class FinBERTMultiModal(nn.Module):
    """FinBERT with numerical + categorical feature fusion."""

    def __init__(
        self,
        num_numerical_features: int,
        num_authors: int,
        num_categories: int,
        num_market_regimes: int = 5,
        num_sectors: int = 12,
        num_market_caps: int = 5,
        num_classes: int = NUM_CLASSES,
        finbert_model: str = FINBERT_MODEL_NAME,
        freeze_bert: bool = False,
        dropout: float = DEFAULT_DROPOUT,
    ):
        super().__init__()

        # Store config
        self.num_numerical_features = num_numerical_features
        self.num_authors = num_authors
        self.num_categories = num_categories
        self.num_market_regimes = num_market_regimes
        self.num_sectors = num_sectors
        self.num_market_caps = num_market_caps
        self.num_classes = num_classes
        self.finbert_model_name = finbert_model
        self.freeze_bert = freeze_bert
        self.dropout_prob = dropout

        # FinBERT encoder
        self.bert = BertModel.from_pretrained(finbert_model)
        if freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False

        bert_hidden_size = self.bert.config.hidden_size  # 768

        # Categorical embeddings
        self.author_embedding = nn.Embedding(num_authors, AUTHOR_EMBEDDING_DIM)
        self.category_embedding = nn.Embedding(num_categories, CATEGORY_EMBEDDING_DIM)
        self.market_regime_embedding = nn.Embedding(num_market_regimes, MARKET_REGIME_EMBEDDING_DIM)
        self.sector_embedding = nn.Embedding(num_sectors, SECTOR_EMBEDDING_DIM)
        self.market_cap_embedding = nn.Embedding(num_market_caps, MARKET_CAP_EMBEDDING_DIM)

        # Numerical feature encoder
        self.numerical_encoder = nn.Sequential(
            nn.Linear(num_numerical_features, 64),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(64, NUMERICAL_HIDDEN_DIM),
            nn.ReLU(),
        )

        # Fusion + classifier
        fusion_size = (
            bert_hidden_size
            + NUMERICAL_HIDDEN_DIM
            + AUTHOR_EMBEDDING_DIM
            + CATEGORY_EMBEDDING_DIM
            + MARKET_REGIME_EMBEDDING_DIM
            + SECTOR_EMBEDDING_DIM
            + MARKET_CAP_EMBEDDING_DIM
        )

        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(fusion_size, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, num_classes),
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        numerical: torch.Tensor,
        author_idx: torch.Tensor,
        category_idx: torch.Tensor,
        market_regime_idx: torch.Tensor,
        sector_idx: torch.Tensor,
        market_cap_idx: torch.Tensor,
        labels: Optional[torch.Tensor] = None,
    ) -> Dict[str, torch.Tensor]:
        # Get BERT [CLS] embedding
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_embedding = bert_output.last_hidden_state[:, 0, :]

        # Encode features
        num_embedding = self.numerical_encoder(numerical)
        author_emb = self.author_embedding(author_idx)
        category_emb = self.category_embedding(category_idx)
        regime_emb = self.market_regime_embedding(market_regime_idx)
        sector_emb = self.sector_embedding(sector_idx)
        market_cap_emb = self.market_cap_embedding(market_cap_idx)

        # Fusion
        combined = torch.cat(
            [cls_embedding, num_embedding, author_emb, category_emb, regime_emb, sector_emb, market_cap_emb],
            dim=1,
        )

        # Classification
        logits = self.classifier(combined)

        output = {"logits": logits}
        if labels is not None:
            loss = F.cross_entropy(logits, labels)
            output["loss"] = loss

        return output

    def get_config(self) -> Dict[str, Any]:
        return {
            "num_numerical_features": self.num_numerical_features,
            "num_authors": self.num_authors,
            "num_categories": self.num_categories,
            "num_market_regimes": self.num_market_regimes,
            "num_sectors": self.num_sectors,
            "num_market_caps": self.num_market_caps,
            "num_classes": self.num_classes,
            "finbert_model": self.finbert_model_name,
            "freeze_bert": self.freeze_bert,
            "dropout": self.dropout_prob,
        }


## 7. Trainer & Metrics


In [None]:
class WeightedTrainer(Trainer):
    """Custom Trainer with class-weighted cross-entropy loss."""

    def __init__(self, class_weights: torch.Tensor, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.class_weights = class_weights

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs["logits"]

        reduction = "sum" if num_items_in_batch is not None else "mean"
        loss = F.cross_entropy(
            logits,
            labels,
            weight=self.class_weights.to(logits.device),
            reduction=reduction,
        )

        if num_items_in_batch is not None:
            loss = loss / num_items_in_batch

        return (loss, outputs) if return_outputs else loss


def compute_metrics(eval_pred):
    """Compute evaluation metrics."""
    predictions, labels = eval_pred
    preds = np.argmax(predictions, axis=1)

    return {
        "accuracy": accuracy_score(labels, preds),
        "f1_macro": f1_score(labels, preds, average="macro"),
        "f1_weighted": f1_score(labels, preds, average="weighted"),
    }


## 8. Training Function


In [None]:
def train(
    data_path: str,
    output_dir: str = "./model_output",
    num_epochs: int = DEFAULT_NUM_EPOCHS,
    batch_size: int = DEFAULT_BATCH_SIZE,
    learning_rate: float = DEFAULT_LEARNING_RATE,
    freeze_bert: bool = False,
    dropout: float = DEFAULT_DROPOUT,
):
    """Train the FinBERT multi-modal tweet classifier."""
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # Step 1: Load and filter data
    logger.info(f"Loading data from {data_path}")
    df = load_enriched_data(data_path)
    df_reliable = filter_reliable(df)
    logger.info(f"After filtering: {len(df_reliable)} reliable samples")

    # Step 2: Split by hash
    logger.info("Splitting data...")
    df_train, df_val, df_test = split_by_hash(df_reliable)
    logger.info(f"Train: {len(df_train)}, Val: {len(df_val)}, Test: {len(df_test)}")

    # Step 3: Create categorical encodings
    logger.info("Creating categorical encodings...")
    encodings = create_categorical_encodings(df_train)
    logger.info(f"Authors: {encodings['num_authors']}, Categories: {encodings['num_categories']}")

    # Step 4: Initialize tokenizer
    logger.info(f"Loading tokenizer from {FINBERT_MODEL_NAME}")
    tokenizer = BertTokenizer.from_pretrained(FINBERT_MODEL_NAME)

    # Step 5: Create datasets
    logger.info("Creating training dataset...")
    train_dataset, scaler = create_dataset_from_df(df_train, tokenizer, encodings, fit_scaler=True)
    logger.info(f"Training dataset: {len(train_dataset)} samples")

    logger.info("Creating validation dataset...")
    val_dataset, _ = create_dataset_from_df(df_val, tokenizer, encodings, scaler=scaler, fit_scaler=False)
    logger.info(f"Validation dataset: {len(val_dataset)} samples")

    # Step 6: Compute class weights
    logger.info("Computing class weights...")
    class_weights = compute_class_weights(df_train[TARGET_COLUMN])
    class_weights_tensor = weights_to_tensor(class_weights)
    for cls, weight in class_weights.items():
        logger.info(f"  Class {LABEL_MAP_INV[cls]}: weight={weight:.3f}")

    # Step 7: Initialize model
    logger.info("Initializing FinBERTMultiModal model...")
    model = FinBERTMultiModal(
        num_numerical_features=len(NUMERICAL_FEATURES),
        num_authors=encodings["num_authors"],
        num_categories=encodings["num_categories"],
        num_market_regimes=encodings["num_market_regimes"],
        num_sectors=encodings["num_sectors"],
        num_market_caps=encodings["num_market_caps"],
        freeze_bert=freeze_bert,
        dropout=dropout,
    )

    if freeze_bert:
        logger.info("BERT parameters are FROZEN")
    else:
        logger.info("BERT parameters are TRAINABLE (full fine-tuning)")

    # Step 8: Create training arguments
    training_args = TrainingArguments(
        output_dir=str(output_dir),
        eval_strategy="epoch",
        save_strategy="epoch",
        learning_rate=learning_rate,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size * 2,
        num_train_epochs=num_epochs,
        weight_decay=DEFAULT_WEIGHT_DECAY,
        warmup_ratio=DEFAULT_WARMUP_RATIO,
        load_best_model_at_end=True,
        metric_for_best_model="f1_macro",
        greater_is_better=True,
        fp16=torch.cuda.is_available(),  # Enable mixed precision on GPU
        logging_steps=50,
        save_total_limit=2,
        report_to="none",
        remove_unused_columns=False,
    )

    # Step 9: Initialize trainer
    logger.info("Initializing WeightedTrainer...")
    trainer = WeightedTrainer(
        class_weights=class_weights_tensor,
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics,
    )

    # Step 10: Train!
    logger.info("Starting training...")
    trainer.train()

    # Step 11: Save artifacts
    logger.info(f"Saving model to {output_dir}")
    trainer.save_model(str(output_dir / "final"))

    # Save scaler
    joblib.dump(scaler, output_dir / "scaler.pkl")
    # Save encodings
    joblib.dump(encodings, output_dir / "encodings.pkl")
    # Save model config
    with open(output_dir / "model_config.json", "w") as f:
        json.dump(model.get_config(), f, indent=2)

    logger.info("Saving preprocessing artifacts...")

    # Step 12: Final evaluation
    logger.info("Running final evaluation on validation set...")
    eval_results = trainer.evaluate()
    logger.info("Validation results:")
    for key, value in eval_results.items():
        logger.info(f"  {key}: {value:.4f}")

    logger.info("Training complete!")
    return model, trainer, encodings, scaler, df_test


In [None]:
# Training parameters - modify as needed
TRAINING_CONFIG = {
    "data_path": DATA_FILE,
    "output_dir": "./finbert_tweet_classifier",
    "num_epochs": 5,
    "batch_size": 16,  # Increase to 32 for T4/V100 GPU
    "learning_rate": 2e-5,
    "freeze_bert": False,  # Set to True for faster training (but lower accuracy)
    "dropout": 0.3,
}

print("Training Configuration:")
for k, v in TRAINING_CONFIG.items():
    print(f"  {k}: {v}")


In [None]:
# Run training
model, trainer, encodings, scaler, df_test = train(**TRAINING_CONFIG)


## 10. Evaluate on Test Set


In [None]:
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

# Create test dataset
tokenizer = BertTokenizer.from_pretrained(FINBERT_MODEL_NAME)
test_dataset, _ = create_dataset_from_df(df_test, tokenizer, encodings, scaler=scaler, fit_scaler=False)

# Get predictions
predictions = trainer.predict(test_dataset)
preds = np.argmax(predictions.predictions, axis=1)
labels = predictions.label_ids

# Classification report
print("\nClassification Report:")
print(classification_report(labels, preds, target_names=["SELL", "HOLD", "BUY"]))

# Confusion matrix
cm = confusion_matrix(labels, preds)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=["SELL", "HOLD", "BUY"], yticklabels=["SELL", "HOLD", "BUY"])
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.title("Confusion Matrix")
plt.show()


## 11. Download Trained Model


In [None]:
# Zip and download the model
import shutil

output_dir = "./finbert_tweet_classifier"
zip_path = "finbert_tweet_classifier.zip"

shutil.make_archive("finbert_tweet_classifier", "zip", output_dir)
print(f"Model zipped to {zip_path}")

# Download
files.download(zip_path)


In [None]:
# Alternative: Save to Google Drive
# from google.colab import drive
# drive.mount('/content/drive')
# !cp -r ./finbert_tweet_classifier /content/drive/MyDrive/

# ============================================================
# Load trained model for inference (example function)
# ============================================================

def load_trained_model(model_dir: str):
    """Load a trained model for inference."""
    model_dir = Path(model_dir)
    
    # Load config
    with open(model_dir / "model_config.json") as f:
        config = json.load(f)
    
    # Load encodings
    encodings = joblib.load(model_dir / "encodings.pkl")
    
    # Load scaler
    scaler = joblib.load(model_dir / "scaler.pkl")
    
    # Initialize model with config
    model = FinBERTMultiModal(
        num_numerical_features=config["num_numerical_features"],
        num_authors=config["num_authors"],
        num_categories=config["num_categories"],
        num_market_regimes=config["num_market_regimes"],
        num_sectors=config["num_sectors"],
        num_market_caps=config["num_market_caps"],
        freeze_bert=config["freeze_bert"],
        dropout=config["dropout"],
    )
    
    # Load weights
    final_dir = model_dir / "final"
    weight_file = final_dir / "pytorch_model.bin"
    if weight_file.exists():
        state_dict = torch.load(weight_file, map_location="cpu")
        model.load_state_dict(state_dict)
    
    model.eval()
    return model, encodings, scaler


# Example usage:
# model, encodings, scaler = load_trained_model("./finbert_tweet_classifier")
print("Notebook complete! Use load_trained_model() to reload a saved model.")
