In [1]:
from typing import Tuple, List, Union, Any, Optional, Dict, Literal, Callable
import time
import collections
import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))
sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))

from tqdm import tqdm
import openml
import numpy as np
import jax
import jax.numpy as jnp
import jax.lax as lax
from jaxtyping import Array, Float, Int, PRNGKeyArray
import aeon
import pandas as pd
from preprocessing.timeseries_augmentation import normalize_mean_std_traindata, normalize_streams, augment_time, add_basepoint_zero
from aeon.regression.sklearn import RotationForestRegressor
from sklearn.metrics import mean_squared_error, mean_absolute_error
from aeon.datasets.tser_datasets import tser_soton
from aeon.datasets import load_regression, load_classification
from sklearn.linear_model import RidgeCV
from sklearn.model_selection import train_test_split

from features.sig import SigTransform, LogSigTransform
from features.base import TimeseriesFeatureTransformer, TabularTimeseriesFeatures, RandomGuesser, RandomProjectionFeatures
from features.sig_neural import RandomizedSignature, TimeInhomogenousRandomizedSignature
from features.SWIM_controlled_resnet import SampledControlledResNet
from features.efficient_SCRN import memory_efficient_SCRN
from features.rocket_wrappers import RocketWrapper
from utils.utils import print_name, print_shape

jax.config.update('jax_platform_name', 'gpu') # Used to set the platform (cpu, gpu, etc.)
np.set_printoptions(precision=3, threshold=5) # Print options

2024-10-10 15:28:09.517220: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.82). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


# OpenML code

In [41]:
# Fetch the collection with ID 353
collection = openml.study.get_suite(353)
dataset_ids = collection.data
metadata_list = []

# Fetch and process each dataset
for i, dataset_id in enumerate(dataset_ids):
    dataset = openml.datasets.get_dataset(dataset_id)
    X, y, categorical_indicator, attribute_names = dataset.get_data(
        target=dataset.default_target_attribute
    )
    X = np.array(X)
    y = np.array(y)
    
    # Extract the required metadata
    metadata = {
        'dataset_id': dataset.id,
        'name': dataset.name,
        'n_obs': int(dataset.qualities['NumberOfInstances']),
        'n_features': int(dataset.qualities['NumberOfFeatures']),
        '%_unique_y': len(np.unique(y))/len(y),
        'n_unique_y': len(np.unique(y)),
    }
    
    metadata_list.append(metadata)
    print(f" {i+1}/{len(dataset_ids)} Processed dataset {dataset.id}: {dataset.name}")

# Create a DataFrame from the metadata list
df_metadata = pd.DataFrame(metadata_list).sort_values('%_unique_y', ascending=False)

 1/35 Processed dataset 44956: abalone
 2/35 Processed dataset 44957: airfoil_self_noise
 3/35 Processed dataset 44958: auction_verification
 4/35 Processed dataset 44959: concrete_compressive_strength
 5/35 Processed dataset 44963: physiochemical_protein
 6/35 Processed dataset 44964: superconductivity
 7/35 Processed dataset 44965: geographical_origin_of_music
 8/35 Processed dataset 44966: solar_flare
 9/35 Processed dataset 44969: naval_propulsion_plant
 10/35 Processed dataset 44971: white_wine
 11/35 Processed dataset 44972: red_wine
 12/35 Processed dataset 44973: grid_stability
 13/35 Processed dataset 44974: video_transcoding
 14/35 Processed dataset 44975: wave_energy
 15/35 Processed dataset 44976: sarcos
 16/35 Processed dataset 44977: california_housing
 17/35 Processed dataset 44978: cpu_activity
 18/35 Processed dataset 44979: diamonds
 19/35 Processed dataset 44980: kin8nm
 20/35 Processed dataset 44981: pumadyn32nh
 21/35 Processed dataset 44983: miami_housing
 22/35 P

In [42]:
df_metadata.sort_values('%_unique_y', ascending=True)

Unnamed: 0,dataset_id,name,n_obs,n_features,%_unique_y,n_unique_y
9,44971,white_wine,4898,12,0.001429,7
26,44993,health_insurance,22272,12,0.003367,75
10,44972,red_wine,1599,12,0.003752,6
8,44969,naval_propulsion_plant,11934,15,0.004274,51
0,44956,abalone,4177,9,0.006703,28
16,44978,cpu_activity,8192,22,0.006836,56
27,45012,fifa,19178,29,0.006935,133
7,44966,solar_flare,1066,11,0.007505,8
31,44967,student_performance_por,649,31,0.026194,17
6,44965,geographical_origin_of_music,1059,117,0.029273,31


# Download single dataset

