In [1]:
from typing import Tuple, List, Union, Any, Optional, Dict, Literal, Callable
import time
import collections
import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))
sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))

from tqdm import tqdm
import openml
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.functional import relu
from torch.nn.functional import tanh
import pandas as pd
import numpy as np

from aeon.regression.sklearn import RotationForestRegressor
from sklearn.metrics import mean_squared_error, mean_absolute_error
from aeon.datasets.tser_datasets import tser_soton
from aeon.datasets import load_regression, load_classification
from sklearn.linear_model import RidgeCV
from sklearn.model_selection import train_test_split

from preprocessing.stream_transforms import normalize_mean_std_traindata, normalize_streams, augment_time, add_basepoint_zero
from utils.utils import print_name, print_shape

np.set_printoptions(precision=3, threshold=5) # Print options



# OpenML code

In [2]:
# # Fetch the collection with ID 353
# collection = openml.study.get_suite(353)
# dataset_ids = collection.data
# metadata_list = []

# # Fetch and process each dataset
# for i, dataset_id in enumerate(dataset_ids):
#     dataset = openml.datasets.get_dataset(dataset_id)
#     X, y, categorical_indicator, attribute_names = dataset.get_data(
#         target=dataset.default_target_attribute
#     )
#     X = np.array(X)
#     y = np.array(y)[..., None]
    
#     # Extract the required metadata
#     metadata = {
#         'dataset_id': dataset.id,
#         'name': dataset.name,
#         'n_obs': int(dataset.qualities['NumberOfInstances']),
#         'n_features': int(dataset.qualities['NumberOfFeatures']),
#         '%_unique_y': len(np.unique(y))/len(y),
#         'n_unique_y': len(np.unique(y)),
#     }
    
#     metadata_list.append(metadata)
#     print(f" {i+1}/{len(dataset_ids)} Processed dataset {dataset.id}: {dataset.name}")

# # Create a DataFrame from the metadata list
# df_metadata = pd.DataFrame(metadata_list).sort_values('%_unique_y', ascending=False)
# df_metadata.sort_values('%_unique_y', ascending=True)

# Download single dataset

In [3]:
def load_openml_dataset(dataset_id, 
                        normalize_X:bool = True,
                        normalize_y:bool = True,
                        train_test_size:float = 0.7,
                        split_seed:int = 0) -> Tuple[np.ndarray, np.ndarray]:
    # Fetch dataset from OpenML by its ID
    dataset = openml.datasets.get_dataset(dataset_id)
    X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute)
    X = np.array(X)
    y = np.array(y)[..., None]
    X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=train_test_size, random_state=split_seed)

    #normalize
    if normalize_X:
        X_train, X_test = normalize_mean_std_traindata(X_train, X_test)
    if normalize_y:
        y_train, y_test = normalize_mean_std_traindata(y_train, y_test)

    return (Tensor(X_train.astype(np.float32)), 
            Tensor(X_test.astype(np.float32)), 
            Tensor(y_train.astype(np.float32)), 
            Tensor(y_test.astype(np.float32)))

#dataset_id = 44971  # Replace with the dataset ID you want
dataset_id = 44970
X_train, X_test, y_train, y_test = load_openml_dataset(dataset_id)


# nn.Module for sampled networks

In [4]:
#################################################################
##### Base classes                                          #####
##### - FittableModule: A nn.Module with .fit(X, y) support #####
##### - ResNetBase: which interatively calls .fit(X, y)     #####
#################################################################

class FittableModule(nn.Module):
    def __init__(self):
        super(FittableModule, self).__init__()
    

    def fit(self, 
            X: Optional[Tensor] = None, 
            y: Optional[Tensor] = None,
        ) -> Tuple[Optional[Tensor], Optional[Tensor]]:
        """Given neurons of the previous layer, and target labels, fit the 
        module. Returns the forwarded activations and labels [f(X), y].

        Args:
            X (Optional[Tensor]): Forward-propagated activations of training data, shape (N, d).
            y (Optional[Tensor]): Training labels, shape (N, p).
        
        Returns:
            Forwarded activations and labels [f(X), y].
        """
        return self(X), y



