# Machine Transfer Learning

Dit notebook bevat de analyse en training waarbij een model wordt getraind op machine 1 en vervolgens wordt getest en gevalideerd op machine 2 en 3.

In [None]:
import os
import sys
from pathlib import Path
import random
import json
from collections import defaultdict
from typing import Dict, List, Sequence, Tuple
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pywt
from tqdm.notebook import tqdm


# Zorg ervoor dat figuren groot genoeg zijn
plt.rcParams['figure.figsize'] = [12, 8]
sns.set_style('whitegrid')

df = pd.read_csv('export\measurement_files_metadata.csv')
display(df)

export_dir = 'export'


## Verdeling van data over machines

Laten we eerst de verdeling van de data over de machines bekijken.

In [None]:
# Verdeling van samples over machines
machine_dist = pd.crosstab(
    [df['machine'], df['operation']], 
    df['class']
).assign(Total=lambda x: x.sum(axis=1)).sort_index()

# print("Verdeling van samples over machines:")
# display(machine_dist)

# Maak een visualisatie
plt.figure(figsize=(15, 10))
machine_dist.reset_index().pivot(index='operation', columns='machine', values='Total').plot(kind='bar')
plt.title('Aantal samples per machine en bewerking')
plt.ylabel('Aantal samples')
plt.xticks(rotation=0)
plt.legend(title='Machine')
plt.tight_layout()
plt.savefig(os.path.join(export_dir, 'machine_distribution.png'), dpi=300)

In [None]:
# Register a custom font file if you have one (optional)
# from matplotlib import font_manager
# font_manager.fontManager.addfont('/path/to/OpenSans-Regular.ttf')

# 1) Set global font family and fallbacks
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = [
    'Open Sans', 'Lato', 'Arial', 'DejaVu Sans'
]

# 2) Update overall styling
sns.set_theme(
    style='whitegrid',
    palette='pastel',
    rc={
        'axes.titlesize': 16,
        'axes.titleweight': 'bold',
        'axes.labelsize': 14,
        'xtick.labelsize': 12,
        'ytick.labelsize': 12,
        'font.size': 12
    }
)

# Prepare data
class_counts = df['class'].value_counts()
total = class_counts.sum()
colors = ['tab:green' if cls == 'good' else 'tab:red' for cls in class_counts.index]

# Draw bars
fig, ax = plt.subplots(figsize=(8, 6))
bars = ax.bar(
    class_counts.index,
    class_counts.values,
    color=colors,
    edgecolor='gray',
    linewidth=1.2,
    width=0.6
)

# Annotate each bar with count and percentage
for bar in bars:
    height = bar.get_height()
    pct = height / total * 100
    ax.text(
        bar.get_x() + bar.get_width() / 2,
        height + total * 0.005,
        f"{height:,}\n({pct:.1f}%)",
        ha='center',
        va='bottom',
        fontweight='medium'
    )

# Final touches
ax.set_title('Distribution of Classes')
ax.set_xlabel('Class')
ax.set_ylabel('Count')
ax.set_ylim(0, class_counts.values.max() * 1.15)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.grid(axis='y', linestyle='--', alpha=0.7)

plt.tight_layout()
plt.savefig(os.path.join(export_dir, 'class_distribution.png'), dpi=300)
plt.show()

# Signal analasys

In [None]:


# Get the current working directory
current_dir = os.getcwd()

# Set the root directory to the parent of the current directory
root_dir = Path(current_dir).parent

# Add the root directory to sys.path so Python can find the utils module
sys.path.append(str(root_dir))
print(f"Added {root_dir} to Python path")

os.chdir(Path(os.getcwd()).parent)

In [None]:
import sys
# Get the current working directory
current_dir = os.getcwd()

# Set the root directory to the parent of the current directory
root_dir = os.path.dirname(os.path.dirname(current_dir))

# Add the root directory to sys.path so Python can find the utils module
sys.path.append(str(root_dir))

from utils.data_loader_utils import datafile_read

In [None]:
# Define file paths
file_paths = {
    'M01': 'data/M01/OP07/bad/M01_Aug_2019_OP07_000.h5',
    'M02': 'data/M02/OP07/bad/M02_Aug_2019_OP07_000.h5',
    'M03': 'data/M03/OP07/bad/M03_Aug_2019_OP07_000.h5'
}

# Initialize variables
OP07_BAD_M01, OP07_BAD_M02, OP07_BAD_M03 = None, None, None

# Try to read files if they exist
for machine, path in file_paths.items():
    if os.path.exists(path):
        if machine == 'M01':
            OP07_BAD_M01 = datafile_read(path, axes=[0])
            print(f"Loaded {path}")
        elif machine == 'M02':
            OP07_BAD_M02 = datafile_read(path, axes=[0])
            print(f"Loaded {path}")
        elif machine == 'M03':
            OP07_BAD_M03 = datafile_read(path, axes=[0])
            print(f"Loaded {path}")
    else:
        print(f"Warning: File not found: {path}")