In [43]:
def load_openml_dataset(dataset_id, 
                        normalize_X:bool = True,
                        normalize_y:bool = True,
                        train_test_size:float = 0.7,
                        split_seed:int = 0) -> Tuple[np.ndarray, np.ndarray]:
    # Fetch dataset from OpenML by its ID
    dataset = openml.datasets.get_dataset(dataset_id)
    X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute)
    X = np.array(X)
    y = np.array(y)
    X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=train_test_size, random_state=split_seed)

    #normalize
    if normalize_X:
        X_train, X_test = normalize_mean_std_traindata(X_train, X_test)
    if normalize_y:
        y_train, y_test = normalize_mean_std_traindata(y_train, y_test)

    return (jnp.array(X_train.astype(np.float32)), 
            jnp.array(y_train.astype(np.float32)), 
            jnp.array(X_test.astype(np.float32)), 
            jnp.array(y_test.astype(np.float32)))

dataset_id = 44971  # Replace with the dataset ID you want
X_train, X_test, y_train, y_test = load_openml_dataset(dataset_id)

# SWIM tabular model

In [None]:
from typing import Tuple, List, Union, Any, Optional, Dict, Set, Literal, Callable
from abc import ABC, abstractmethod
from functools import partial

import numpy as np
import jax
import jax.numpy as jnp
import jax.lax as lax
from jaxtyping import Array, Float, Int, PRNGKeyArray

from features.base import TimeseriesFeatureTransformer



def init_single_SWIM_layer(
        seed: PRNGKeyArray,
        X: Float[Array, "N  d"],
        y: Float[Array, "N  p"],
        n_features: int,
        sampling_method: Literal["uniform", "gradient-weighted"] = "uniform",
    ) -> Tuple[Float[Array, "d  n_features"], Float[Array, "n_features"]]:
    """
    Fits the weights for a single layer of the SWIM model.

    Args:
        seed (PRNGKeyArray): Random seed for the weights and biases.
        X (Float[Array, "N  d"]): Previous layer's output.
        y (Float[Array, "N  p"]): Target training data.
        n_features (int): Next hidden layer size.
        sampling_method (str): Uniform or gradient-weighted pair sampling.
    Returns:
        Weights (d, n_features) and biases (1, n_features) for the next layer.
    """
    seed_idxs, seed_sample = jax.random.split(seed, 2)
    N, d = X.shape
    EPS = 1e-06

    #obtain pair indices
    n = 3*N
    idx1 = jnp.arange(0, n) % N
    delta = jax.random.randint(seed_idxs, shape=(n,), minval=1, maxval=N)
    idx2 = (idx1 + delta) % N
    
    if sampling_method=="gradient-weighted":
        #calculate 'gradients'
        dx = X[idx2] - X[idx1]
        dy = y[idx2] - y[idx1]
        dists = jnp.maximum(EPS, jnp.linalg.norm(dx, axis=1, keepdims=True) )
        gradients = (jnp.linalg.norm(dy, axis=1, keepdims=True) / dists, ).reshape(-1) #NOTE paper uses ord=inf instead of ord=2
        p = gradients/gradients.sum()
    elif sampling_method=="uniform":
        p = None
    else:
        raise ValueError(f"sampling_method must be 'uniform' or 'gradient-weighted'. Given: {sampling_method}")

    #sample pairs
    selected_idx = jax.random.choice(
        seed_sample, 
        n,
        shape=(n_features,), 
        replace=True,
        p=p
        )
    idx1 = idx1[selected_idx]
    dx = dx[selected_idx]
    dists = dists[selected_idx]

    #define weights and biases
    weights = (dx / dists**2).T
    biases = -jnp.sum(weights * X[idx1].T, axis=0, keepdims=True) - 0.5  # NOTE experiment with this. also +-0.5 ?
    return weights, biases



def forward_1_layer(
        X: Float[Array, "N  d"],
        weights: Float[Array, "d  n_features"],
        biases: Float[Array, "1  n_features"],
        add_residual: bool,
        activation = lambda x : jnp.maximum(0,x+0.5), # jnp.tanh,
        scaling_factor: float = 1.0,
    ) -> Float[Array, "N  n_features"]:
    """
    Forward pass for a single layer of the SWIM model.
    """
    d, D = weights.shape
    X1 = activation(X @ weights + biases)
    if add_residual:
        return scaling_factor*X1 + X
    else:
        return X1



def SWIM_all_layers(
        seed: PRNGKeyArray,
        X0: Float[Array, "N  d"],
        y: Float[Array, "N  p"],
        n_features: int,
        activation: Callable,
        n_layers: int,
        add_residual: bool,
        residual_scaling_factor: float = 1.0,
        sampling_method: Literal["uniform", "gradient-weighted"] = "gradient-weighted",
    ):
    """
    Fits the weights for the SWIM model, iteratively layer by layer

    Args:
        seed (PRNGKeyArray): Random seed for the weights and biases.
        X0 (Float[Array, "N  d"]): First layer input.
        y (Float[Array, "N  p"]): Target training data.
        n_features (int): Hidden layer size.
        activation (Callable): Activation function for the network.
        n_layers (int): Number of layers in the network.
        add_residual (bool): Whether to use residual connections.
        residual_scaling_factor (float): Scaling factor for the residual connections.
        sampling_method (str): Uniform or gradient-weighted pair sampling for weight initialization.

    Returns:
        Weights (d, n_features) and biases (1, n_features) for the next layer.
    """

    def scan_body(carry, seed): # (carry, x) -> (carry, y)
        X, y = carry
        w, b = init_single_SWIM_layer(seed, X, y, n_features, sampling_method)
        X = forward_1_layer(X, w, b, add_residual, activation, residual_scaling_factor)
        return (X, y), (w, b)

    init_carry = (X0, y)
    carry, WaB = lax.scan(
        scan_body,
        init_carry,
        xs=jax.random.split(seed, n_layers),
    )
    return WaB



