Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Porting WEASEL 2.0 to pyts #155

Open
aglenis opened this issue Jul 5, 2023 · 3 comments
Open

Porting WEASEL 2.0 to pyts #155

aglenis opened this issue Jul 5, 2023 · 3 comments

Comments

@aglenis
Copy link

aglenis commented Jul 5, 2023

Description

I am trying to port WEASEL 2.0 (https://github.com/patrickzib/dictionary) to pyts however I failed to match WEASEL 2.0 accuracy results.

I don't know if its a bug in my code, or something like bad hyperparameters on my part.

Steps/Code to Reproduce

My code for WEASEL 2.0:

from numba import (  # set_num_threads,
    NumbaPendingDeprecationWarning,
    NumbaTypeSafetyWarning,
    njit,
    objmode,
    prange,
)
from numba.core import types
from numba.typed import Dict

from pyts.transformation import WEASEL
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

import numpy as np
from sklearn.utils import check_random_state

from pyts.datasets import *

import time



def _dilation(X, d, first_difference,do_padding = False,padding_size =10):
    if do_padding:
        padding = np.zeros((len(X), padding_size))
        X = np.concatenate((padding, X, padding), axis=1)

    # using only first order differences
    if first_difference:
        X = np.diff(X, axis=1, prepend=0)

    # adding dilation
    X_dilated = _dilation2(X, d)
    X_index = _dilation2(np.arange(X_dilated.shape[-1], dtype=np.float_)
                         .reshape(1, -1), d)[0]

    return (
        X_dilated,
        X_index,
    )


@njit(cache=True, fastmath=True)
def _dilation2(X, d):
    # dilation on actual data
    if d > 1:
        start = 0
        data = np.zeros(X.shape, dtype=np.float_)
        for i in range(0, d):
            curr = X[:, i::d]
            end = curr.shape[1]
            data[:, start : start + end] = curr
            start += end
        return data
    else:
        return X.astype(np.float_)


def doWEASEL_dilation(X_train,X_test,y_train,y_test,clf,
                      window_sizes_arg=[0.1, 0.3, 0.5, 0.7, 0.9],
                      window_steps_arg=None,word_size_arg=7, n_bins_arg=4,
                      anova_arg=False, drop_sum_arg=False, norm_mean_arg=False, norm_std_arg=False,
                      strategy_arg='uniform', alphabet_arg=None,do_chi2_arg=False,chi2_threshold=2,random2 =True):

    rng = check_random_state(None)

    weasels_list = []
    weasels_trend_list = []

    dilation_train_array_list = []
    dilation_test_array_list = []

    x_dialation_list =[]
    x_trend_dilation_list = []

    X_train_trend = np.diff(X_train)
    X_test_trend = np.diff(X_test)
    start_fit = time.time()
    for curr_window,curr_step in zip(window_sizes_arg,window_steps_arg):

        series_length = X_train.shape[1]
        curr_dilation = np.maximum(
        1,
        np.int32(2 ** rng.uniform(0, np.log2((series_length - 1) / (curr_window - 1)))))

        x_dialation_list.append(curr_dilation)
        #print(str(curr_window)+','+str(curr_step)+','+str(curr_dilation))
        X_dilated,X_index = _dilation(X_train, curr_dilation, first_difference=False,do_padding = False,padding_size =10)
        weasel_curr_dilation = WEASEL(sparse=False,strategy=strategy_arg,window_sizes=[curr_window],window_steps=[curr_step],
        word_size=word_size_arg,n_bins=n_bins_arg,anova=anova_arg,alphabet=alphabet_arg,
        drop_sum=drop_sum_arg, norm_mean=norm_mean_arg, norm_std=norm_std_arg)
        X_train_array_weasel = weasel_curr_dilation.fit_transform(X_dilated, y_train)

        weasels_list.append(weasel_curr_dilation)

        series_length_trend = X_train_trend.shape[1]
        curr_dilation_trend = np.maximum(
        1,
        np.int32(2 ** rng.uniform(0, np.log2((series_length_trend - 1) / (curr_window - 1)))))

        x_trend_dilation_list.append(curr_dilation_trend)

        #print(str(curr_window)+','+str(curr_step)+','+str(curr_dilation_trend))

        X_dilated_trend,X_index = _dilation(X_train_trend, curr_dilation_trend, first_difference=False,do_padding = False,padding_size =10)

        weasel_curr_dilation_trend = WEASEL(sparse=False,strategy=strategy_arg,window_sizes=[curr_window],window_steps=[curr_step],
        word_size=word_size_arg,n_bins=n_bins_arg,anova=anova_arg,alphabet=alphabet_arg,
        drop_sum=drop_sum_arg, norm_mean=norm_mean_arg, norm_std=norm_std_arg)
        X_train_array_weasel_trend = weasel_curr_dilation_trend.fit_transform(X_dilated_trend, y_train)

        weasels_trend_list.append(weasel_curr_dilation_trend)

        merged_arrays = np.hstack([X_train_array_weasel,X_train_array_weasel_trend])
        dilation_train_array_list.append(merged_arrays)

    final_array_train=np.hstack(dilation_train_array_list)
    clf.fit(final_array_train,y_train)
    end_fit = time.time()

    start_predict = time.time()
    i=0
    #print('starting predict')
    for curr_window,curr_step in zip(window_sizes_arg,window_steps_arg):
        curr_dilation = x_dialation_list[i]

        #print(str(curr_window)+','+str(curr_step)+','+str(curr_dilation))
        X_dilated_test,X_index = _dilation(X_test, curr_dilation, first_difference=False,do_padding = False,padding_size =10)
        weasel_curr_dilation = weasels_list[i]
        X_test_array = weasel_curr_dilation.transform(X_dilated_test)

        curr_dilation_trend = x_trend_dilation_list[i]
        #print(str(curr_window)+','+str(curr_step)+','+str(curr_dilation_trend))
        X_test_trend_dilated,X_index_trend = _dilation(X_test_trend, curr_dilation_trend, first_difference=False,do_padding = False,padding_size =10)
        weasel_curr_dilation_trend_test = weasels_trend_list[i]
        #print(weasel_curr_dilation_trend_test)
        X_test_array_trend =weasel_curr_dilation_trend_test.transform(X_test_trend_dilated)

        merged_arrays_test = np.hstack([X_test_array,X_test_array_trend])
        dilation_test_array_list.append(merged_arrays_test)
        i+=1
    final_array_test = np.hstack(dilation_test_array_list)


    predicted = clf.predict(final_array_test)
    end_predict = time.time()

    total_fit = end_fit-start_fit
    total_predict = end_predict-start_predict
    #print(predicted)
    #print('y_test')
    #print(y_test)
    curr_score = accuracy_score( y_test,predicted)
    return total_fit,total_predict,curr_score

if __name__ == '__main__':
    from pyts.datasets import load_gunpoint

    names_list = ['Crop','FordB','FordA','NonInvasiveFetalECGThorax2','NonInvasiveFetalECGThorax1','PhalangesOutlinesCorrect','HandOutlines','TwoPatterns']

    num_trees = 64

    name = 'WEASEL12-dilation6-bigram'
    clf1 = RandomForestClassifier(n_estimators=num_trees)

    total_accuracy = 0.0
    total_time = 0.0
    #X_train,X_test, y_train, y_test = load_gunpoint(return_X_y=True)
    for curr_dataset in names_list:
        (X_train, X_test, y_train, y_test)=fetch_ucr_dataset(curr_dataset, use_cache=True, data_home='/Users/aglenis/ucr_datasets/',
                                                                             return_X_y=True)

        max_window_size=256
        n_timestamps = X_train.shape[1]
        curr_max_window = min(n_timestamps,max_window_size)

        window_inc = 24
        if n_timestamps < 100:
            window_inc = 16

        min_window = 8
        window_step = 4
        #import math
        #count = math.sqrt(curr_max_window)
        #window_inc = int((curr_max_window-min_window)/count)
        window_sizes = [i for i in range(min_window,curr_max_window,window_inc)]
        #print(window_sizes)
        window_step_list = [window_step]*len(window_sizes)

        total_fit,total_predict,curr_score = doWEASEL_dilation(X_train,X_test,y_train,y_test,clf=clf1,window_sizes_arg=window_sizes,window_steps_arg=window_step_list)
        print(name+','+curr_dataset+','+str(curr_score)+','+str(total_fit)+','+str(total_predict)+','+str(total_fit+total_predict)+','+str(window_sizes)+','+str(window_step_list))
        total_time+= (total_fit+total_predict)
        total_accuracy+=curr_score

print('Average Accuracy : '+str(total_accuracy/len(names_list)))
print('Average Execution time : '+str(total_time/len(names_list)))

Versions

NumPy 1.21.6
SciPy 1.5.4
Scikit-Learn 1.0.2
Numba 0.55.2
Pyts 0.12.0

@patrickzib
Copy link

patrickzib commented Jul 11, 2023

Thanks for taking the effort!

A couple of comments:

  1. n_bins_arg=4
    In WEASEL 2.0 the alphabet size is fixed to 2

See: Alphabet Size

  1. window_sizes_arg=[0.1, 0.3, 0.5, 0.7, 0.9]:
    I suppose this means 0.9 times the size of the time series length? Please go for much smaller numbers. I go for, i.e. win_size in np.arange(4, 44) or win_size in np.arange(4, 24). In combination with dilation, this adds up to very large receptive fields.

See: Window sizes

  1. first_difference = False
    I do not see the use of first_differences? Please randomly choose from first_differences, too.

See: Ensemble

  1. The number of parameter configurations should be between 50 and 100, each choosing from the range of window_sizes, first_differences, and dilation factors.

See: Ensemble

  1. WEASEL has a novel feature selection strategy based on variance

See: SFA with Variance

  1. strategy_arg='uniform'
    Not sure, what uniform refers to? I am randomly choosing from equi-width and equi-depth

See: [Binning Strategy]
(https://github.com/patrickzib/dictionary/blob/63633eeaa52680f3a1eb016ec95ea0ca2c5430b9/weasel/classification/dictionary_based/_weasel_v2.py#L125)

Hope, this helps. IMO: The most critical parts should be alphabet_size, window-size, differences, and variance in SFA.

@johannfaouzi
Copy link
Owner

Hi,

Sorry for the delayed response, I saw the notification and forgot about it...

First, thanks @aglenis for the effort and thanks @patrickzib for the feedback! I will need to look at the paper and the source code to provide more detailed, but I will answer some points first.

  • Performing dilation just to get the indices sounds suboptimal to me. You can get the indices with a closed formula.
  • The default window sizes seem to be from my implementation of WEASEL in pyts (I don't remember the default values in the original implementation of WEASEL, but I prefer in general relative values than absolute values for hyper-parameters).
  • The first difference seems to be used with X_train_trend and X_test_trend.
  • The strategy argument has different values in pyts: uniform stands for equi-width (the bins all have the same width), while quantile stands for equi-depth (the same number of values fall in each bin).

In general, I like having more hyper-parameters (even if the values are fixed in the original paper) because it might be useful to change these values for other datasets (many people have their own datasets and don't work on the UCR/UEA archive), but I try to keep the default values as close as possible to the ones in the original publication.

I'm very interested in adding WEASEL 2.0 to pyts, so I will further look into your code and also start working on this on my own, and we'll see what we get!

@TonyBagnall
Copy link

FYI, WEASEL 2 is in aeon and we have run it to test results
https://github.com/aeon-toolkit/aeon/blob/main/aeon/classification/dictionary_based/_weasel_v2.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants