# Neural Nets on Financial Time Seris Forecasting on Text with JAX

In [1]:
#!pip install datasets

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m8.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.12.0-py3-none-any.wh

In [38]:
from datasets import load_dataset

bigdata_train = load_dataset("TheFinAI/flare-sm-bigdata", split="train")
bigdata_valid = load_dataset("TheFinAI/flare-sm-bigdata", split="validation")
bigdata_test = load_dataset("TheFinAI/flare-sm-bigdata", split="test")

bigdata_train_df = bigdata_train.to_pandas()[['gold', 'text']] # 0: rise, 1: fall
bigdata_valid_df = bigdata_valid.to_pandas()[['gold', 'text']]
bigdata_test_df = bigdata_test.to_pandas()[['gold', 'text']]

In [39]:
# embedding model
from sentence_transformers import SentenceTransformer

# Initialize the model (using a compact and efficient model)
model = SentenceTransformer('all-MiniLM-L6-v2')

# Function to get sentence embeddings
def get_sbert_embeddings(texts):
    embeddings = model.encode(texts, convert_to_numpy=True)
    return embeddings

# Combine embedded text features with non-text features
import numpy as np
texts = bigdata_train_df['text'].tolist()
X_text_embeddings = get_sbert_embeddings(texts)
X_text_embeddings = np.array(X_text_embeddings, dtype=np.float32)

bigdata_train_embedded = np.concatenate([bigdata_train_df[['gold']], X_text_embeddings], axis=1)

In [40]:
texts = bigdata_valid_df['text'].tolist()
X_emb = get_sbert_embeddings(texts)
X_emb = np.array(X_emb, dtype=np.float32)
bigdata_valid_embedded = np.concatenate([bigdata_valid_df[['gold']], X_emb], axis=1)

texts = bigdata_test_df['text'].tolist()
X_emb = get_sbert_embeddings(texts)
X_emb = np.array(X_emb, dtype=np.float32)
bigdata_test_embedded = np.concatenate([bigdata_test_df[['gold']], X_emb], axis=1)

In [41]:
# Convert the combined data to JAX arrays if needed
import jax
import jax.numpy as jnp
from jax import grad, jit, random
import flax
from flax import linen as nn
from jax import random

bigdata_train_embedded_jax = jnp.array(bigdata_train_embedded, dtype=jnp.float32)
print(bigdata_train_embedded_jax.shape)

bigdata_test_embedded_jax = jnp.array(bigdata_test_embedded, dtype=jnp.float32)
bigdata_valid_embedded_jax = jnp.array(bigdata_valid_embedded, dtype=jnp.float32)

(4897, 385)


In [42]:
# initialize random parameters
def init_params(layer_sizes, key):
    """Initialize parameters for a simple MLP model."""
    params = []
    for i in range(len(layer_sizes) - 1):
        # Initialize weights with a small random value and biases as zeros
        w_key, b_key = random.split(key)
        w = random.normal(w_key, (layer_sizes[i], layer_sizes[i+1])) * jnp.sqrt(2.0 / layer_sizes[i])
        b = jnp.zeros((layer_sizes[i+1],))
        params.append((w, b))
    return params

# Define the neural nets with enhancements
def leaky_relu(x, alpha=0.01):
    """Leaky ReLU activation function."""
    return jnp.where(x > 0, x, alpha * x)

def mlp(params, X, dropout_key=None, dropout_rate=0.2):
    """A feedforward MLP with dropout."""
    for i, (w, b) in enumerate(params[:-1]):
        X = jnp.dot(X, w) + b
        X = leaky_relu(X)  # Use LeakyReLU activation
        if dropout_key is not None:
            # Apply dropout during training
            dropout_key, subkey = random.split(dropout_key)
            mask = random.bernoulli(subkey, p=1 - dropout_rate, shape=X.shape)
            X = X * mask / (1 - dropout_rate)
    w, b = params[-1]
    return jnp.dot(X, w) + b  # Linear output layer (regression)

