# Mutual information

## MI in Time domain

In [None]:
features = load_time_domain_features(['ax', 'ay', 'az', 'bx', 'by', 'bz'])
mi = calc_mutual_information(features, TD_COLUMNS, summary=True)
mi.plot.bar(figsize=(8, 5), grid=True, xlabel='Feature', ylabel='Mutual information', legend=False)
plt.show()

### MI between feature on axis and target fault state

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))  
sb.heatmap(calc_mutual_information(features, sel.TD_COLUMNS, summary=False), annot=True, ax=ax, cmap="Greens")
plt.show()

## MI in Frequency domain
- By fft window length
- By measurement point:{(ax, ay, az), {bx, by, bz})

In [None]:
WINDOW_SIZES = (2**8, 2**10, 2**12, 2**14, 2**16)

def show_freq_domain_mutual_info(features, cols):
    fig, ax = plt.subplots(1, 5, figsize=(20, 5))
    for i, win in enumerate(WINDOW_SIZES):
        x = features[
                (features['fft_window_length'] == win) &
                (features['axis'].isin(cols))
            ].dropna()
        print('FFT:', win, 'Number of rows:', len(x))
        mi = calc_mutual_information(x, columns, summary=True)

        o = ax.flatten()[i]
        o.bar(mi.index, mi.values.T[0])
        # Stylize bar graph
        o.grid(True)
        o.set_xlabel('Feature')
        o.set_ylabel('MI')
        o.set_title(f'FFT: {win}')
        # Rotate x labels by 45 deg
        o.set_xticks(o.get_xticks())
        o.set_xticklabels(o.get_xticklabels(), rotation=45, ha='right')

In [None]:
features = load_frequency_domain_features()
show_freq_domain_mutual_info(features, ['ax', 'ay', 'az'])
plt.show()

In [None]:
show_freq_domain_mutual_info(features, ['bx', 'by', 'bz'])
plt.show()

### Mutual information between feature in axis and various faults (predicted variable)

In [None]:
def mi_among_fault_and_axis(features, cols):
    fig, ax = plt.subplots(5, 1, figsize=(8, 20))
    
    for i, win in enumerate(WINDOW_SIZES):
        x = features[
            (features['fft_window_length'] == win) &
            (features['axis'].isin(cols))
        ].dropna()
        o = ax.flatten()[i]
        mi = calc_mutual_information(x, FD_COLUMNS, summary=False)
        sb.heatmap(mi, annot=True, ax=o, cmap="Greens")
        o.set_title(f'FFT: {win}')

AXIS = ['ax', 'ay', 'az', 'bx', 'by', 'bz']
features = pd.read_csv(FREQ_FEATURES_PATH)
features['fault'] = features['fault'].astype('category')
features['fft_window_length'] = features['fft_window_length'].astype('category')

mi_among_fault_and_axis(features, AXIS)
plt.show()

## TODO: MI in Freq domain: Rank order of features averaged among all window sizes

## MI in Wavelets

In [None]:
features = pd.read_csv(WPD_FEATURES_PATH)

WPD_AXIS = 'ax'
# More axis at once significantly reduces MI
features = features[features['axis'] == WPD_AXIS]                 # One axis
features['fault'] = features['fault'].astype('category')
#features = features[features['axis'].isin(['ax', 'ay', 'az'])]  # One measuremnt position

columns = [col for col in features.columns 
           if col not in ('fault', 'severity', 'seq', 'rpm', 'axis', 'feature')]
features.head()

In [None]:
features_energy = features[features['feature'] == 'energy']
print(len(features_energy))

mi = calc_mutual_information(features_energy, columns, summary=True)
mi.iloc[:30].plot.bar(figsize=(20, 4), grid=True, ylabel='MI', title='WPD Energy')
plt.show()

In [None]:
def plot_wpd_energy_ratio_per_level(features, wpd_axis):
    features = features[features['axis'].isin(wpd_axis)]  
    features_energy_ratio = features[features['feature'] == 'energy_ratio']
    # print(len(features_energy_ratio))
    
    fig, ax = plt.subplots(6, 1, figsize=(15, 20))
    
    for level in range(1, 7):
        cols = np.array(columns)
        cols = cols[np.char.startswith(cols, f'L{level}')]
        mi = calc_mutual_information(features_energy_ratio, cols, summary=True)
        
        o = ax.flatten()[level-1]
        o.bar(mi.index, mi.values.T[0])
        o.grid(True)
        o.set_xlabel('Feature')
        o.set_ylabel('MI')
        
        # Rotate x labels by 45 deg
        o.set_xticks(o.get_xticks())
        o.set_xticklabels(o.get_xticklabels(), rotation=45, ha='right')

    fig.suptitle(f'WPD energy ratio: Axis "{wpd_axis}"', fontsize=16, y=0.9)
    plt.show()

In [None]:
plot_wpd_energy_ratio_per_level(features, ['ax'])

In [None]:
plot_wpd_energy_ratio_per_level(features, ['ax', 'ay', 'az'])

In [None]:
features_entropy = features[features['feature'] == 'negentropy']
print(len(features_entropy))

mi = calc_mutual_information(features_entropy, columns, summary=True)
mi.iloc[:30].plot.bar(figsize=(20, 4), grid=True, ylabel='MI', title='WPD Negentropy')
plt.show()

In [None]:
features_kurtosis = features[features['feature'] == 'kurtosis']
print(len(features_kurtosis))

mi = calc_mutual_information(features_entropy, columns, summary=True)
mi.iloc[:30].plot.bar(figsize=(20, 4), grid=True, ylabel='MI', title='WPD Kurtosis')
plt.show()

In [None]:
def level_to_frequency_bands(level, fs):
    bin_count = 2 ** level
    bin_width = (fs / 2) / bin_count
    for bin in range(bin_count):
        a = bin * bin_width
        b = a + bin_width
        print(f'L{level}_{bin} = [{a}; {b}] Hz')

level_to_frequency_bands(level=4, fs=50000)