class ResNetBase(nn.Module):
    def __init__(self,
                upsample:FittableModule,
                blocks:List[FittableModule],
                output_layer:FittableModule,
                ):
        """Residual Network base class, with fit method for non-SGD training/initialization.

        Args:
            upsample (FittableModule): _description_
            blocks (List[FittableModule]): _description_
            output_layer (FittableModule): _description_
        """
        super(ResNetBase, self).__init__()
        self.upsample = upsample
        self.blocks = nn.ModuleList(blocks)
        self.output_layer = output_layer

    
    def fit(self, X:Tensor, y:Tensor):
        # X shape (N, d)
        # y shape (N, p)
        X, y = self.upsample.fit(X, y)
        for block in self.blocks:
            X, y = block.fit(X, y)
        X, y = self.output_layer.fit(X, y)
        return X, y

    
    def forward(self, x:Tensor) -> Tensor:
        # x shape (N, d)
        x = self.upsample(x)
        for block in self.blocks:
            x = block(x)
        x = self.output_layer(x)
        return x

In [5]:
##############################
######## Dense Layer ########
#############################


class Dense(FittableModule):
    def __init__(self,
                 generator: torch.Generator,
                 in_dim: int,
                 out_dim: int,
                 ):
        """Dense MLP layer with LeCun weight initialization,
        Gaussan bias initialization."""
        super(Dense, self).__init__()
        self.generator = generator
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.dense = nn.Linear(in_dim, out_dim)
    
    def fit(self, X:Tensor, y:Tensor):
        with torch.no_grad():
            nn.init.normal_(self.dense.weight, mean=0, std=self.in_dim**-0.5, generator=self.generator)
            nn.init.normal_(self.dense.bias, mean=0, std=self.in_dim**-0.5, generator=self.generator)
            return self(X), y
    
    def forward(self, X):
        return self.dense(X)
    

class Identity(FittableModule):
    def __init__(self):
        super(Identity, self).__init__()
    
    def fit(self, X:Tensor, y:Tensor):
        return X, y
    
    def forward(self, X):
        return X


D = X_train.shape[1]
g1 = torch.Generator().manual_seed(0)
net = Dense(g1, D, 3)
net.fit(X_train, y_train)
out = net(X_test)
print_name(out)
print(net)

