# Neural Nets on Financial Time Seris Forecasting on Text with JAX

In [None]:
!pip install datasets

In [2]:
from datasets import load_dataset

bigdata_train = load_dataset("TheFinAI/flare-sm-bigdata", split="train")
bigdata_valid = load_dataset("TheFinAI/flare-sm-bigdata", split="validation")
bigdata_test = load_dataset("TheFinAI/flare-sm-bigdata", split="test")

bigdata_train_df = bigdata_train.to_pandas()[['gold', 'text']] # 0: rise, 1: fall
bigdata_valid_df = bigdata_valid.to_pandas()[['gold', 'text']]
bigdata_test_df = bigdata_test.to_pandas()[['gold', 'text']]

In [None]:
bigdata_train_df.head()

Unnamed: 0,gold,text
0,1,"date,open,high,low,close,adj-close,inc-5,inc-1..."
1,1,"date,open,high,low,close,adj-close,inc-5,inc-1..."
2,1,"date,open,high,low,close,adj-close,inc-5,inc-1..."
3,1,"date,open,high,low,close,adj-close,inc-5,inc-1..."
4,1,"date,open,high,low,close,adj-close,inc-5,inc-1..."


In [None]:
acl_train = load_dataset("TheFinAI/flare-sm-acl", split="train")
cikm_train = load_dataset("TheFinAI/flare-sm-cikm", split="train")
acl_test = load_dataset("TheFinAI/flare-sm-acl", split="test")
acl_valid = load_dataset("TheFinAI/flare-sm-acl", split="valid")
cikm_valid = load_dataset("TheFinAI/flare-sm-cikm", split="valid")
cikm_test = load_dataset("TheFinAI/flare-sm-cikm", split="test")


acl_train_df = acl_train.to_pandas()[['gold', 'text']]
cikm_train_df = cikm_train.to_pandas()[['gold', 'text']]

acl_valid_df = acl_valid.to_pandas()[['gold', 'text']]
cikm_valid_df = cikm_valid.to_pandas()[['gold', 'text']]


acl_test_df = acl_test.to_pandas()[['gold', 'text']]
cikm_test_df = cikm_test.to_pandas()[['gold', 'text']]

## Preprocess data

In [4]:
from sentence_transformers import SentenceTransformer
import numpy as np
model = SentenceTransformer('all-MiniLM-L6-v2') # load pre-trained model for text embedding

# Text embedding with batch processing
def batch_encode(texts, batch_size=32):
    def get_sbert_embeddings(texts):
        embeddings = model.encode(texts, convert_to_numpy=True)
        return embeddings
    embeddings_list = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i + batch_size]
        batch_embeddings = get_sbert_embeddings(batch)
        embeddings_list.append(batch_embeddings)
    return np.vstack(embeddings_list)

In [5]:
import jax
import jax.numpy as jnp

# training data processing
texts = acl_train_df['text'].tolist()
X_train_embeddings = batch_encode(texts)
X_train_embeddings = np.array(X_train_embeddings, dtype=np.float32)
acl_train_embedded = np.concatenate([acl_train_df[['gold']], X_train_embeddings], axis=1)
acl_train_embedded_jax = jnp.array(acl_train_embedded, dtype=jnp.float32)

In [6]:
# validation data processing
valid_texts = acl_valid_df['text'].tolist()
X_valid_embeddings = batch_encode(valid_texts)
X_valid_embeddings = np.array(X_valid_embeddings, dtype=np.float32)
acl_valid_embedded = np.concatenate([acl_valid_df[['gold']], X_valid_embeddings], axis=1)
acl_valid_embedded_jax = jnp.array(acl_valid_embedded, dtype=jnp.float32)

# test data processing
test_texts = acl_test_df['text'].tolist()
X_test_embeddings = batch_encode(test_texts)
X_test_embeddings = np.array(X_test_embeddings, dtype=np.float32)
acl_test_embedded = np.concatenate([acl_test_df[['gold']], X_test_embeddings], axis=1)
acl_test_embedded_jax = jnp.array(acl_test_embedded, dtype=jnp.float32)

In [8]:
print(acl_train_embedded[:3])
print(acl_train_embedded_jax[:3])
print(acl_train_embedded.shape)
print(acl_valid_embedded.shape)

