In [1]:
from typing import Tuple, List, Union, Any, Optional, Dict, Literal, Callable
import time
import collections
import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))
sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))

from tqdm import tqdm
import openml
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.functional import relu
from torch.nn.functional import tanh
import pandas as pd
import numpy as np

from aeon.regression.sklearn import RotationForestRegressor
from sklearn.metrics import mean_squared_error, mean_absolute_error
from aeon.datasets.tser_datasets import tser_soton
from aeon.datasets import load_regression, load_classification
from sklearn.linear_model import RidgeCV
from sklearn.model_selection import train_test_split

from preprocessing.stream_transforms import normalize_mean_std_traindata, normalize_streams, augment_time, add_basepoint_zero
from utils.utils import print_name, print_shape

np.set_printoptions(precision=3, threshold=5) # Print options



# OpenML code

In [2]:
# # Fetch the collection with ID 353
# collection = openml.study.get_suite(353)
# dataset_ids = collection.data
# metadata_list = []

# # Fetch and process each dataset
# for i, dataset_id in enumerate(dataset_ids):
#     dataset = openml.datasets.get_dataset(dataset_id)
#     X, y, categorical_indicator, attribute_names = dataset.get_data(
#         target=dataset.default_target_attribute
#     )
#     X = np.array(X)
#     y = np.array(y)[..., None]
    
#     # Extract the required metadata
#     metadata = {
#         'dataset_id': dataset.id,
#         'name': dataset.name,
#         'n_obs': int(dataset.qualities['NumberOfInstances']),
#         'n_features': int(dataset.qualities['NumberOfFeatures']),
#         '%_unique_y': len(np.unique(y))/len(y),
#         'n_unique_y': len(np.unique(y)),
#     }
    
#     metadata_list.append(metadata)
#     print(f" {i+1}/{len(dataset_ids)} Processed dataset {dataset.id}: {dataset.name}")

# # Create a DataFrame from the metadata list
# df_metadata = pd.DataFrame(metadata_list).sort_values('%_unique_y', ascending=False)
# df_metadata.sort_values('%_unique_y', ascending=True)

# Download single dataset

In [3]:
def load_openml_dataset(dataset_id, 
                        normalize_X:bool = True,
                        normalize_y:bool = True,
                        train_test_size:float = 0.7,
                        split_seed:int = 0) -> Tuple[np.ndarray, np.ndarray]:
    # Fetch dataset from OpenML by its ID
    dataset = openml.datasets.get_dataset(dataset_id)
    X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute)
    X = np.array(X)
    y = np.array(y)[..., None]
    X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=train_test_size, random_state=split_seed)

    #normalize
    if normalize_X:
        X_train, X_test = normalize_mean_std_traindata(X_train, X_test)
    if normalize_y:
        y_train, y_test = normalize_mean_std_traindata(y_train, y_test)

    return (Tensor(X_train.astype(np.float32)), 
            Tensor(X_test.astype(np.float32)), 
            Tensor(y_train.astype(np.float32)), 
            Tensor(y_test.astype(np.float32)))

dataset_id = 44971  # Replace with the dataset ID you want
X_train, X_test, y_train, y_test = load_openml_dataset(dataset_id)


# nn.Module for sampled networks

In [4]:
#################################################################
##### Base classes                                          #####
##### - FittableModule: A nn.Module with .fit(X, y) support #####
##### - ResNetBase: which interatively calls .fit(X, y)     #####
#################################################################

class FittableModule(nn.Module):
    def __init__(self):
        super(FittableModule, self).__init__()
    

    def fit(self, 
            X: Optional[Tensor] = None, 
            y: Optional[Tensor] = None,
        ) -> Tuple[Optional[Tensor], Optional[Tensor]]:
        """Given neurons of the previous layer, and target labels, fit the 
        module. Returns the forwarded activations and labels [f(X), y].

        Args:
            X (Optional[Tensor]): Forward-propagated activations of training data, shape (N, d).
            y (Optional[Tensor]): Training labels, shape (N, p).
        
        Returns:
            Forwarded activations and labels [f(X), y].
        """
        return self(X), y