In [None]:
w = pywt.Wavelet('coif8')

In [None]:
wavelet = pywt.Wavelet('coif8')
phi, psi, x = wavelet.wavefun(level=7)

# Create a figure with two subplots side by side
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Plot scaling function
ax1.plot(x, phi, linewidth=1.5, color='#1f77b4')
ax1.set_title("Coiflet 8 – Scaling Function (φ)", fontsize=14, fontweight='bold')
ax1.set_xlabel("t", fontsize=12)
ax1.set_ylabel("φ(t)", fontsize=12)
ax1.tick_params(axis='both', which='major', labelsize=11)
ax1.grid(True, linestyle='--', alpha=0.7)
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)

# Plot wavelet function
ax2.plot(x, psi, linewidth=1.5, color='#ff7f0e')
ax2.set_title("Coiflet 8 – Wavelet Function (ψ)", fontsize=14, fontweight='bold')
ax2.set_xlabel("t", fontsize=12)
ax2.set_ylabel("ψ(t)", fontsize=12)
ax2.tick_params(axis='both', which='major', labelsize=11)
ax2.grid(True, linestyle='--', alpha=0.7)
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)

plt.tight_layout()
plt.show()

# Save the figure to the export folder
import os

# Create export directory if it doesn't exist
export_dir = "../../export"
if not os.path.exists(export_dir):
    os.makedirs(export_dir)

# Save the figure
fig_path = os.path.join(export_dir, "coif8.png")
fig.savefig(fig_path, dpi=600, bbox_inches='tight')
print(f"Figure saved to {fig_path}")


In [None]:
def wavelet_transform(signal, wavelet='coif8', maxlevel=3, mode='symmetric'):
    return pywt.WaveletPacket(data=signal, wavelet=wavelet, maxlevel=maxlevel, mode=mode)


In [None]:
for j in range(3):
    plt.figure(figsize=(15,5))
    for i in range(3):
        machine_data = datafile_read(f'data/M0{i+1}/OP07/bad/M0{i+1}_Aug_2019_OP07_000.h5',axes=[j],plotting=False)
        wp_data = wavelet_transform(machine_data[:,0])
        plt.plot(wp_data['aaa'].data, label=f'Machine M0{i+1}')
    plt.title(f'Wavelet Packet Decomposition - AAA Node - Axis {j}')
    plt.xlabel('Sample')
    plt.ylabel('Amplitude')
    plt.grid(True)
    plt.legend()

In [None]:
OP07_BAD_M01_WP = wavelet_transform(OP07_BAD_M01[:,0])  
OP07_BAD_M02_WP = wavelet_transform(OP07_BAD_M02[:,0])
OP07_BAD_M03_WP = wavelet_transform(OP07_BAD_M03[:,0])

In [None]:
# Calculate cross-correlation between signals of the three machines
import numpy as np
import seaborn as sns
from scipy import signal
import matplotlib.pyplot as plt

# Extract axis 0 data for each machine
m1_data = OP07_BAD_M01_WP['aaa'].data
m2_data = OP07_BAD_M02_WP['aaa'].data 
m3_data = OP07_BAD_M03_WP['aaa'].data

# Find minimum length across all signals
min_length = min(len(m1_data), len(m2_data), len(m3_data))

# Truncate signals to minimum length
m1_data = m1_data[:min_length]
m2_data = m2_data[:min_length] 
m3_data = m3_data[:min_length]

# Calculate cross-correlations
cross_corr_m1_m2 = signal.correlate(m1_data, m2_data, mode='same') / np.sqrt(signal.correlate(m1_data, m1_data, mode='same')[int(min_length/2)] * signal.correlate(m2_data, m2_data, mode='same')[int(min_length/2)])
cross_corr_m1_m3 = signal.correlate(m1_data, m3_data, mode='same') / np.sqrt(signal.correlate(m1_data, m1_data, mode='same')[int(min_length/2)] * signal.correlate(m3_data, m3_data, mode='same')[int(min_length/2)])
cross_corr_m2_m3 = signal.correlate(m2_data, m3_data, mode='same') / np.sqrt(signal.correlate(m2_data, m2_data, mode='same')[int(min_length/2)] * signal.correlate(m3_data, m3_data, mode='same')[int(min_length/2)])

# Create lag array for plotting
lags = np.arange(-min_length/2, min_length/2)