[[ 1.         -0.00104391 -0.0153452  ... -0.06044361 -0.04032699
   0.04581903]
 [ 1.         -0.01484502 -0.01614808 ... -0.05998765 -0.04143057
   0.04037762]
 [ 0.         -0.00950441 -0.01646534 ... -0.06522585 -0.03886713
   0.03963047]]
[[ 1.         -0.00104391 -0.0153452  ... -0.06044361 -0.04032699
   0.04581903]
 [ 1.         -0.01484502 -0.01614808 ... -0.05998765 -0.04143057
   0.04037762]
 [ 0.         -0.00950441 -0.01646534 ... -0.06522585 -0.03886713
   0.03963047]]
(20781, 385)
(2555, 385)


## Train the models

### 1. Logistic regression

Prepare data for training. Split training data into X and response Y.

In [9]:
y_train = acl_train_embedded[:, 0].reshape(-1, 1)
X_train = acl_train_embedded[:, 1:]
y_valid = acl_valid_embedded[:, 0].reshape(-1, 1)
X_valid = acl_valid_embedded[:, 1:]

Define functions

In [16]:
from sklearn.metrics import classification_report
from sklearn.linear_model import LogisticRegression

# Train logistic regression model
maxiter = 30
lr_model = LogisticRegression(max_iter=maxiter, random_state=42)
lr_model.fit(X_train, y_train)

# Evaluate on validation set
y_pred = lr_model.predict(X_valid)

# Print performance metrics
print("\nLogistic Regression Results:")
print(f"Validation Accuracy: {lr_model.score(X_valid, y_valid):.10f}")

  y = column_or_1d(y, warn=True)



Logistic Regression Results:
Validation Accuracy: 0.4845401174


### 2. Train a MLP with JAX

In [18]:
import jax
import jax.numpy as jnp
from jax import random
from typing import List, Tuple, Any
import optax  # For Adam optimizer
from functools import partial

# Prepare data for modelling
y_train = acl_train_embedded_jax[:, 0].reshape(-1, 1)
X_train = acl_train_embedded_jax[:, 1:]
y_valid = acl_valid_embedded_jax[:, 0].reshape(-1, 1)
X_valid = acl_valid_embedded_jax[:, 1:]
input_dim = X_train.shape[1]  # Get input dimension from training data

In [19]:
# Define MLP in JAX
def init_mlp_params(layer_sizes: List[int], key: Any) -> List[Tuple[jnp.ndarray, jnp.ndarray]]:
    params = []
    keys = random.split(key, len(layer_sizes))
    for in_dim, out_dim, k in zip(layer_sizes[:-1], layer_sizes[1:], keys):
        w_key, b_key = random.split(k)
        W = random.normal(w_key, (in_dim, out_dim)) * jnp.sqrt(2. / in_dim)
        b = jnp.zeros((out_dim,))
        params.append((W, b))
    return params

def mlp_forward(params: List[Tuple[jnp.ndarray, jnp.ndarray]], x: jnp.ndarray, dropout_rate: float = 0.0,
               train: bool = False, key: Any = None) -> jnp.ndarray:
    """Forward pass with dropout support"""
    for i, (W, b) in enumerate(params[:-1]):
        x = jnp.dot(x, W) + b
        x = jax.nn.relu(x)

        # Apply dropout during training
        if train and dropout_rate > 0:
            if key is None:
                raise ValueError("Random key required for dropout")
            dropout_key = random.fold_in(key, i)  # Different key for each layer
            mask = random.bernoulli(dropout_key, p=1-dropout_rate, shape=x.shape)
            x = x * mask / (1 - dropout_rate)  # Scale to maintain expected value

    W_last, b_last = params[-1]
    logits = jnp.dot(x, W_last) + b_last
    return logits