class ResNetBase(nn.Module):
    def __init__(self,
                upsample:FittableModule,
                blocks:List[FittableModule],
                output_layer:FittableModule,
                ):
        """Residual Network base class, with fit method for non-SGD training/initialization.

        Args:
            upsample (FittableModule): _description_
            blocks (List[FittableModule]): _description_
            output_layer (FittableModule): _description_
        """
        super(ResNetBase, self).__init__()
        self.upsample = upsample
        self.blocks = nn.ModuleList(blocks)
        self.output_layer = output_layer

    
    def fit(self, X:Tensor, y:Tensor):
        # X shape (N, d)
        # y shape (N, p)
        X, y = self.upsample.fit(X, y)
        print(X)
        for block in self.blocks:
            X, y = block.fit(X, y)
            print(X)
        X, y = self.output_layer.fit(X, y)
        print(X)
        return X, y

    
    def forward(self, x:Tensor) -> Tensor:
        # x shape (N, d)
        x = self.upsample(x)
        for block in self.blocks:
            x = block(x) + x
        x = self.output_layer(x)
        return x

# I need:
# - ResNetBase
# - Sampled Layer
# - 1 layer Sampled Network
# - RidgeCV and TODO other classifiers/regressors. Should I implement this in pytorch or use sklearn?

In [47]:
##############################
######## Dense Layer ########
#############################


class Dense(FittableModule):
    def __init__(self,
                 generator: torch.Generator,
                 in_dim: int,
                 out_dim: int,
                 ):
        """Dense MLP layer with LeCun weight initialization,
        Gaussan bias initialization."""
        super(Dense, self).__init__()
        self.generator = generator
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.dense = nn.Linear(in_dim, out_dim)
    
    def fit(self, X:Tensor, y:Tensor):
        with torch.no_grad():
            nn.init.normal_(self.dense.weight, mean=0, std=self.in_dim**-0.5, generator=self.generator)
            nn.init.normal_(self.dense.bias, mean=0, std=0.1, generator=self.generator) #arbitrary
            return self(X), y
    
    def forward(self, X):
        return self.dense(X)
    

class Identity(FittableModule):
    def __init__(self):
        super(Identity, self).__init__()
    
    def fit(self, X:Tensor, y:Tensor):
        return X, y
    
    def forward(self, X):
        return X


D = X_train.shape[1]
g1 = torch.Generator().manual_seed(0)
net = Dense(g1, D, 3)
net.fit(X_train, y_train)
out = net(X_test)
print_name(out)
print(net)

torch.Size([1470, 3]) out torch.float32
tensor([[ 0.2244,  0.2429, -0.3403],
        [-0.1479, -0.3924,  0.4136],
        [ 0.8625,  0.9489, -0.4612],
        ...,
        [ 2.3646, -0.8617, -0.9985],
        [ 1.1491, -1.4283,  0.4724],
        [-0.7699,  1.0977,  0.5348]], grad_fn=<AddmmBackward0>) 

Dense(
  (dense): Linear(in_features=11, out_features=3, bias=True)
)


In [6]:
###############################
#### Pair Sampled Networks ####
###############################


