In [1]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


In [3]:
!pip install numerapi



In [4]:
!pip install ninja




In [5]:
import sys
sys.path.append("/content/erasplit-gpu/")

In [6]:
TRAIN = False

In [7]:
import numerapi
import json


napi = numerapi.NumerAPI()
# list the datasets and available versions
all_datasets = napi.list_datasets()
dataset_versions = list(set(d.split('/')[0] for d in all_datasets))
print("Available versions:\n", dataset_versions)

# Set data version to one of the latest datasets
DATA_VERSION = "v5.0"

# Print all files available for download for our version
current_version_files = [f for f in all_datasets if f.startswith(DATA_VERSION)]
print("Available", DATA_VERSION, "files:\n", current_version_files)


# download the feature metadata file
napi.download_dataset(f"{DATA_VERSION}/features.json")

# read the metadata and display
feature_metadata = json.load(open(f"{DATA_VERSION}/features.json"))
for metadata in feature_metadata:
  print(metadata, len(feature_metadata[metadata]))

feature_sets = feature_metadata["feature_sets"]
for feature_set in ["small", "medium", "all"]:
  print(feature_set, len(feature_sets[feature_set]))

Available versions:
 ['v5.0']
Available v5.0 files:
 ['v5.0/features.json', 'v5.0/live.parquet', 'v5.0/live_benchmark_models.parquet', 'v5.0/live_example_preds.csv', 'v5.0/live_example_preds.parquet', 'v5.0/meta_model.parquet', 'v5.0/train.parquet', 'v5.0/train_benchmark_models.parquet', 'v5.0/validation.parquet', 'v5.0/validation_benchmark_models.parquet', 'v5.0/validation_example_preds.csv', 'v5.0/validation_example_preds.parquet']
feature_sets 17
targets 37
small 42
medium 705
all 2376


In [8]:
import os
import pandas as pd
import numpy as np

# Define our feature set from your metadata
feature_set = feature_sets["all"]

# Define our target columns (for a multi-target approach)
# Assumes feature_metadata["targets"] exists and holds a list of target column names
#targets = feature_metadata["targets"]

targets = [
    #"target_nomi_v4_20", # until V16
    "target_teager2b_20", # new from V17
    #"target_jerome_v4_60", # until V16
    "target_teager2b_60", # new from V17
    #"target_jeremy_v4_60", # until V16
    "target_rowan_20",
    "target_ralph_20",
    "target_tyler_20",
    "target_victor_20",
    #"target_waldo_v4_20", # until V16
    "target_claudia_20", # new from V17
    "target_cyrusd_20"   # adding latest primary target from V16
]

targets = ['target']

In [10]:
# Download the training data (if not already downloaded)
napi.download_dataset(f"{DATA_VERSION}/train.parquet" if TRAIN else f"{DATA_VERSION}/live.parquet")

COLUMNS = ["era"] + feature_set + targets if TRAIN else ["era"] + feature_set
# Load training data
# We're loading the "era" column, all target columns, and all features
train = pd.read_parquet(
    f"{DATA_VERSION}/train.parquet" if TRAIN else f"{DATA_VERSION}/live.parquet",
    columns=COLUMNS
)


In [11]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [12]:
!rm -rf ~/.cache/torch_extensions


In [13]:
from torch.utils.cpp_extension import load
import os
if TRAIN:
  # Optional, but recommended: narrow architecture to reduce compile time
  os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0"
  PATH = '/content/erasplit-gpu/histogram_ext/'
  histogram_ext = load(
      name='histogram_ext',
      sources=[
          PATH + 'histogram.cpp',
          PATH + 'best_split_kernel.cu',
          PATH + 'histogram_kernel.cu'
      ],
      extra_cuda_cflags=['-O3', '-gencode=arch=compute_80,code=sm_80', '-DTORCH_USE_CUDA_DSA'],
      extra_cflags=['-O3'],
      verbose=True  # <-- this shows build logs inline
  )


In [14]:
import torch
import numpy as np
from sklearn.base import BaseEstimator, RegressorMixin
from tqdm import tqdm
from typing import List  # Needed for TorchScript list type annotation