def all_forward(
        X: Float[Array, "N  d"], 
        w1: Float[Array, "d  D"],
        b1: Float[Array, "1  D"], 
        weights: Float[Array, "n_layers-1  d  D"],
        biases: Float[Array, "n_layers-1  1  D"], 
        n_layers:int,
        add_residual: bool,
        activation = lambda x : jnp.maximum(0,x+0.5), # jnp.tanh,
    ):
    """
    Forward pass for the SWIM model.

    Args:
        X (Float[Array, "N  d"]): Input to the model.
        w1 (Float[Array, "d  D"]): Weights for the first layer.
        b1 (Float[Array, "1  D"]): Biases for the first layer.
        weights (Float[Array, "n_layers-1  d  D"]): Weights for the remaining layers.
        biases (Float[Array, "n_layers-1  1  D"]): Biases for the remaining layers.
        n_layers (int): Number of layers in the network.
        add_residual (bool): Whether to use residual connections
        activation (Callable): Activation function for the network.
    Returns:
        Output of the model of shape (N, D).
    """




class SWIM_MLP(TimeseriesFeatureTransformer):
    def __init__(
            self,
            seed: PRNGKeyArray,
            n_features: int = 512,
            n_layers: int = 3,
            add_residual: bool = False,
            max_batch: int = 512,
            activation = lambda x : jnp.maximum(0,x+0.5), # jnp.tanh,
        ):
        """Implementation of the original paper's SWIM model
        https://gitlab.com/felix.dietrich/swimnetworks-paper/,
        but with support for residual connections.

        Args:
            seed (PRNGKeyArray): Random seed for matrices, biases, initial value.
            n_features (int): Hidden layer dimension.
            n_layers (int): Number of layers in the network.
            add_residual (bool): Whether to use residual connections.
            max_batch (int): Max batch size for computations.
            activation (Callable): Activation function for the network.
        """
        super().__init__(max_batch)
        self.n_features = n_features
        self.n_layers = n_layers
        self.seed = seed
        self.add_residual = add_residual
        self.activation = activation
        self.w1 = None
        self.b1 = None
        self.weights = None
        self.biases = None


    def fit(
            self, 
            X: Float[Array, "N  D"], 
            y: Float[Array, "N  d"]
        ):
        """
        Initializes MLP weights and biases, using SWIM algorithm.

        Args:
            X (Float[Array, "N  D"]): Input training data.
            y (Float[Array, "N  d"]): Target training data.
        """
        #TODO TODO TODO do this, add new args to all the functions and class init, such as activation and residual_scale_factor TODO
        # Get shape, dtype
        N, D = X.shape
        seed1, seedrest = jax.random.split(self.seed, 2)

        #first do first layer, which cannot always be done in a scan loop
        self.w1, self.b1 = init_single_SWIM_layer(
            seed1, X, y, self.n_features
            )
        X = forward_1_layer(X, self.w1, self.b1, self.add_residual)
        
        #rest of the layers
        if self.n_layers > 1:
            self.weights, self.biases = SWIM_all_layers(
                X, y, self.n_features, self.n_layers-1, self.add_residual, seedrest
                )

        return self


    def transform(self, X: Float[Array, "N  D"]) -> Float[Array, "N  n_features"]:
        #TODO TODO TODO do this
        #First hidden layer
        X = forward_1_layer(X, w1, b1, add_residual, activation)
        if n_layers == 1:
            return X
        #subsequent layers in a scan loop
        else:
            def scan_body(carry, t):
                X = carry
                w, b = weights[t], biases[t]
                return forward_1_layer(X, w, b, add_residual, activation), None

            X, _ = lax.scan(scan_body, X, xs=jnp.arange(n_layers-1))
            return X

In [None]:

def neuron_distribution_for_each_layer(
        X_train: Array, 
        y_train: Array, 
        X_test: Array, 
        model: Callable, 
        hidden_size: int, 
        n_layers: int, 
        random_seed: int) -> Tuple[Array, Array]:
    """Looks at the distribution of neurons for each layer of a neural network model
    (used to compare SWIM, residual sampling, and random feature networks).
    """
    
    # Initialize the arrays to store the neuron distribution
    train_layers= []
    test_layers = []
    
    # for each layer
    for t in range(n_layers):
        
        # Train the model
        model.fit(X_train, y_train)
    
    
    return train_neuron_distribution, test_neuron_distribution

I want to look at the distribution of weights (eigenvalues? absolute values of rows? distribution of (assuming iid) matrix entries?)

distribution of neurons at each layer

This is for both SWIM, Residual SWIM, random features, 