torch.Size([273, 3]) out torch.float32
tensor([[ 1.0063e+00,  6.0632e-01, -1.3344e+00],
        [-7.7373e-01,  2.5015e+00,  1.6009e+00],
        [ 9.0851e-02,  1.0703e+00, -8.2686e-01],
        [ 4.3493e-01,  1.8349e+00,  5.2403e-02],
        [ 1.3343e+00,  7.0151e-01,  5.1928e-01],
        [ 2.5770e+00,  4.4369e-01, -1.7169e+00],
        [-4.1187e-01,  3.9454e-01, -3.6120e-01],
        [-1.7308e-01,  1.0798e+00, -4.3126e-01],
        [ 7.6029e-01,  7.0146e-02, -1.6973e+00],
        [ 1.0815e+00,  2.5981e-03, -5.3142e-01],
        [-9.2615e-01,  4.2799e-03,  2.0159e+00],
        [ 1.1285e+00,  4.1362e-01,  1.4112e+00],
        [-1.6260e+00,  1.5503e+00,  1.2666e+00],
        [ 1.1246e+00,  6.6066e-01,  2.3194e+00],
        [ 6.8546e-01,  1.1508e+00, -6.2143e-01],
        [ 3.5718e-01,  3.8035e-01, -9.0341e-02],
        [-2.5088e+00,  2.6201e+00,  3.6691e+00],
        [ 1.2671e+00,  1.7701e+00, -9.2772e-01],
        [-1.2412e+00,  1.9267e+00,  3.5902e-01],
        [-1.4199e+00,  2.6209e

In [6]:
###############################
#### Pair Sampled Networks ####
###############################


class PairSampledLinear(FittableModule):
    def __init__(self,
                 generator: torch.Generator,
                 in_dim: int, 
                 out_dim: int,
                 sampling_method: Literal['uniform', 'gradient'] = 'gradient'
                 ):
        """Dense MLP layer with pair sampled weights.

        Args:
            generator (torch.Generator): PRNG object.
            in_dim (int): Input dimension.
            out_dim (int): Output dimension.
            sampling_method (str): Pair sampling method. Uniform or gradient-weighted.
        """
        super(PairSampledLinear, self).__init__()
        self.generator = generator
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.dense = nn.Linear(in_dim, out_dim)
        self.sampling_method = sampling_method


    def fit(self, 
            X: Tensor, 
            y: Tensor,
        ) -> Tuple[Tensor, Tensor]:
        """Given forward-propagated training data X at the previous 
        hidden layer, and supervised target labels y, fit the weights
        iteratively by letting rows of the weight matrix be given by
        pairs of samples from X. See paper for more details.

        Args:
            X (Tensor): Forward-propagated activations of training data, shape (N, d).
            y (Tensor): Training labels, shape (N, p).
        
        Returns:
            Forwarded activations and labels [f(X), y].
        """
        with torch.no_grad():
            N, d = X.shape
            dtype = X.dtype
            device = X.device
            EPS = torch.tensor(0.1, dtype=dtype, device=device)

            #obtain pair indices
            n = 5*N
            idx1 = torch.arange(0, n, dtype=torch.int32, device=device) % N
            delta = torch.randint(1, N, size=(n,), dtype=torch.int32, device=device, generator=self.generator)
            idx2 = (idx1 + delta) % N
            dx = X[idx2] - X[idx1]
            dists = torch.linalg.norm(dx, axis=1, keepdims=True)
            dists = torch.maximum(dists, EPS)
            
            if self.sampling_method=="gradient":
                #calculate 'gradients'
                dy = y[idx2] - y[idx1]
                y_norm = torch.linalg.norm(dy, axis=1, keepdims=True) #NOTE 2023 paper uses ord=inf instead of ord=2
                grad = (y_norm / dists).reshape(-1) 
                p = grad/grad.sum()
            elif self.sampling_method=="uniform":
                p = torch.ones(n, dtype=dtype, device=device) / n
            else:
                raise ValueError(f"sampling_method must be 'uniform' or 'gradient'. Given: {self.sampling_method}")

            #sample pairs
            selected_idx = torch.multinomial(
                p,
                self.out_dim,
                replacement=True,
                generator=self.generator
                )
            idx1 = idx1[selected_idx]
            dx = dx[selected_idx]
            dists = dists[selected_idx]

            #define weights and biases
            weights = dx / (dists**2)
            biases = -torch.einsum('ij,ij->i', weights, X[idx1]) - 0.5
            self.dense.weight.data = weights
            self.dense.bias.data = biases
            return self(X), y
    

    def forward(self, X):
        return self.dense(X)
    
    
D = X_train.shape[1]
g1 = torch.Generator().manual_seed(0)
net = PairSampledLinear(g1, D, 3)
net.fit(X_train, y_train)
out = net(X_test)
print_name(out)
print(net)

torch.Size([273, 3]) out torch.float32
tensor([[ 3.3456e-01, -2.2509e-01,  1.5428e-01],
        [-1.5937e-01, -2.1131e-01,  5.6211e-01],
        [ 3.1890e-01, -3.4543e-01,  4.5053e-01],
        [-6.3202e-02, -2.9029e-01,  1.0104e+00],
        [ 1.4667e-01,  3.2821e-01,  3.7209e-01],
        [ 4.2502e-01,  5.4385e-01,  6.1908e-01],
        [ 5.0000e-01, -3.7438e-01, -1.7110e-01],
        [ 3.0951e-01, -3.5543e-01,  3.0490e-01],
        [ 5.6117e-01, -1.8869e-01,  2.8079e-01],
        [ 3.5847e-01,  1.4138e-02,  3.2554e-01],
        [ 1.2324e-01, -5.1429e-01, -4.1025e-01],
        [-1.8248e-01, -8.1484e-02,  6.2416e-01],
        [ 1.7651e-01, -5.7704e-01, -1.1708e-01],
        [-7.4602e-01, -5.1382e-01,  1.9618e+00],
        [ 2.7526e-01, -1.4222e-01,  2.0107e-01],
        [ 4.2758e-01,  3.0626e-02, -7.8444e-02],
        [-4.3642e-01, -9.2550e-01,  3.6025e-01],
        [ 1.0645e-01,  6.3015e-02,  9.9766e-01],
        [ 1.4542e-01, -6.9116e-01,  3.0739e-01],
        [ 5.3052e-01, -5.9615e

In [7]:
###################################
#### Sampled Bottleneck ResNet ####
###################################


class SampledResBlock(FittableModule):
    def __init__(self,
                 generator: torch.Generator,
                 hidden_dim: int, 
                 activation_dim: int,
                 activation: nn.Module = nn.Tanh(),
                 sampling_method: Literal['uniform', 'gradient'] = 'gradient'
                 ):
        """A sampled layer followed by activation and linear layer.
        Equivalent to a 1-hidden-layer Sampled Neural Network.

        Args:
            generator (torch.Generator): PRNG object.
            in_dim (int): Input dimension.
            out_dim (int): Output dimension.
            activation (nn.Module): Activation function.
            sampling_method (str): Pair sampling method. Uniform or gradient-weighted.
        """
        super(SampledResBlock, self).__init__()
        self.generator = generator
        self.sampled_linear = PairSampledLinear(generator, hidden_dim, activation_dim, sampling_method)
        self.activation = activation
        self.upscale = Dense(generator, activation_dim, hidden_dim)
    

    def fit(self, X: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]:
        with torch.no_grad():
            X0 = X
            X, y = self.sampled_linear.fit(X, y)
            X = self.activation(X)
            X, y = self.upscale.fit(X, y)
            return X0 + X, y

    
    def forward(self, X):
        X0 = X
        X = self.sampled_linear(X)
        X = self.activation(X)
        X = self.upscale(X)
        return X0 + X
    
    
D = X_train.shape[1]
g1 = torch.Generator().manual_seed(0)
net = SampledResBlock(g1, D, 3)
net.fit(X_train, y_train)
out = net(X_test)
print_name(out)
print(net)

torch.Size([273, 6]) out torch.float32
tensor([[-0.0473, -0.7883,  0.1466,  3.5619, -1.3800, -0.7621],
        [ 1.3652,  1.8986,  0.6492,  0.0685, -1.1746, -1.7752],
        [-1.1149, -0.4199, -0.5711,  0.0608, -1.3669, -0.4968],
        ...,
        [ 0.4023, -3.2661,  0.1615,  0.3883, -1.4053,  0.8282],
        [-1.3197,  0.0145, -0.5040,  0.0637, -1.2957,  0.0842],
        [-0.7870, -1.4006,  1.4091,  0.2247, -1.4208,  0.0865]],
       grad_fn=<AddBackward0>) 

SampledResBlock(
  (sampled_linear): PairSampledLinear(
    (dense): Linear(in_features=6, out_features=3, bias=True)
  )
  (activation): Tanh()
  (upscale): Dense(
    (dense): Linear(in_features=3, out_features=6, bias=True)
  )
)


In [8]:
#####################
### RidgeCV Layer ###
#####################

class RidgeCVModule(FittableModule):
    def __init__(self, alphas=np.logspace(-1, 3, 10)):
        """RidgeCV layer using sklearn's RidgeCV. TODO dont use sklearn"""
        super(RidgeCVModule, self).__init__()
        self.ridge = RidgeCV(alphas=alphas)

    def fit(self, X: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]:
        """Fit the RidgeCV model. TODO dont use sklearn"""
        X_np = X.detach().cpu().numpy().astype(np.float64)
        y_np = y.detach().cpu().squeeze().numpy().astype(np.float64)
        self.ridge.fit(X_np, y_np)
        return self(X), y

    def forward(self, X: Tensor) -> Tensor:
        """Forward pass through the RidgeCV model. TODO dont use sklearn"""
        X_np = X.detach().cpu().numpy().astype(np.float64)
        y_pred_np = self.ridge.predict(X_np)
        return torch.tensor(y_pred_np, dtype=X.dtype, device=X.device).unsqueeze(1) #TODO unsqueeze???


D = X_train.shape[1]
g1 = torch.Generator()
net = RidgeCVModule()
out_train, _ = net.fit(X_train, y_train)
out = net(X_test)
print(net)

print("rmse test", mean_squared_error(y_test.detach().cpu().numpy(), out.detach().cpu().numpy()))
print("rmse train", mean_squared_error(y_train.detach().cpu().numpy(), out_train.detach().cpu().numpy()))

RidgeCVModule()
rmse test 0.45788816
rmse train 0.38987496


In [9]:
class SampledResNet(ResNetBase):
    def __init__(self,
                 generator: torch.Generator,
                 in_dim: int,
                 hidden_dim: int,
                 activation_dim: int, #rename to bottleneck dim?
                 n_blocks: int,
                 activation: nn.Module = nn.Tanh(),
                 upsample_module: Literal['dense', 'sampled', 'identity'] = 'dense',
                 sampling_method: Literal['uniform', 'gradient'] = 'gradient'
                 ):
        """A ResNet with sampled layers as bottleneck layers.
        """
        if upsample_module=="dense":
            upsample = Dense(generator, in_dim, hidden_dim)
        elif upsample_module=="sampled":
            upsample = PairSampledLinear(generator, in_dim, hidden_dim, sampling_method)
        elif upsample_module=="identity":
            upsample = Identity()
        else:
            raise ValueError(f"upsample_module must be 'dense', 'sampled' or 'identity'. Given: {upsample_module}")

        blocks = [SampledResBlock(generator, 
                                hidden_dim, 
                                activation_dim,
                                activation,
                                sampling_method
                                ) for _ in range(n_blocks)]
        ridge = RidgeCVModule()
        super(SampledResNet, self).__init__(upsample, blocks, ridge)


D = X_train.shape[1]
g1 = torch.Generator().manual_seed(int(time.time()*10))
net = SampledResNet(g1, D, 100*D, 100*D, 6, upsample_module='sampled', sampling_method='uniform')
out_train, _ = net.fit(X_train, y_train)
out = net(X_test)
print(net)

print("rmse test", mean_squared_error(y_test.detach().cpu().numpy(), out.detach().cpu().numpy()))
print("rmse train", mean_squared_error(y_train.detach().cpu().numpy(), out_train.detach().cpu().numpy()))
print(net.output_layer.ridge.alpha_)

SampledResNet(
  (upsample): PairSampledLinear(
    (dense): Linear(in_features=6, out_features=600, bias=True)
  )
  (blocks): ModuleList(
    (0-5): 6 x SampledResBlock(
      (sampled_linear): PairSampledLinear(
        (dense): Linear(in_features=600, out_features=600, bias=True)
      )
      (activation): Tanh()
      (upscale): Dense(
        (dense): Linear(in_features=600, out_features=600, bias=True)
      )
    )
  )
  (output_layer): RidgeCVModule()
)
rmse test 0.37741318
rmse train 0.28207216
16.68100537200059


In [12]:
class SampledAndActivation(FittableModule):
    def __init__(self,
                 generator: torch.Generator,
                 in_dim: int,
                 out_dim: int, 
                 activation: nn.Module = nn.Tanh(),
                 sampling_method: Literal['uniform', 'gradient'] = 'gradient'
                 ):
        """TODO
        """
        super(SampledAndActivation, self).__init__()
        self.generator = generator
        self.sampled_linear = PairSampledLinear(generator, in_dim, out_dim, sampling_method)
        self.activation = activation
    

    def fit(self, X: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]:
        with torch.no_grad():
            X, y = self.sampled_linear.fit(X, y)
            X = self.activation(X)
            return X, y

    
    def forward(self, X):
        X = self.sampled_linear(X)
        X = self.activation(X)
        return X


class SampledODEBlock(FittableModule):
    def __init__(self,
                 generator: torch.Generator,
                 hidden_dim: int, 
                 activation: nn.Module = nn.Tanh(),
                 sampling_method: Literal['uniform', 'gradient'] = 'gradient'
                 ):
        """TODO

        Args:
            generator (torch.Generator): PRNG object.
            hidden_dim (int): Hidden size.
            activation (nn.Module): Activation function.
            sampling_method (str): Pair sampling method. Uniform or gradient-weighted.
        """
        super(SampledODEBlock, self).__init__()
        self.generator = generator
        self.sampled_linear = PairSampledLinear(generator, hidden_dim, hidden_dim, sampling_method)
        self.activation = activation
    

    def fit(self, X: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]:
        with torch.no_grad():
            X0 = X
            X, y = self.sampled_linear.fit(X, y)
            X = self.activation(X)
            return X0 + X, y

    
    def forward(self, X):
        X0 = X
        X = self.sampled_linear(X)
        X = self.activation(X)
        return X0 + X
    
    
D = X_train.shape[1]
g1 = torch.Generator().manual_seed(0)
net = SampledResBlock(g1, D, 3)
net.fit(X_train, y_train)
out = net(X_test)
print_name(out)
print(net)

torch.Size([273, 6]) out torch.float32
tensor([[-0.0473, -0.7883,  0.1466,  3.5619, -1.3800, -0.7621],
        [ 1.3652,  1.8986,  0.6492,  0.0685, -1.1746, -1.7752],
        [-1.1149, -0.4199, -0.5711,  0.0608, -1.3669, -0.4968],
        ...,
        [ 0.4023, -3.2661,  0.1615,  0.3883, -1.4053,  0.8282],
        [-1.3197,  0.0145, -0.5040,  0.0637, -1.2957,  0.0842],
        [-0.7870, -1.4006,  1.4091,  0.2247, -1.4208,  0.0865]],
       grad_fn=<AddBackward0>) 

SampledResBlock(
  (sampled_linear): PairSampledLinear(
    (dense): Linear(in_features=6, out_features=3, bias=True)
  )
  (activation): Tanh()
  (upscale): Dense(
    (dense): Linear(in_features=3, out_features=6, bias=True)
  )
)


In [23]:
class SampledEulerODE(ResNetBase):
    def __init__(self,
                 generator: torch.Generator,
                 in_dim: int,
                 hidden_dim: int,
                 n_blocks: int,
                 activation: nn.Module = nn.Tanh(),
                 upsample_module: Literal['dense', 'sampled', 'identity'] = 'dense',
                 sampling_method: Literal['uniform', 'gradient'] = 'gradient'
                 ):
        """A ResNet with sampled layers as bottleneck layers.
        """
        if upsample_module=="dense":
            upsample = Dense(generator, in_dim, hidden_dim)
        elif upsample_module=="sampled":
            upsample = SampledAndActivation(generator, in_dim, hidden_dim, activation, sampling_method)
        elif upsample_module=="identity":
            upsample = Identity()
        else:
            raise ValueError(f"upsample_module must be 'dense', 'sampled' or 'identity'. Given: {upsample_module}")
        
        blocks = [SampledODEBlock(generator,
                                hidden_dim,
                                activation,
                                sampling_method
                                ) for _ in range(n_blocks)]
        ridge = RidgeCVModule()
        super(SampledEulerODE, self).__init__(upsample, blocks, ridge)


D = X_train.shape[1]
g1 = torch.Generator().manual_seed(int(time.time()*10))
net = SampledEulerODE(g1, D, 100*D, 6, upsample_module='sampled', sampling_method='gradient')
out_train, _ = net.fit(X_train, y_train)
out = net(X_test)
print(net)

print("rmse test", mean_squared_error(y_test.detach().cpu().numpy(), out.detach().cpu().numpy()))
print("rmse train", mean_squared_error(y_train.detach().cpu().numpy(), out_train.detach().cpu().numpy()))
print(net.output_layer.ridge.alpha_)

SampledEulerODE(
  (upsample): SampledAndActivation(
    (sampled_linear): PairSampledLinear(
      (dense): Linear(in_features=6, out_features=600, bias=True)
    )
    (activation): Tanh()
  )
  (blocks): ModuleList(
    (0-5): 6 x SampledODEBlock(
      (sampled_linear): PairSampledLinear(
        (dense): Linear(in_features=600, out_features=600, bias=True)
      )
      (activation): Tanh()
    )
  )
  (output_layer): RidgeCVModule()
)
rmse test 0.3939432
rmse train 0.30253765
46.41588833612777


# SWIM tabular model

In [None]:
def neuron_distribution_for_each_layer(
        X_train: Array,
        y_train: Array,
        X_test: Array,
        hidden_size: int,
        n_layers: int,
        ) -> Tuple[Array, Array]:
    """Looks at the distribution of neurons for each layer of a neural network model
    (used to compare SWIM, residual sampling, and random feature networks).
    """
    
    # Initialize the arrays to store the neuron distribution
    train_layers= []
    test_layers = []

    model = SWIM_MLP(
        jax.random.PRNGKey(0), 
        hidden_size, 
        n_layers,
        add_residual=True,
        sampling_method="gradient-weighted",
        #activation = jnp.tanh,
        #residual_scaling_factor=1.0,
        
        )

    model.fit(X_train, y_train)

    feat_test  = model.transform(X_train, only_last=False).reshape(n_layers, -1)
    feat_train = model.transform(X_test, only_last=False).reshape(n_layers, -1)
    
    print(feat_test[1]-feat_test[1])

    #features are shape (n_layers, n_samples, n_features)
    import matplotlib.pyplot as plt
    plt.figure(figsize=(10, 6))
    plt.hist(X_train.flatten(), bins=50, alpha=0.5, label='Train', density=True)
    plt.hist(X_test.flatten(), bins=50, alpha=0.5, label='Test', density=True)
    plt.title('Input Data Distribution')
    plt.xlabel('Input Feature Value')
    plt.ylabel('Probability Density')
    plt.legend()
    plt.show()

    for layer in range(n_layers):
        plt.figure(figsize=(10, 6))
        plt.hist(feat_train[layer], bins=50, alpha=0.5, label='Train', density=True)
        plt.hist(feat_test[layer], bins=50, alpha=0.5, label='Test', density=True)
        plt.title(f'Layer {layer + 1} Neuron Distribution')
        plt.xlabel('Neuron Activation')
        plt.ylabel('Probability Density')
        plt.legend()
        plt.show()

    print(feat_test.shape)

neuron_distribution_for_each_layer(X_train, y_train, X_test, 128, 100)

I want to look at the distribution of weights (eigenvalues? absolute values of rows? distribution of (assuming iid) matrix entries?)

distribution of neurons at each layer

This is for both SWIM, Residual SWIM, random features, 