# TorchScript function for a single tree traversal.
@torch.jit.script
def _predict_tree(flat_tree: torch.Tensor, X_batch: torch.Tensor, bin_edges: torch.Tensor, max_depth: int) -> torch.Tensor:
    N = X_batch.shape[0]
    node_indices = torch.zeros(N, dtype=torch.long, device=X_batch.device)

    # Unpack flat_tree columns.
    is_leaf = flat_tree[:, 0] > 0       # [num_nodes]
    features = flat_tree[:, 1].long()   # [num_nodes]
    thresholds = flat_tree[:, 2]        # [num_nodes]
    lefts = flat_tree[:, 3].long()      # [num_nodes]
    rights = flat_tree[:, 4].long()     # [num_nodes]
    values = flat_tree[:, 5]            # [num_nodes]

    for _ in range(max_depth + 1):
        leaf_mask = is_leaf[node_indices]
        if leaf_mask.all():
            break

        active_idx = torch.nonzero(~leaf_mask).squeeze(1)
        cur_nodes = node_indices[active_idx]
        cur_features = features[cur_nodes]
        cur_thresholds = thresholds[cur_nodes]

        # For each active sample, get the feature value.
        x = X_batch[active_idx].gather(1, cur_features.unsqueeze(1)).squeeze(1)
        boundaries = bin_edges[cur_features]
        bin_idx = (x.unsqueeze(1) > boundaries).sum(dim=1)
        next_nodes = torch.where(bin_idx <= cur_thresholds, lefts[cur_nodes], rights[cur_nodes])
        node_indices[active_idx] = next_nodes

    return values[node_indices]

# TorchScript function to predict the forest in a loop over trees.
@torch.jit.script
def _predict_forest(flat_forests: List[torch.Tensor], X_batch: torch.Tensor, bin_edges: torch.Tensor, max_depth: int, learning_rate: float, base_prediction: float) -> torch.Tensor:
    N = X_batch.size(0)
    preds = torch.full((N,), base_prediction, device=X_batch.device, dtype=torch.float32)
    for i in range(len(flat_forests)):
         preds = preds + learning_rate * _predict_tree(flat_forests[i], X_batch, bin_edges, max_depth)
    return preds

import torch
from typing import Tuple

@torch.jit.script
def _predict_forest_vectorized(
    flat_forests: torch.Tensor,  # shape: (n_trees, max_nodes, 6)
    X_batch: torch.Tensor,       # shape: (N, F)
    bin_edges: torch.Tensor,     # shape: (F, num_bins)
    max_depth: int,
    learning_rate: float,
    base_prediction: float
) -> torch.Tensor:
    n_trees = flat_forests.shape[0]
    N = X_batch.shape[0]

    # Initialize node indices for every tree and every sample.
    node_indices = torch.zeros((n_trees, N), dtype=torch.long, device=X_batch.device)

    # Pre-extract the tree info (each flat tree row has 6 columns):
    # column 0: is_leaf flag, column 1: feature, column 2: threshold,
    # column 3: left child, column 4: right child, column 5: value.
    is_leaf = flat_forests[:, :, 0] > 0         # (n_trees, max_nodes)
    features = flat_forests[:, :, 1].long()       # (n_trees, max_nodes)
    thresholds = flat_forests[:, :, 2]            # (n_trees, max_nodes)
    lefts = flat_forests[:, :, 3].long()          # (n_trees, max_nodes)
    rights = flat_forests[:, :, 4].long()         # (n_trees, max_nodes)
    values = flat_forests[:, :, 5]                # (n_trees, max_nodes)

    # Expand X_batch so that we have one copy per tree.
    X_exp = X_batch.unsqueeze(0).expand(n_trees, -1, -1)  # shape: (n_trees, N, F)

    for _ in range(max_depth + 1):
        # For each tree and sample, check if the current node is a leaf.
        current_leaf = torch.gather(is_leaf, 1, node_indices)
        # If all are leaves, break out.
        if current_leaf.all():
            break

        # For every (tree, sample), get the feature index used at the current node.
        cur_features = torch.gather(features, 1, node_indices)  # shape: (n_trees, N)

        # Now, for each (tree, sample), grab the feature value from X.
        # We do this by treating cur_features as column indices.
        # The gathered x will have shape (n_trees, N)
        x = X_exp.gather(2, cur_features.unsqueeze(2)).squeeze(2)

        # Get the corresponding bin edges for each (tree, sample).
        # Here bin_edges is indexed by feature. The resulting shape is (n_trees, N, num_bins)
        boundaries = bin_edges[cur_features]

        # Compute the bin index by comparing x to the boundaries.
        # (x.unsqueeze(2) > boundaries) creates a boolean tensor which we sum along dim=2.
        bin_idx = (x.unsqueeze(2) > boundaries).sum(dim=2)

        # Get the threshold, left, and right child values for the current node.
        cur_thresholds = torch.gather(thresholds, 1, node_indices)
        cur_lefts = torch.gather(lefts, 1, node_indices)
        cur_rights = torch.gather(rights, 1, node_indices)

        # Decide next node index for each (tree, sample).
        next_node = torch.where(bin_idx <= cur_thresholds, cur_lefts, cur_rights)

        # Only update those samples that are not at a leaf.
        node_indices = torch.where(current_leaf, node_indices, next_node)

    # Now, gather the leaf values from all trees.
    final_values = torch.gather(values, 1, node_indices)  # shape: (n_trees, N)

    # Sum over trees (with learning rate) and add base prediction.
    preds = base_prediction + learning_rate * final_values.sum(dim=0)
    return preds