class PairSampledLinear(FittableModule):
    def __init__(self,
                 generator: torch.Generator,
                 in_dim: int, 
                 out_dim: int,
                 sampling_method: Literal['uniform', 'gradient'] = 'gradient'
                 ):
        """Dense MLP layer with pair sampled weights.

        Args:
            generator (torch.Generator): PRNG object.
            in_dim (int): Input dimension.
            out_dim (int): Output dimension.
            sampling_method (str): Pair sampling method. Uniform or gradient-weighted.
        """
        super(PairSampledLinear, self).__init__()
        self.generator = generator
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.dense = nn.Linear(in_dim, out_dim)
        self.sampling_method = sampling_method


    def fit(self, 
            X: Tensor, 
            y: Tensor,
        ) -> Tuple[Tensor, Tensor]:
        """Given forward-propagated training data X at the previous 
        hidden layer, and supervised target labels y, fit the weights
        iteratively by letting rows of the weight matrix be given by
        pairs of samples from X. See paper for more details.

        Args:
            X (Tensor): Forward-propagated activations of training data, shape (N, d).
            y (Tensor): Training labels, shape (N, p).
        
        Returns:
            Forwarded activations and labels [f(X), y].
        """
        with torch.no_grad():
            N, d = X.shape
            dtype = X.dtype
            device = X.device
            EPS = torch.tensor(0.1, dtype=dtype, device=device)

            #obtain pair indices
            n = 5*N
            idx1 = torch.arange(0, n, dtype=torch.int32, device=device) % N
            delta = torch.randint(1, N, size=(n,), dtype=torch.int32, device=device, generator=self.generator)
            idx2 = (idx1 + delta) % N
            dx = X[idx2] - X[idx1]
            dists = torch.linalg.norm(dx, axis=1, keepdims=True)
            dists = torch.maximum(dists, EPS)
            
            if self.sampling_method=="gradient":
                #calculate 'gradients'
                dy = y[idx2] - y[idx1]
                y_norm = torch.linalg.norm(dy, axis=1, keepdims=True) #NOTE 2023 paper uses ord=inf instead of ord=2
                grad = (y_norm / dists).reshape(-1) 
                p = grad/grad.sum()
            elif self.sampling_method=="uniform":
                p = torch.ones(n, dtype=dtype, device=device) / n
            else:
                raise ValueError(f"sampling_method must be 'uniform' or 'gradient'. Given: {self.sampling_method}")

            #sample pairs
            selected_idx = torch.multinomial(
                p,
                self.out_dim,
                replacement=True,
                generator=self.generator
                )
            idx1 = idx1[selected_idx]
            dx = dx[selected_idx]
            dists = dists[selected_idx]

            #define weights and biases
            weights = dx / (dists**2)
            biases = -torch.einsum('ij,ij->i', weights, X[idx1]) - 0.5
            self.dense.weight.data = weights
            self.dense.bias.data = biases
            return self(X), y
    

    def forward(self, X):
        return self.dense(X)
    
    
D = X_train.shape[1]
g1 = torch.Generator().manual_seed(0)
net = PairSampledLinear(g1, D, 3)
net.fit(X_train, y_train)
out = net(X_test)
print_name(out)
print(net)

torch.Size([1470, 3]) out torch.float32
tensor([[-0.1170, -0.2263, -0.4405],
        [-0.3022, -0.4827, -0.1504],
        [ 0.0027,  0.0866, -0.4289],
        ...,
        [ 0.0508,  0.1066,  0.0260],
        [-0.2116, -0.5603,  0.3439],
        [ 0.0938, -1.5502, -0.9001]], grad_fn=<AddmmBackward0>) 

PairSampledLinear(
  (dense): Linear(in_features=11, out_features=3, bias=True)
)


In [7]:
###################################
#### Sampled Bottleneck ResNet ####
###################################


class SampledBlock(FittableModule):
    def __init__(self,
                 generator: torch.Generator,
                 hidden_dim: int, 
                 activation_dim: int,
                 activation: nn.Module = nn.Tanh(),
                 sampling_method: Literal['uniform', 'gradient'] = 'gradient'
                 ):
        """A sampled layer followed by activation and linear layer.
        Equivalent to a 1-hidden-layer Sampled Neural Network.

        Args:
            generator (torch.Generator): PRNG object.
            in_dim (int): Input dimension.
            out_dim (int): Output dimension.
            activation (nn.Module): Activation function.
            sampling_method (str): Pair sampling method. Uniform or gradient-weighted.
        """
        super(SampledBlock, self).__init__()
        self.generator = generator
        self.sampled_linear = PairSampledLinear(generator, hidden_dim, activation_dim, sampling_method)
        self.activation = activation
        self.upscale = Dense(generator, activation_dim, hidden_dim)
    

    def fit(self, X: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]:
        with torch.no_grad():
            X, y = self.sampled_linear.fit(X, y)
            X = self.activation(X)
            X, y = self.upscale.fit(X, y)
            return X, y

    
    def forward(self, X):
        X = self.sampled_linear(X)
        X = self.activation(X)
        X = self.upscale(X)
        return X
    
    
D = X_train.shape[1]
g1 = torch.Generator().manual_seed(0)
net = SampledBlock(g1, D, 3)
net.fit(X_train, y_train)
out = net(X_test)
print_name(out)
print(net)

