In [1]:
%load_ext autoreload
%autoreload 2

In [43]:
import h5py
from pprint import pprint
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import entr
from drift_ml.datasets.bosch_cnc_machining.utils.utils import augment_xyz_samples
from drift_ml.datasets.bosch_cnc_machining.utils.dataloader import RawBoschCNCDataloader, STFTBoschCNCDataloader, NPYBoschCNCDataLoader

In [30]:
base_loader = NPYBoschCNCDataLoader(metadata_path="/home/tbiegel/nico_files/drift_ml/src/drift_ml/datasets/bosch_cnc_machining/extracted_features/metadata_ws4096.pkl")
base_loader.load_data(
    sample_data_x_path="/home/tbiegel/nico_files/drift_ml/src/drift_ml/datasets/bosch_cnc_machining/extracted_features/npy/sample_data_x_raw_ws4096.npy",
    sample_data_y_path="/home/tbiegel/nico_files/drift_ml/src/drift_ml/datasets/bosch_cnc_machining/extracted_features/npy/sample_data_y_raw_ws4096.npy"
    )
base_loader.generate_datasets_by_size(train_size=.4,val_size=.2,test_size=.4)

In [31]:
stft_base_loader = base_loader.get_windowed_samples_as_stft_dataloader(transform_fn=lambda x: x)
X_train_base, X_val_base, X_test_base = stft_base_loader.get_standardized_train_val_test()

  0%|          | 0/42978 [00:00<?, ?it/s]

In [62]:
from torch import tensor
from drift_ml.datasets.bosch_cnc_machining.utils.evaluation import Metrics
from drift_ml.datasets.bosch_cnc_machining.models.nnclassifier import NNEnsembleClassifier

metric_calculator = Metrics()
ensemble = NNEnsembleClassifier(n_ensemble=5)

In [6]:
ensemble.fit(
    fit_args=[
        X_train_base, 
        stft_base_loader.y_train[:, np.newaxis], 
        X_val_base, 
        stft_base_loader.y_val[:, np.newaxis]
    ],
    fit_kwargs={
        "lrate": 1e-2,
        "epochs": 20
    }
)

  tensor(X).to(self.device).float(),
  tensor(X).to(self.device).float(),
DEBUG:root:Final val. performance: AUROC 0.97, AURPC 0.95, F1 0.92
DEBUG:root:Final val. performance: AUROC 0.96, AURPC 0.85, F1 0.81
DEBUG:root:Final val. performance: AUROC 0.97, AURPC 0.94, F1 0.90
DEBUG:root:Final val. performance: AUROC 0.98, AURPC 0.96, F1 0.93
DEBUG:root:Final val. performance: AUROC 0.94, AURPC 0.86, F1 0.85


In [60]:
y_base_scores = ensemble.predict_proba(X_test_base)

base_entropies = entr(y_base_scores)
metric_calculator.print(y_base_scores, stft_base_loader.y_test[:, np.newaxis])

   BinaryAUROC  BinaryAveragePrecision  BinaryF1Score  BinaryMatthewsCorrCoef
0      0.90591                0.833986       0.859228                0.860729


In [None]:
shift = 40

print(f"Evaluating {shift} degrees pitch shift")
stft_loader = base_loader.get_windowed_samples_as_stft_dataloader(
    transform_fn=lambda x: augment_xyz_samples(x, pitch_deg=shift, yaw_deg=shift)
)
X_test_scaled = stft_base_loader.standardize_datasets([stft_loader.X_test])[0]

In [63]:
y_shift_scores = ensemble.predict_proba(X_test_base)

shift_entropies = entr(y_shift_scores)
avg_entropy = np.mean(shift_entropies)

metric_calculator.print(y_shift_scores, stft_loader.y_test[:, np.newaxis])

   BinaryAUROC  BinaryAveragePrecision  BinaryF1Score  BinaryMatthewsCorrCoef
0     0.754016                0.252417       0.054845                0.165389