# Loss and training step
def binary_cross_entropy_loss(logits: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
    preds = jax.nn.sigmoid(logits)
    return -jnp.mean(labels * jnp.log(preds + 1e-7) + (1 - labels) * jnp.log(1 - preds + 1e-7))

# Improved training step with Adam optimizer
@partial(jax.jit, static_argnums=(4, 5))
def train_step(params, X_batch, y_batch, opt_state, dropout_rate=0.2, train=True):
    """Single training step with Adam optimizer and dropout"""
    key = random.PRNGKey(0)  # For reproducibility

    def loss_fn(p):
        logits = mlp_forward(p, X_batch, dropout_rate=dropout_rate, train=train, key=key)
        return binary_cross_entropy_loss(logits, y_batch)

    loss, grads = jax.value_and_grad(loss_fn)(params)
    updates, new_opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state, loss

In [25]:
# Hyperparameter tuning
def tune_hyperparameters():
    best_accuracy = 0.0
    best_params = None
    best_config = {}

    # Define hyperparameter search space
    learning_rates = [1e-4]
    hidden_layer_configs = [
        [128, 64]
    ]
    dropout_rates = [0.0, 0.2, 0.3, 0.4, 0.5, 0.6, .7, .8, .9]
    batch_sizes = [32]  # For mini-batch training

    results = []

    for lr in learning_rates:
        for hidden_layers in hidden_layer_configs:
            for dropout_rate in dropout_rates:
                for batch_size in batch_sizes:
                    print(f"\nTrying: lr={lr}, layers={hidden_layers}, dropout={dropout_rate}, batch_size={batch_size}")

                    # Initialize model
                    key = random.PRNGKey(42)
                    layer_sizes = [input_dim] + hidden_layers + [1]
                    params = init_mlp_params(layer_sizes, key)

                    # Initialize optimizer
                    global optimizer  # Make it accessible in train_step
                    optimizer = optax.adam(learning_rate=lr)
                    opt_state = optimizer.init(params)

                    # Mini-batch training
                    num_batches = max(1, len(X_train) // batch_size)

                    for epoch in range(50):  # Fewer epochs for tuning
                        # Shuffle data
                        perm = random.permutation(key, len(X_train))
                        key = random.fold_in(key, epoch)  # Update key for next epoch

                        # Mini-batch updates
                        total_loss = 0.0
                        for i in range(num_batches):
                            batch_idx = perm[i * batch_size:(i + 1) * batch_size]
                            X_batch = X_train[batch_idx]
                            y_batch = y_train[batch_idx]

                            params, opt_state, loss = train_step(
                                params, X_batch, y_batch, opt_state,
                                dropout_rate=dropout_rate, train=True
                            )
                            total_loss += loss

                        avg_loss = total_loss / num_batches
                        if epoch % 10 == 0:
                            print(f"Epoch {epoch+1} | Avg Loss: {avg_loss:.4f}")

                    # Evaluate on validation set
                    val_accuracy = evaluate(params, X_valid, y_valid, dropout_rate=0.0, train=False)
                    print(f"Validation Accuracy: {val_accuracy:.4f}")

                    # Record result
                    config = {
                        'learning_rate': lr,
                        'hidden_layers': hidden_layers,
                        'dropout_rate': dropout_rate,
                        'batch_size': batch_size,
                        'val_accuracy': val_accuracy
                    }
                    results.append(config)

                    # Update best model
                    if val_accuracy > best_accuracy:
                        best_accuracy = val_accuracy
                        best_params = params
                        best_config = config

    print("\n=== Hyperparameter Tuning Results ===")
    for i, res in enumerate(sorted(results, key=lambda x: x['val_accuracy'], reverse=True)):
        print(f"{i+1}. Accuracy: {res['val_accuracy']:.4f} - LR: {res['learning_rate']}, "
              f"Layers: {res['hidden_layers']}, Dropout: {res['dropout_rate']}, "
              f"Batch Size: {res['batch_size']}")

    print(f"\nBest Configuration: {best_config}")
    return best_params, best_config

# Modified evaluation function for hyperparameter tuning
def evaluate(params, X, y, dropout_rate=0.0, train=False) -> float:
    """Evaluate model accuracy with optional dropout"""
    key = random.PRNGKey(99) if train else None
    logits = mlp_forward(params, X, dropout_rate=dropout_rate, train=train, key=key)
    preds = jax.nn.sigmoid(logits)
    binary_preds = (preds > 0.5).astype(jnp.float32)
    accuracy = jnp.mean(binary_preds == y)
    return float(accuracy)

In [26]:
# Run hyperparameter tuning
print("\n=== Starting Hyperparameter Tuning ===")
best_params, best_config = tune_hyperparameters()

# Train final model with best hyperparameters
print("\n=== Training Final Model with Best Hyperparameters ===")
key = random.PRNGKey(0)
layer_sizes = [input_dim] + best_config['hidden_layers'] + [1]
params = init_mlp_params(layer_sizes, key)

# Initialize optimizer with best learning rate
optimizer = optax.adam(learning_rate=best_config['learning_rate'])
opt_state = optimizer.init(params)


=== Starting Hyperparameter Tuning ===

Trying: lr=0.0001, layers=[128, 64], dropout=0.0, batch_size=32
Epoch 1 | Avg Loss: 0.6912
Epoch 11 | Avg Loss: 0.6760
Epoch 21 | Avg Loss: 0.6739
Epoch 31 | Avg Loss: 0.6727
Epoch 41 | Avg Loss: 0.6715
Validation Accuracy: 0.4975

Trying: lr=0.0001, layers=[128, 64], dropout=0.2, batch_size=32
Epoch 1 | Avg Loss: 0.6920
Epoch 11 | Avg Loss: 0.6777
Epoch 21 | Avg Loss: 0.6752
Epoch 31 | Avg Loss: 0.6739
Epoch 41 | Avg Loss: 0.6730
Validation Accuracy: 0.5045

Trying: lr=0.0001, layers=[128, 64], dropout=0.3, batch_size=32
Epoch 1 | Avg Loss: 0.6922
Epoch 11 | Avg Loss: 0.6781
Epoch 21 | Avg Loss: 0.6756
Epoch 31 | Avg Loss: 0.6745
Epoch 41 | Avg Loss: 0.6734
Validation Accuracy: 0.5072

Trying: lr=0.0001, layers=[128, 64], dropout=0.4, batch_size=32
Epoch 1 | Avg Loss: 0.6926
Epoch 11 | Avg Loss: 0.6788
Epoch 21 | Avg Loss: 0.6768
Epoch 31 | Avg Loss: 0.6752
Epoch 41 | Avg Loss: 0.6738
Validation Accuracy: 0.5025

Trying: lr=0.0001, layers=[128,

In [27]:
# Training with mini-batches
num_epochs = 100
batch_size = best_config['batch_size']
num_batches = max(1, len(X_train) // batch_size)

for epoch in range(num_epochs):
    # Shuffle data
    perm = random.permutation(key, len(X_train))
    key = random.fold_in(key, epoch)

    # Mini-batch updates
    total_loss = 0.0
    for i in range(num_batches):
        batch_idx = perm[i * batch_size:(i + 1) * batch_size]
        X_batch = X_train[batch_idx]
        y_batch = y_train[batch_idx]

        params, opt_state, loss = train_step(
            params, X_batch, y_batch, opt_state,
            dropout_rate=best_config['dropout_rate'], train=True
        )
        total_loss += loss

    avg_loss = total_loss / num_batches
    if epoch % 10 == 0 or epoch == num_epochs - 1:
        train_acc = evaluate(params, X_train, y_train, dropout_rate=0.0, train=False)
        val_acc = evaluate(params, X_valid, y_valid, dropout_rate=0.0, train=False)
        print(f"Epoch {epoch+1} | Loss: {avg_loss:.4f} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")

Epoch 1 | Loss: 0.6973 | Train Acc: 0.5278 | Val Acc: 0.4669
Epoch 11 | Loss: 0.6830 | Train Acc: 0.5286 | Val Acc: 0.4908
Epoch 21 | Loss: 0.6801 | Train Acc: 0.5248 | Val Acc: 0.5311
Epoch 31 | Loss: 0.6792 | Train Acc: 0.5377 | Val Acc: 0.5119
Epoch 41 | Loss: 0.6784 | Train Acc: 0.5381 | Val Acc: 0.5292
Epoch 51 | Loss: 0.6776 | Train Acc: 0.5284 | Val Acc: 0.5432
Epoch 61 | Loss: 0.6767 | Train Acc: 0.5268 | Val Acc: 0.5495
Epoch 71 | Loss: 0.6763 | Train Acc: 0.5474 | Val Acc: 0.5374
Epoch 81 | Loss: 0.6756 | Train Acc: 0.5254 | Val Acc: 0.5487
Epoch 91 | Loss: 0.6751 | Train Acc: 0.5418 | Val Acc: 0.5374
Epoch 100 | Loss: 0.6749 | Train Acc: 0.5270 | Val Acc: 0.5499


In [28]:
# Final evaluation
final_accuracy = evaluate(params, X_valid, y_valid, dropout_rate=0.0, train=False)
print(f"\nFinal Validation Accuracy: {final_accuracy:.10f}")


Final Validation Accuracy: 0.5499022007