class ErasplitGBDT(BaseEstimator, RegressorMixin):
    def __init__(self, num_bins=10, max_depth=3, learning_rate=0.1, n_estimators=100):
        self.num_bins = num_bins
        self.max_depth = max_depth
        self.learning_rate = learning_rate
        self.n_estimators = n_estimators
        self.forest = None
        self.flat_forest = None  # To store flattened trees.
        self.bin_edges = None
        self.base_prediction = None
        self.unique_eras = None
        self.device = "cuda"
        self.gradients = None
        self.root_node_indices = None
        self.bin_indices = None
        self.Y_gpu = None
        self.num_features = None
        self.num_samples = None
        self.out_feature = torch.zeros(1, device=self.device, dtype=torch.int32)
        self.out_bin = torch.zeros(1, device=self.device, dtype=torch.int32)

    def fit(self, X, y, era_id):
        self.bin_indices, era_indices, self.bin_edges, self.unique_eras, self.Y_gpu = self.preprocess_gpu_data(X, y, era_id)
        self.gradients = torch.zeros_like(self.Y_gpu)
        self.root_node_indices = torch.arange(self.num_samples, device=self.device)
        self.base_prediction = self.Y_gpu.mean().item()
        self.gradients += self.base_prediction  # Initialize with mean.
        self.forest = self.grow_forest()
        # Pre-flatten trees for fast inference.
        self.flat_forest = [self.flatten_tree(tree) for tree in self.forest]
        return self

    def predict(self, X):
        X_tensor = torch.from_numpy(X).to(torch.int8).to(self.device)
        batch_size = 10_000  # Adjust based on your GPU memory.
        preds = []
        # First, flatten each tree.
        flat_trees = [self.flatten_tree(tree) for tree in self.forest]

        # Determine the maximum number of nodes among all trees.
        max_nodes = max(tree.shape[0] for tree in flat_trees)

        # Pad each tree to have the same number of nodes.
        padded_trees = []
        for tree in flat_trees:
            pad_size = max_nodes - tree.shape[0]
            if pad_size > 0:
                # Pad with zeros (or a dummy leaf node that won't be used)
                padding = torch.zeros((pad_size, tree.shape[1]), device=tree.device, dtype=tree.dtype)
                padded_tree = torch.cat([tree, padding], dim=0)
            else:
                padded_tree = tree
            padded_trees.append(padded_tree)

        # Stack the padded trees into one tensor of shape (n_trees, max_nodes, 6).
        flat_forests_ts = torch.stack(padded_trees, dim=0)

        # Assume self.flat_forest_tensor is created at fit time.
        for batch_start in tqdm(range(0, X_tensor.shape[0], batch_size), desc="Predicting batches"):
            batch_end = min(batch_start + batch_size, X_tensor.shape[0])
            X_batch = X_tensor[batch_start:batch_end]
            preds.append(_predict_forest_vectorized(flat_forests_ts, X_batch, self.bin_edges,
                                                      self.max_depth, self.learning_rate, self.base_prediction))
        preds = torch.cat(preds, dim=0)
        return preds.cpu().numpy()


    def compute_quantile_bins(self, X, num_bins):
        N, F = X.shape
        bin_edges_list = []
        for f in tqdm(range(F), desc="Computing bin edges", leave=False):
            feature = X[:, f]
            unique_vals = torch.unique(feature)
            if unique_vals.numel() == num_bins and torch.all((unique_vals[1:] - unique_vals[:-1]) == 1):
                edges = (unique_vals[:-1] + unique_vals[1:]) / 2.0
            elif unique_vals.numel() < num_bins:
                min_val = unique_vals.min()
                max_val = unique_vals.max()
                edges = torch.linspace(min_val, max_val, steps=num_bins+1)[1:-1]
            else:
                quantiles = torch.linspace(0, 1, num_bins + 1)[1:-1]
                sorted_feature, _ = torch.sort(feature)
                idx = (quantiles * (N - 1)).long().clamp(max=N - 1)
                edges = sorted_feature[idx]
            bin_edges_list.append(edges)
        bin_edges = torch.stack(bin_edges_list, dim=0)
        return bin_edges.contiguous()

    def preprocess_gpu_data(self, X_np, Y_np, era_id_np):
        self.num_samples, self.num_features = X_np.shape
        X = torch.from_numpy(X_np).type(torch.int8).to(self.device)
        Y_gpu = torch.from_numpy(Y_np).type(torch.float32).to(self.device)
        era_id_gpu = torch.from_numpy(era_id_np).type(torch.int32).to(self.device)
        bin_edges = self.compute_quantile_bins(X, self.num_bins).type(torch.float32).contiguous()
        bin_indices = torch.empty((self.num_samples, self.num_features), dtype=torch.int8)
        for f in tqdm(range(self.num_features), desc='Bucketizing...'):
            bin_indices[:, f] = torch.bucketize(X[:, f], bin_edges[f], right=False).type(torch.int8)
        bin_indices = bin_indices.to(self.device).contiguous()
        bin_edges = bin_edges.to(self.device)
        unique_eras, era_indices = torch.unique(era_id_gpu, return_inverse=True)
        return bin_indices, era_indices, bin_edges, unique_eras, Y_gpu

    def compute_histograms(self, bin_indices_sub, gradients):
        grad_hist = torch.zeros((self.num_features, self.num_bins), device=self.device, dtype=torch.float32)
        hess_hist = torch.zeros((self.num_features, self.num_bins), device=self.device, dtype=torch.float32)
        histogram_ext.compute_histogram(
            bin_indices_sub.to(torch.int32),
            gradients,
            grad_hist,
            hess_hist,
            self.num_bins
        )
        return grad_hist, hess_hist

    def find_best_split(self, gradient_histogram, hessian_histogram):
        histogram_ext.compute_split(
            gradient_histogram.contiguous(),
            hessian_histogram.contiguous(),
            self.num_features, self.num_bins,
            0.0, 1.0, 1e-6,
            self.out_feature,
            self.out_bin
        )
        f = int(self.out_feature[0])
        b = int(self.out_bin[0])
        return (f, b)

    def grow_tree(self, gradient_histogram, hessian_histogram, node_indices, depth):
        if depth == self.max_depth:
            leaf_value = (self.Y_gpu[node_indices] - self.gradients[node_indices]).mean()
            self.gradients[node_indices] += self.learning_rate * leaf_value
            return {"leaf_value": leaf_value, "samples": node_indices.numel()}

        best_feature, best_bin = self.find_best_split(gradient_histogram, hessian_histogram)
        if best_feature == -1:
            leaf_value = (self.Y_gpu[node_indices] - self.gradients[node_indices]).mean()
            self.gradients[node_indices] += self.learning_rate * leaf_value
            return {"leaf_value": leaf_value, "samples": node_indices.numel()}

        split_mask = self.bin_indices[node_indices, best_feature] <= best_bin
        left_indices = node_indices[split_mask]
        right_indices = node_indices[~split_mask]
        left_size, right_size = left_indices.numel(), right_indices.numel()

        if left_size == 0 or right_size == 0:
            leaf_value = (self.Y_gpu[node_indices] - self.gradients[node_indices]).mean()
            self.gradients[node_indices] += self.learning_rate * leaf_value
            return {"leaf_value": leaf_value, "samples": node_indices.numel()}

        if left_size < right_size:
            gradient_histogram_left, hessian_histogram_left = self.compute_histograms(self.bin_indices[left_indices], self.residual[left_indices])
            gradient_histogram_right = gradient_histogram - gradient_histogram_left
            hessian_histogram_right = hessian_histogram - hessian_histogram_left
        else:
            gradient_histogram_right, hessian_histogram_right = self.compute_histograms(self.bin_indices[right_indices], self.residual[right_indices])
            gradient_histogram_left = gradient_histogram - gradient_histogram_right
            hessian_histogram_left = hessian_histogram - hessian_histogram_right

        new_depth = depth + 1
        left_child = self.grow_tree(gradient_histogram_left, hessian_histogram_left, left_indices, new_depth)
        right_child = self.grow_tree(gradient_histogram_right, hessian_histogram_right, right_indices, new_depth)

        del gradient_histogram
        del hessian_histogram

        return {
            "feature": best_feature,
            "bin": best_bin,
            "left": left_child,
            "right": right_child
        }

    def grow_forest(self):
        forest = [{}] * self.n_estimators
        for i in tqdm(range(self.n_estimators), desc="Growing trees"):
            self.residual = self.Y_gpu - self.gradients
            self.root_gradient_histogram, self.root_hessian_histogram = self.compute_histograms(self.bin_indices, self.residual)
            tree = self.grow_tree(
                self.root_gradient_histogram,
                self.root_hessian_histogram,
                self.root_node_indices,
                depth=0
            )
            forest[i] = tree
        return forest

    def flatten_tree(self, tree):
        flat_nodes = []
        def recurse(node):
            idx = len(flat_nodes)
            if "leaf_value" in node:
                flat_nodes.append([1, -1, -1.0, -1, -1, node["leaf_value"]])
                return 1
            else:
                flat_nodes.append([0, node["feature"], float(node["bin"]), -1, -1, -1.0])
                left_count = recurse(node["left"])
                right_count = recurse(node["right"])
                flat_nodes[idx][3] = idx + 1
                flat_nodes[idx][4] = idx + 1 + left_count
                return 1 + left_count + right_count
        recurse(tree)
        return torch.tensor(flat_nodes, device=self.device, dtype=torch.float32)

    def predict_forest_batch(self, X_batch):
        # Instead of looping over trees in Python,
        # call the TorchScript _predict_forest that iterates over the list of trees.
        flat_forests_ts: List[torch.Tensor] = self.flat_forest
        return _predict_forest(flat_forests_ts, X_batch, self.bin_edges, self.max_depth, self.learning_rate, self.base_prediction)