# Loss function to binary cross-entropy
def binary_cross_entropy_loss(params, X, y):
    """Compute the binary cross-entropy loss."""
    logits = mlp(params, X)
    preds = jax.nn.sigmoid(logits)  # Apply sigmoid for binary classification
    return -jnp.mean(y * jnp.log(preds + 1e-8) + (1 - y) * jnp.log(1 - preds + 1e-8))

# compute gradients
grad_loss_fn = grad(binary_cross_entropy_loss)

In [43]:
# define a training step
@jit
def train_step(params, X, y, key, learning_rate=0.001, dropout_rate=0.2):
    """Perform one step of gradient descent with dropout."""
    dropout_key, subkey = random.split(key)
    grads = grad_loss_fn(params, X, y)
    new_params = [(w - learning_rate * dw, b - learning_rate * db)
                  for (w, b), (dw, db) in zip(params, grads)]
    return new_params, dropout_key

# Training data (replace with your own)
X_train = bigdata_train_embedded_jax[:, 1:]
print(X_train[:3,:])
print(X_train.shape)
y_train = bigdata_train_embedded_jax[:, 0]
print(y_train[:10])
print(y_train.shape)

X_valid = bigdata_valid_embedded_jax[:, 1:]
y_valid = bigdata_valid_embedded_jax[:, 0]

X_test = bigdata_test_embedded_jax[:, 1:]
y_test = bigdata_test_embedded_jax[:, 0]

[[ 0.01446111 -0.03051838  0.00084274 ... -0.08114446 -0.029073
   0.0523665 ]
 [ 0.03013896 -0.00629956  0.00309164 ... -0.09023841 -0.05489946
   0.03478851]
 [ 0.01602308 -0.03069994 -0.00132307 ... -0.07940799 -0.03194
   0.04441951]]
(4897, 384)
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
(4897,)


In [44]:
# training loop
# Initialize model parameters
key = random.PRNGKey(0)
layer_sizes = [X_train.shape[1], 128, 64, 32, 1]
params = init_params(layer_sizes, key)
dropout_key = random.PRNGKey(1)

# Training loop
num_epochs = 1000
for epoch in range(num_epochs):
    params, dropout_key = train_step(params, X_train, y_train, dropout_key, learning_rate=0.005, dropout_rate=0.2)
    if epoch % 100 == 0 or epoch == num_epochs - 1:
        loss = binary_cross_entropy_loss(params, X_train, y_train)
        print(f"Epoch {epoch}, Loss: {loss}")

Epoch 0, Loss: 0.6922121644020081
Epoch 100, Loss: 0.6917664408683777
Epoch 200, Loss: 0.6916472911834717
Epoch 300, Loss: 0.6916123032569885
Epoch 400, Loss: 0.6916014552116394
Epoch 500, Loss: 0.6915980577468872
Epoch 600, Loss: 0.6915969848632812
Epoch 700, Loss: 0.6915965676307678
Epoch 800, Loss: 0.6915964484214783
Epoch 900, Loss: 0.6915963888168335
Epoch 999, Loss: 0.6915963888168335


In [45]:
# Evaluate on validation set with sigmoid activation
valid_logits = mlp(params, X_valid)
valid_preds = jax.nn.sigmoid(valid_logits)
valid_loss = binary_cross_entropy_loss(params, X_valid, y_valid)
print(f"Validation Loss: {valid_loss:.10f}")

# Calculate accuracy
valid_preds_binary = (valid_preds > 0.5).astype(jnp.float32)
accuracy = jnp.mean(valid_preds_binary == y_valid)
print(f"Validation Accuracy: {accuracy:.10f}")

Validation Loss: 0.6958872676
Validation Accuracy: 0.4887217879


In [48]:
# Predict on test set
test_preds = mlp(params, X_test)
test_preds_binary = (test_preds > 0.5).astype(jnp.float32)

test_preds_binary

Array([[0.],
       [0.],
       [0.],
       ...,
       [0.],
       [0.],
       [0.]], dtype=float32)