In [1]:
from typing import Tuple, List, Union, Any, Optional, Dict, Literal, Callable
import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))
sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))

import numpy as np
import jax
import jax.numpy as jnp
import jax.lax as lax
from jaxtyping import Array, Float, Int, PRNGKeyArray
import aeon
import torch
import torch.nn as nn
import torch.functional as F

# from features.sig_trp import SigVanillaTensorizedRandProj, SigRBFTensorizedRandProj
# from features.sig import SigTransform, LogSigTransform
# from features.base import TimeseriesFeatureTransformer, TabularTimeseriesFeatures, RandomNoInformation
# from features.sig_neural import RandomizedSignature
from utils.utils import print_name, print_shape

jax.config.update('jax_platform_name', 'cpu') # Used to set the platform (cpu, gpu, etc.)
np.set_printoptions(precision=3, threshold=5) # Print options

In [2]:
# from aeon.transformations.collection.convolution_based import Rocket, MultiRocket
# import numpy as np

# rocket = Rocket(num_kernels=10000, random_state=0)
# N=700
# D=3
# T=100
# X = np.random.randn(N, D, T)
# rocket.fit(X)

In [3]:
# weights, lengths, biases, dilations, paddings, num_channel_indices, channel_indices = rocket.kernels

# print_name(weights)
# print_name(lengths)
# print_name(biases)
# print_name(dilations)
# print_name(paddings)
# print_name(num_channel_indices)
# print_name(channel_indices)

# # 2 options: 
# # either use kernels of shape (D, 9), 
# # or sample random channels and use kernels of shape (D, 1)

(152172,) weights
[ 1.387  0.123 -0.022 ... -1.093  0.271  0.541] 

(10000,) lengths
[7 9 7 ... 9 7 7] 

(10000,) biases
[-0.497 -0.587 -0.29  ...  0.105  0.049 -0.999] 

(10000,) dilations
[ 1  1  9 ...  1  2 14] 

(10000,) paddings
[3 0 0 ... 4 0 0] 

(10000,) num_channel_indices
[3 1 1 ... 2 3 1] 

(16972,) channel_indices
[2 1 0 ... 0 2 2] 



# Create my own here

In [4]:
T = 113
input_length = T
D = 3
N= 2
n_features = 10_000
kernel_length = 8


X = torch.randn( (N, D, T) )


conv = nn.Conv1d(
    in_channels=D, 
    out_channels=n_features, 
    kernel_size=kernel_length,
    dilation = 16,
    padding = 0,
    bias=False)
output = conv(X)
print_shape(X)
print_shape(output)


max_exponent = np.floor(np.log2((input_length - 1) / (kernel_length- 1))).astype(np.int64)
dilations = 2**np.arange(max_exponent + 1)
print_name(max_exponent)
print_name(dilations)

torch.Size([2, 3, 113]) X 

torch.Size([2, 10000, 1]) output 

() max_exponent
4 

(5,) dilations
[ 1  2  4  8 16] 



In [91]:
from torch import Tensor


def apply_chunked(fn: Callable, X: Tensor, chunk_size: int=1000):
    """Applies a function to a tensor in chunks.

    Args:
        fn (Callable): Function to apply.
        X (Tensor): Input tensor of shape (N, ...).
        chunk_size (int): Chunk size. Defaults to 1000.

    Returns:
        Tensor: Output tensor.
    """
    batches = torch.split(X, chunk_size)
    output = [fn(batch) for batch in batches]
    return torch.cat(output, dim=0)


class MultiRocket(nn.Module):
    def __init__(self, D, T, n_features, kernel_length=9, seed=0):
        super(MultiRocket, self).__init__()

        max_exponent = np.floor(np.log2((T - 1) / (kernel_length- 1))).astype(np.int64)
        dilations = 2**np.arange(max_exponent + 1)
        n_kernels_per_dilation = n_features // dilations[-1] *2 #// 4

        self.convs = nn.ModuleList(
            [nn.Conv1d(
                in_channels=D, 
                out_channels=n_kernels_per_dilation, 
                kernel_size=kernel_length,
                dilation = dilation,
                padding = "same",
                bias=True) 
             for dilation in dilations]
        )

    
    def init_biases(self, X: Tensor, chunk_size: int=1000):
        """Initializes the biases of the convolutional layers,
        using the quantiles of the data. Assumes the data to
        be shuffled.

        WARNING: Slow even for 10 points, 10 000 kernels, T=113. (1.3s elapsed)

        Args:
            X (Tensor): Shape (N, D, T).
            chunk_size (int): Batch size for computations
        """
        with torch.no_grad():
            # first set the biases to zero
            for conv in self.convs:
                conv.bias.data.zero_()

            #obtain output
            out_per_conv = [apply_chunked(conv, X, chunk_size) for conv in self.convs]

            #initalize bias using random quantiles
            for out, conv in zip(out_per_conv, self.convs):
                #out: (N, n_kernels_per_dilation, T)
                n_ker_per_dil = out.shape[1]
                quantiles = 0.8 * torch.rand(n_ker_per_dil) + 0.1
                q = torch.quantile(out.permute(0,2,1).reshape(-1, n_ker_per_dil), quantiles, dim=0)
                conv.bias.data = torch.diag(q)

        return self

    
    def forward(self, x):
        # x: (N, D, T)
        x = [conv(x) for conv in self.convs]
        x = torch.cat(x, dim=1)
        x = torch.mean((x>0), dim=-1, dtype=x.dtype)
        #TODO add rest ...
        return x
    

    
T = 113
input_length = T
D = 3
N= 100
n_features = 10_000
kernel_length = 8


X = torch.randn( (N, D, T) )


rocket = MultiRocket(
    D, T, n_features
)
output = rocket(X)

rocket.init_biases(X[0:10])

MultiRocket(
  (convs): ModuleList(
    (0): Conv1d(3, 2500, kernel_size=(9,), stride=(1,), padding=same)
    (1): Conv1d(3, 2500, kernel_size=(9,), stride=(1,), padding=same, dilation=(2,))
    (2): Conv1d(3, 2500, kernel_size=(9,), stride=(1,), padding=same, dilation=(4,))
    (3): Conv1d(3, 2500, kernel_size=(9,), stride=(1,), padding=same, dilation=(8,))
  )
)

In [43]:
out = torch.randn( (5, 2500, 113) )

quantiles = torch.rand(2500)
print_shape(quantiles)
print_shape(out)
quantiles = torch.quantile(out.permute(0,2,1).reshape(-1, 2500), quantiles, dim=0)
print_name(quantiles)

torch.Size([2500]) quantiles 

torch.Size([5, 2500, 113]) out 

torch.Size([2500, 2500]) quantiles
tensor([[-0.8175, -0.7173, -0.6654,  ..., -0.7986, -0.6466, -0.8051],
        [-0.9599, -0.9850, -0.7997,  ..., -1.0100, -0.8602, -0.9626],
        [-0.4113, -0.2979, -0.3206,  ..., -0.3810, -0.3319, -0.3042],
        ...,
        [-1.8916, -1.9357, -1.7681,  ..., -1.8416, -1.7164, -1.9571],
        [-0.6908, -0.5830, -0.5854,  ..., -0.6887, -0.5153, -0.6877],
        [ 1.7925,  1.8698,  2.0595,  ...,  1.7197,  1.7675,  1.8403]]) 



### What do i need to do?

( x | x | x | x | x | x )