In [16]:
# ---------------------------
# Example usage:
# ---------------------------
if __name__ == "__main__":
    # Replace 'train', 'feature_set', 'targets', and 'era' with your actual data variables.
    X = train[feature_set].values.astype(np.int8)
    y = train[targets].values.astype(np.float32) if TRAIN else None
    era = train[['era']].values.astype(int) if TRAIN else None

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [17]:
if TRAIN:
  model = ErasplitGBDT(num_bins=5,
                       max_depth=5,
                       learning_rate=0.01,
                       n_estimators=2000)
  # Train on the entire dataset (full batch training)
  model.fit(X, y, era)
  torch.save(model, 'model.pth')
else:
  with torch.serialization.safe_globals([ErasplitGBDT]):
    model = torch.load("model.pth", weights_only=False)
  preds = model.predict(X)
  train['prediction'] = preds


Predicting batches: 100%|██████████| 1/1 [00:00<00:00,  6.35it/s]


In [18]:
if not TRAIN:
  train["prediction"] = train['prediction'].rank(pct=True)
  train[['prediction']].to_parquet('live.parquet')


In [19]:
train[['prediction']]

Unnamed: 0_level_0,prediction
id,Unnamed: 1_level_1
n00051870d69760e,0.221368
n0010dbaaf8a002c,0.060261
n00149a3d96b5e0a,0.164796
n0029c4792309ee2,0.789700
n003b65d301107cc,0.799078
...,...
nffb9db73af1c9cf,0.978017
nffbe447250b5f67,0.795849
nffcdf7ca23f25fd,0.732360
nffd19077f7fe6cc,0.473636