# Plot cross-correlations
plt.figure(figsize=(10,8))
plt.subplot(2, 1, 1)
plt.plot(lags, cross_corr_m1_m2, label='M01-M02')
plt.plot(lags, cross_corr_m1_m3, label='M01-M03')
plt.plot(lags, cross_corr_m2_m3, label='M02-M03')
plt.title('Cross-Correlation between Machine Signals - Axis 0')
plt.xlabel('Lag')
plt.ylabel('Correlation')
plt.legend()
plt.grid(True)

# Also create a matrix to show maximum cross-correlation values
max_corr_values = np.array([
    [1, np.max(cross_corr_m1_m2), np.max(cross_corr_m1_m3)],
    [np.max(cross_corr_m1_m2), 1, np.max(cross_corr_m2_m3)],
    [np.max(cross_corr_m1_m3), np.max(cross_corr_m2_m3), 1]
])

plt.subplot(2, 1, 2)
sns.heatmap(max_corr_values, annot=True, cmap='coolwarm', 
            xticklabels=['M01', 'M02', 'M03'],
            yticklabels=['M01', 'M02', 'M03'])
plt.title('Maximum Cross-Correlation Values')
plt.tight_layout()

In [None]:
# Compare original vs. decomposed signal for Machine M01, Axis 0
axis = 0
original_signal   = OP07_BAD_M01[:, axis]
decomposed_signal = OP07_BAD_M01_WP['aaa'].data

fig = plt.figure(figsize=(15, 5))

# Create two subplots stacked vertically
ax1 = plt.subplot(2, 1, 1)
ax1.plot(original_signal, label='Original Signal (M01, Axis 0)', color='tab:blue')
ax1.set_title('Original Signal')
ax1.set_ylabel('Amplitude')
# ax1.legend()

ax2 = plt.subplot(2, 1, 2)
ax2.plot(decomposed_signal, label='Wavelet Packet AAA Node', color='tab:orange')
ax2.set_title('Decomposed Signal - Wavelet Packet AAA Node')
ax2.set_xlabel('Sample')
ax2.set_ylabel('Amplitude')
# ax2.legend()

plt.tight_layout()
plt.show()

# Save the figure to the export directory
fig.savefig(f'{export_dir}/wavelet_decomposition_comparison.png', dpi=300, bbox_inches='tight')
print(f"Figure saved to {export_dir}/wavelet_decomposition_comparison.png")

In [None]:
# Get the absolute path to the current notebook directory
current_dir = Path().resolve()

# Set the project root directory (two levels up from notebooks if in experiments folder)
project_root = current_dir.parent

# Add the project root to sys.path so Python can find the utils module
sys.path.append(str(project_root))

from utils.load_data import load_data
from utils.feature_extraction import transform_data

X,y, y_binary = load_data()
# Group X and y by machine without forcing X into a 2D numpy array
machines = ['M01', 'M02', 'M03']
machine_data = {}

for m in machines:
    X_m = []
    y_m = []
    for xi, yi in zip(X, y):
        if yi.startswith(f"{m}_"):
            X_m.append(xi)  
            y_m.append(0 if yi.endswith("_good") else 1)

    # leave X_m as a list of nd‐arrays; convert y_m to numpy if you like
    machine_data[m] = (X_m, np.array(y_m))

# e.g.:
X_M01, y_M01 = machine_data['M01']
X_M02, y_M02 = machine_data['M02']
X_M03, y_M03 = machine_data['M03']

X_M01_tr, y_M01_tr = transform_data(X_M01, y_M01, label_type='binary')
X_M02_tr, y_M02_tr = transform_data(X_M02, y_M02, label_type='binary')
X_M03_tr, y_M03_tr = transform_data(X_M03, y_M03, label_type='binary')

# Create an overview of all the subsets
print("Dataset Overview:")
print(f"Total samples: {len(X)}")
print("\nMachine-specific breakdown:")
for machine in machines:
    X_m, y_m = machine_data[machine]
    good_samples = sum(1 for y in y_m if y == 0)
    bad_samples = sum(1 for y in y_m if y == 1)
    total_samples = len(y_m)
    
    print(f"\n{machine} Dataset:")
    print(f"  Total samples: {total_samples}")
    print(f"  Good samples: {good_samples} ({good_samples/total_samples:.2%})")
    print(f"  Bad samples: {bad_samples} ({bad_samples/total_samples:.2%})")

# Visualize class distribution
plt.figure(figsize=(12, 6))

# Create a subplot for each machine
for i, machine in enumerate(machines, 1):
    plt.subplot(1, 3, i)
    _, y_m = machine_data[machine]
    counts = np.bincount(y_m)
    bars = plt.bar(['Good', 'Bad'], counts, color=['green', 'red'])
    
    # Add count labels on top of each bar
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height,
                 f'{int(height)}',
                 ha='center', va='bottom')
    
    plt.title(f'{machine} Class Distribution')
    plt.ylabel('Count')
    
plt.tight_layout()
plt.show()