torch.Size([1470, 11]) out torch.float32
tensor([[ 2.9955e-01, -2.2911e-01, -1.0834e-01,  ...,  4.5266e-01,
         -3.3509e-01,  4.6695e-01],
        [ 7.5369e-02, -5.1524e-01, -7.1613e-03,  ...,  3.7685e-01,
         -6.7468e-02,  2.6877e-01],
        [ 3.0491e-01, -1.8205e-02, -1.2761e-01,  ...,  2.8464e-01,
         -3.5406e-01,  3.7600e-01],
        ...,
        [-6.5020e-02, -1.0969e-01, -1.5286e-01,  ...,  3.6806e-04,
         -3.4603e-02, -4.7464e-02],
        [-3.4870e-01, -6.3833e-01, -8.5864e-02,  ...,  1.1641e-01,
          2.6838e-01, -1.8195e-01],
        [ 4.6800e-01, -3.4668e-01, -4.9589e-01,  ...,  1.1188e+00,
         -7.2387e-01,  9.4487e-01]], grad_fn=<AddmmBackward0>) 

SampledBlock(
  (sampled_linear): PairSampledLinear(
    (dense): Linear(in_features=11, out_features=3, bias=True)
  )
  (activation): Tanh()
  (upscale): Dense(
    (dense): Linear(in_features=3, out_features=11, bias=True)
  )
)


In [8]:
#####################
### RidgeCV Layer ###
#####################

class RidgeCVModule(FittableModule):
    def __init__(self, alphas=np.logspace(-1, 3, 10)):
        """RidgeCV layer using sklearn's RidgeCV. TODO dont use sklearn"""
        super(RidgeCVModule, self).__init__()
        self.ridge = RidgeCV(alphas=alphas)

    def fit(self, X: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]:
        """Fit the RidgeCV model. TODO dont use sklearn"""
        X_np = X.detach().cpu().numpy().astype(np.float64)
        y_np = y.detach().cpu().squeeze().numpy().astype(np.float64)
        self.ridge.fit(X_np, y_np)
        return self(X), y

    def forward(self, X: Tensor) -> Tensor:
        """Forward pass through the RidgeCV model. TODO dont use sklearn"""
        X_np = X.detach().cpu().numpy().astype(np.float64)
        y_pred_np = self.ridge.predict(X_np)
        return torch.tensor(y_pred_np, dtype=X.dtype, device=X.device).unsqueeze(1) #TODO unsqueeze???


D = X_train.shape[1]
g1 = torch.Generator()
net = RidgeCVModule()
net.fit(X_train, y_train)
out = net(X_test)
print_name(out)
print_name(y_test)
print(net)
print("alpha", net.ridge.alpha_)
print("rmse", mean_squared_error(y_test.detach().cpu().numpy(), out.detach().cpu().numpy(), squared=False))

torch.Size([1470, 1]) out torch.float32
tensor([[-0.3634],
        [-0.4966],
        [ 0.2163],
        ...,
        [ 0.5423],
        [ 0.1570],
        [ 0.1244]]) 

torch.Size([1470, 1]) y_test torch.float32
tensor([[-1.0209],
        [ 0.1168],
        [ 1.2545],
        ...,
        [ 0.1168],
        [ 0.1168],
        [ 0.1168]]) 

RidgeCVModule()
alpha 2.1544346900318834
rmse 0.88653886




In [91]:
class SampledResNet(ResNetBase):
    def __init__(self,
                 generator: torch.Generator,
                 in_dim: int,
                 hidden_dim: int,
                 activation_dim: int, #rename to bottleneck dim?
                 n_blocks: int,
                 activation: nn.Module = nn.Tanh(),
                 upsample_method: Literal['dense', 'sampled', 'identity'] = 'dense',
                 sampling_method: Literal['uniform', 'gradient'] = 'gradient'
                 ):
        """A ResNet with sampled layers as bottleneck layers.
        """
        if upsample_method=="dense":
            upsample = Dense(generator, in_dim, hidden_dim)
        elif upsample_method=="sampled":
            upsample = PairSampledLinear(generator, in_dim, hidden_dim, sampling_method)
        elif upsample_method=="identity":
            upsample = Identity()
        else:
            raise ValueError(f"upsample_method must be 'dense', 'sampled' or 'identity'. Given: {upsample_method}")

        blocks = [SampledBlock(generator, 
                               hidden_dim, 
                               activation_dim,
                               activation,
                               sampling_method
                               ) for _ in range(n_blocks)]
        ridge = RidgeCVModule()
        super(SampledResNet, self).__init__(upsample, blocks, ridge)


