In [1]:
import metis
import torch 
import numpy as np 
from typing import List, Tuple, Union, Dict
import matplotlib.pyplot as plt 

from metis.featurepipeline import (
    Std, 
    Clip, 
    ClipQuantiles, 
    ElementwiseOp, 
    ToQuantile, 
    Identity, 
    FeaturePipeline
)
from metis.plotting import hist_tensor, hist_df, dropdown_plot

# dir(metis)

In [2]:
data = torch.rand((8192, 16, 3))
data[:, :, 0].exp_().exp_()
data[:, :, 1].exp_().exp_()
data[:, :, 2] = (1e-3 + data[:, :, 2]).log()

feature_names = ['exp', 'nlog', 'exp2']
dropdown_plot(
    hist_tensor(data.reshape(-1, 128)[..., :3].cpu(), feature_names)
).run(jupyter_mode='inline')

In [11]:
seq_args = {
    'normalize_shape': "() () feature", 
    'input_shape': 'seq batch feature',
}
fop = {
    'exp.*': [
        Std(normalize_shape='() () feature', input_shape='seq batch feature'), 
        Clip(-3, 3, input_shape='seq batch feature')], 
    '.*': [
        Std(normalize_shape='() () feature', input_shape='seq batch feature'), 
        ToQuantile(normalize_shape='() () feature', input_shape='seq batch feature', method='gaussian')],
}

fp = FeaturePipeline(fop, feature_names, feature_pack_pattern='seq batch *').to(data.device)
transformed_data = fp.fit(data).reshape(-1, 3)
display(fp.match_info())
print(transformed_data.shape, feature_names)
dropdown_plot(
    hist_tensor(transformed_data, feature_names)
).run(jupyter_mode='inline')

forward_data = fp(data).reshape(-1, 3)
dropdown_plot(
    hist_tensor(forward_data, feature_names)
).run(jupyter_mode='inline')

Clipping: torch.Size([8192, 16, 2]) tensor([-1.0481,  0.4474, -0.3684, -0.9714, -0.3401]) -3 3



Features matched by 'exp.*' are not contiguous: ['exp', 'exp2']


Regex '.*' matches 2 already-matched features: {'exp', 'exp2'}



feature,feature_index,regex_match,regex_match_index
str,i64,str,i64
"""exp""",0,"""exp.*""",0
"""nlog""",1,""".*""",1
"""exp2""",2,"""exp.*""",0


torch.Size([131072, 3]) ['exp', 'nlog', 'exp2']


Clipping: torch.Size([8192, 16, 2]) tensor([-1.0481,  0.4474, -0.3684, -0.9714, -0.3401]) -3 3


In [12]:
%%timeit
forward_data = fp(data).reshape(-1, 3)

Clipping: torch.Size([8192, 16, 2]) tensor([-1.0481,  0.4474, -0.3684, -0.9714, -0.3401]) -3 3
Clipping: torch.Size([8192, 16, 2]) tensor([-1.0481,  0.4474, -0.3684, -0.9714, -0.3401]) -3 3
Clipping: torch.Size([8192, 16, 2]) tensor([-1.0481,  0.4474, -0.3684, -0.9714, -0.3401]) -3 3
Clipping: torch.Size([8192, 16, 2]) tensor([-1.0481,  0.4474, -0.3684, -0.9714, -0.3401]) -3 3
Clipping: torch.Size([8192, 16, 2]) tensor([-1.0481,  0.4474, -0.3684, -0.9714, -0.3401]) -3 3
Clipping: torch.Size([8192, 16, 2]) tensor([-1.0481,  0.4474, -0.3684, -0.9714, -0.3401]) -3 3
Clipping: torch.Size([8192, 16, 2]) tensor([-1.0481,  0.4474, -0.3684, -0.9714, -0.3401]) -3 3
Clipping: torch.Size([8192, 16, 2]) tensor([-1.0481,  0.4474, -0.3684, -0.9714, -0.3401]) -3 3
Clipping: torch.Size([8192, 16, 2]) tensor([-1.0481,  0.4474, -0.3684, -0.9714, -0.3401]) -3 3
Clipping: torch.Size([8192, 16, 2]) tensor([-1.0481,  0.4474, -0.3684, -0.9714, -0.3401]) -3 3
Clipping: torch.Size([8192, 16, 2]) tensor([-1.048