D = X_train.shape[1]
g1 = torch.Generator().manual_seed(int(time.time()*10))
net = SampledResNet(g1, D, 3*D, 2*D, 0, upsample_method='sampled', sampling_method='uniform')
net.fit(X_train, y_train)
out = net(X_test)
print_name(out)
print(y_test)
print(net)
print("rmse", mean_squared_error(y_test.detach().cpu().numpy(), out.detach().cpu().numpy()))
print(net.output_layer.ridge.alpha_)

tensor([[ 0.8281, -0.5969, -0.3491,  ..., -0.4852, -0.1246,  0.0519],
        [-0.0388,  0.5599,  0.1400,  ...,  0.0787, -0.1044, -0.1064],
        [ 0.9605, -0.2854, -0.0918,  ..., -0.6064, -0.0846,  0.0346],
        ...,
        [ 1.0064, -1.2030,  0.0122,  ..., -1.2521, -0.1361,  0.4100],
        [ 1.1758,  0.3042,  0.3017,  ..., -0.2136,  0.3592, -0.0513],
        [ 0.2833,  0.4097,  0.0871,  ..., -0.0392, -0.1120, -0.0608]])
tensor([[-0.7473],
        [ 0.6416],
        [-0.0361],
        ...,
        [-0.8562],
        [ 0.1393],
        [-0.9555]])
torch.Size([1470, 1]) out torch.float32
tensor([[-0.3910],
        [-0.4883],
        [ 0.1868],
        ...,
        [ 0.5399],
        [ 0.2303],
        [ 0.0784]]) 

tensor([[-1.0209],
        [ 0.1168],
        [ 1.2545],
        ...,
        [ 0.1168],
        [ 0.1168],
        [ 0.1168]])
SampledResNet(
  (upsample): PairSampledLinear(
    (dense): Linear(in_features=11, out_features=33, bias=True)
  )
  (blocks): ModuleList()

# SWIM tabular model

In [10]:
from typing import Tuple, List, Union, Any, Optional, Dict, Set, Literal, Callable
from abc import ABC, abstractmethod
from functools import partial

import numpy as np
import jax
import jax.numpy as jnp
import jax.lax as lax
from jaxtyping import Array, Float, Int, PRNGKeyArray


def init_single_SWIM_layer(
        generator: torch.Generator,
        X: Tensor,
        y: Tensor,
        hidden_size: int,
        sampling_method: Literal["uniform", "gradient"] = "gradient",
    ):
    """
    Fits the weights for a single layer of the SWIM model.

    Args:
        generator (torch.Generator): PRNG object.
        X (Tensor): Previous layer's output, shape (N, d).
        y (Tensor): Target training data, shape (N, p).
        hidden_size (int): Next hidden layer size.
        sampling_method (str): Uniform or gradient-weighted pair sampling.
    Returns:
        Weights (d, hidden_size) and biases (1, hidden_size) for the next layer.
    """
    N, d = X.shape
    EPS = 0.1

    #obtain pair indices
    n = 5*N
    idx1 = jnp.arange(0, n) % N
    delta = torch.randint(low=1, high=N, shape=(n,), generator=generator)
    idx2 = (idx1 + delta) % N
    dx = X[idx2] - X[idx1]
    dists = torch.linalg.norm(dx, axis=1, keepdims=True)
    dists = torch.maximum(EPS, dists)
    
    if sampling_method=="gradient":
        #calculate 'gradients'
        dy = y[idx2] - y[idx1]
        y_norm = torch.linalg.norm(dy, axis=1, keepdims=True)
        grad = (y_norm / dists).reshape(-1) #NOTE 2023 paper uses ord=inf instead of ord=2
        p = grad/grad.sum()
    elif sampling_method=="uniform":
        p = None
    else:
        raise ValueError(f"sampling_method must be 'uniform' or 'gradient'. Given: {sampling_method}")

    #sample pairs
    selected_idx = torch.multinomial(
        p,
        hidden_size,
        replacement=True,
        generator=generator
        )
    idx1 = idx1[selected_idx]
    dx = dx[selected_idx]
    dists = dists[selected_idx]

    #define weights and biases
    weights = (dx / dists**2).T
    biases = -torch.sum(weights * X[idx1].T, axis=0, keepdims=True) - 0.5  # NOTE experiment with this. also +-0.5 ?
    return weights, biases



def forward_1_layer(
        X: Tensor, # shape (N, d)
        weights: Tensor, # shape (d, hidden_size)
        biases: Tensor, # shape (1, hidden_size)
        add_residual: bool,
        activation: Callable,
        scaling_factor: float = 1.0,
    ) -> Float[Array, "N  n_features"]:
    """
    Forward pass for a single layer of the SWIM model.
    """
    d, D = weights.shape
    X1 = activation(X @ weights + biases)
    if add_residual and d==D:
        print("residual")
        return scaling_factor*X1 + X
    else:
        print("not residual")
        return X1



def SWIM_all_layers(
        generator: torch.Generator,
        X0: Tensor, # shape (N, d)
        y: Tensor, #shape (N, p)
        hidden_size: int,
        activation: Callable,
        n_layers: int,
        add_residual: bool,
        residual_scaling_factor: float = 1.0,
        sampling_method: Literal["uniform", "gradient"] = "gradient",
    ):
    """
    Fits the weights for the SWIM model, iteratively layer by layer

    Args:
        generator (torch.Generator): PRNG object.
        X0 (Float[Array, "N  d"]): First layer input.
        y (Float[Array, "N  p"]): Target training data.
        hidden_size (int): Hidden layer size.
        activation (Callable): Activation function for the network.
        n_layers (int): Number of layers in the network.
        add_residual (bool): Whether to use residual connections.
        residual_scaling_factor (float): Scaling factor for the residual connections.
        sampling_method (str): Uniform or gradient-weighted pair sampling for weight initialization.

    Returns:
        Weights (d, n_features) and biases (1, n_features) for the next layer.
    """

    def scan_body(carry, seed): # (carry, x) -> (carry, y)
        X, y = carry
        w, b = init_single_SWIM_layer(seed, X, y, n_features, sampling_method)
        X = forward_1_layer(X, w, b, add_residual, activation, residual_scaling_factor)
        return (X, y), (w, b)

    init_carry = (X0, y)
    carry, WaB = lax.scan(
        scan_body,
        init_carry,
        xs=jax.random.split(seed, n_layers),
    )
    return WaB



class SWIM_MLP():
    def __init__(
            self,
            seed: PRNGKeyArray,
            n_features: int = 512,
            n_layers: int = 3,
            add_residual: bool = False,
            sampling_method: Literal["uniform", "gradient-weighted"] = "gradient-weighted",
            activation = lambda x : jnp.maximum(0,x+0.5), # jnp.tanh,
            residual_scaling_factor: Optional[float] = None,
        ):
        """Implementation of the original paper's SWIM model
        https://gitlab.com/felix.dietrich/swimnetworks-paper/,
        but with support for residual connections.

        Args:
            seed (PRNGKeyArray): Random seed for matrices, biases, initial value.
            n_features (int): Hidden layer dimension.
            n_layers (int): Number of layers in the network.
            add_residual (bool): Whether to use residual connections.
            sampling_method (str): Uniform or gradient-weighted pair sampling for weight initialization.
            activation (Callable): Activation function for the network.
            residual_scaling_factor (Optional[float]): Scaling factor for the residual skip connections.
        """
        self.n_features = n_features
        self.n_layers = n_layers
        self.seed = seed
        self.add_residual = add_residual
        self.sampling_method = sampling_method
        self.activation = activation
        self.w1 = None
        self.b1 = None
        self.weights = None
        self.biases = None
        self.residual_scaling_factor = 1/self.n_layers if residual_scaling_factor is None else residual_scaling_factor


    def fit(
            self, 
            X: Float[Array, "N  D"], 
            y: Float[Array, "N  p"]
        ):
        """
        Initializes MLP weights and biases, using SWIM algorithm.

        Args:
            X (Float[Array, "N  D"]): Input training data.
            y (Float[Array, "N  p"]): Target training data.
        """
        # Get shape, dtype
        seed1, seedrest = jax.random.split(self.seed, 2)

        #first do first layer, which cannot always be done in a scan loop
        self.w1, self.b1 = init_single_SWIM_layer(
            seed1, X, y, self.n_features, self.sampling_method
            )
        X = forward_1_layer(X, self.w1, self.b1, self.add_residual)
        
        #rest of the layers
        if self.n_layers > 1:
            self.weights, self.biases = SWIM_all_layers(
                seedrest, X, y, self.n_features, self.activation, self.n_layers-1, 
                self.add_residual, self.residual_scaling_factor, self.sampling_method
                )

        return self


    def transform(self, X: Float[Array, "N  D"], only_last=True) -> Float[Array, "N  n_features"]:
        #First hidden layer
        X = forward_1_layer(X, self.w1, self.b1, self.add_residual, self.activation, self.residual_scaling_factor)
        if self.n_layers == 1:
            return X
        #subsequent layers in a scan loop
        else:
            def scan_body(carry, inputs):
                X = carry
                w, b = inputs
                return forward_1_layer(X, w, b, self.add_residual, self.activation, self.residual_scaling_factor), X #TODO temporarily return all laters

            last_X, stacked_X = lax.scan(scan_body, X, xs=(self.weights, self.biases))
            if only_last:
                return last_X
            else:
                return jnp.concat([stacked_X, last_X[None]], axis=0)

In [11]:
def neuron_distribution_for_each_layer(
        X_train: Array,
        y_train: Array,
        X_test: Array,
        hidden_size: int,
        n_layers: int,
        ) -> Tuple[Array, Array]:
    """Looks at the distribution of neurons for each layer of a neural network model
    (used to compare SWIM, residual sampling, and random feature networks).
    """
    
    # Initialize the arrays to store the neuron distribution
    train_layers= []
    test_layers = []

    model = SWIM_MLP(
        jax.random.PRNGKey(0), 
        hidden_size, 
        n_layers,
        add_residual=True,
        sampling_method="gradient-weighted",
        #activation = jnp.tanh,
        #residual_scaling_factor=1.0,
        
        )

    model.fit(X_train, y_train)

    feat_test  = model.transform(X_train, only_last=False).reshape(n_layers, -1)
    feat_train = model.transform(X_test, only_last=False).reshape(n_layers, -1)
    
    print(feat_test[1]-feat_test[1])

    #features are shape (n_layers, n_samples, n_features)
    import matplotlib.pyplot as plt
    plt.figure(figsize=(10, 6))
    plt.hist(X_train.flatten(), bins=50, alpha=0.5, label='Train', density=True)
    plt.hist(X_test.flatten(), bins=50, alpha=0.5, label='Test', density=True)
    plt.title('Input Data Distribution')
    plt.xlabel('Input Feature Value')
    plt.ylabel('Probability Density')
    plt.legend()
    plt.show()

    for layer in range(n_layers):
        plt.figure(figsize=(10, 6))
        plt.hist(feat_train[layer], bins=50, alpha=0.5, label='Train', density=True)
        plt.hist(feat_test[layer], bins=50, alpha=0.5, label='Test', density=True)
        plt.title(f'Layer {layer + 1} Neuron Distribution')
        plt.xlabel('Neuron Activation')
        plt.ylabel('Probability Density')
        plt.legend()
        plt.show()

    print(feat_test.shape)

neuron_distribution_for_each_layer(X_train, y_train, X_test, 128, 100)

2024-10-18 16:51:30.330419: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:266] failed call to cuInit: CUDA_ERROR_UNKNOWN: unknown error


RuntimeError: Unable to initialize backend 'cuda': FAILED_PRECONDITION: No visible GPU devices. (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)

I want to look at the distribution of weights (eigenvalues? absolute values of rows? distribution of (assuming iid) matrix entries?)

distribution of neurons at each layer

This is for both SWIM, Residual SWIM, random features, 