In [1]:
import numpy as np
import pandas as pd
import glob
import pickle

from scipy.stats import zscore

In [2]:
from hmmlearn import _hmmc

In [3]:
# config

np.set_printoptions(precision=3, suppress=True)

In [7]:
# data munging functions

def make_dfs(data_dir, first_n=1000):
    dfs = [
        (file, pd.read_csv(file)
        .sort_values("epoch")
        # .sort_values("step")
        .reset_index(drop=True)
        .head(first_n)
        )  
        for file in glob.glob(data_dir + "*")
    ]
    file_names, dfs = zip(*dfs)
    return file_names, dfs

def make_hmm_data(dfs, cols):
    dfs = [df[cols] for df in dfs]
    
    data = np.vstack(
        [np.apply_along_axis(zscore, 0, df.to_numpy()) for df in dfs]
    )
    return data

def break_list_by_lengths(lst, lengths):
    result = []
    start_index = 0
    
    for length in lengths:
        sublist = lst[start_index:start_index + length]
        result.append(sublist)
        start_index += length
    
    return result


def munge_data(model_pth, data_dir, cols, n_components, first_n=1000):
    '''
    Convert raw CSV data into expected format for analysis
    '''
    file_names, dfs = make_dfs(data_dir, first_n=first_n)
    data = make_hmm_data(dfs, cols)
    lengths =  [len(df) for df in dfs]

    with open(model_pth, 'rb') as f:
        models = pickle.load(f)

    model = models['best_models'][n_components-1]
    print(model.score(data, lengths=lengths))
    best_predictions = break_list_by_lengths(model.predict(data, lengths=lengths), lengths)
    
    return model, data, best_predictions, lengths

In [5]:
# derivative computation functions

def softmax_with_overflow(logits):
    '''
    log-sum-exp
    '''
    exp_logits = np.exp(logits - np.max(logits))
    return exp_logits / exp_logits.sum()

def find_i_followed_by_j(lst, i, j):
    '''
    Find transition in the estimated hidden state
    '''
    indexes = [index for index in range(len(lst) - 1) if lst[index] == i and lst[index + 1] == j]
    return indexes

def get_derivatives(X, model):
    '''
    Compute the derivative d/dz_t p(s_t = k | z_{1:t}) for the entire forward lattice.
    '''
    derivatives = []
    
    log_frameprob = model._compute_log_likelihood(X)
    log_probij, fwdlattice = _hmmc.forward_log(
                model.startprob_, model.transmat_, log_frameprob)
    n_components = fwdlattice.shape[1] # can be computed another way
    covars = [np.linalg.inv(model.covars_[i]) for i in range(n_components)]
    
    for i in range(len(X)):
        derivatives_i = []
        probs = softmax_with_overflow(fwdlattice[i]) 
        Z = np.sum([probs[j] * covars[j] @ (model.means_[j] - X[i]) for j in range(n_components)])
        
        for component in range(n_components):
            derivatives_i.append(
                covars[component] @ (model.means_[component] - X[i]) - Z 
            )
            
        derivatives.append(derivatives_i)

    return derivatives

def get_features_for_transition(model, data, best_predictions, lengths, phase_1, phase_2):
    '''
    For each time a transition (phase_1 -> phase_2) happens, compute the derivatives for each feature.
    
    This computation is slightly inefficient, in that it computes the entire forward lattice of derivatives.
    In practice, this inefficiency doesn't seem to be an issue in terms of runtime.
    '''
    features = []
    for (i, datum) in enumerate(break_list_by_lengths(data, lengths)):
        preds = best_predictions[i]
        indexes = find_i_followed_by_j(preds, phase_1, phase_2)
        if indexes != []:
            derivatives = np.array(get_derivatives(datum, model))
            for idx in indexes:
                features.append(derivatives[idx, phase_2])
    return features

def get_difference_bt_means(model, phase_1, phase_2):
    return model.means_[phase_2] - model.means_[phase_1]

def characterize_transition(model, data, best_predictions, cols, lengths, i, j):
    '''
    Compute the average derivative for each feature, sort features by highest absolute value
    '''
    features = get_features_for_transition(model, data, best_predictions, lengths, i, j)
    print(f"Number of times transition happened: {len(features)}")
    features = np.mean(features, axis=0)
    order = np.argsort(np.abs(features))[::-1]
    print(cols[order])

    feature_changes = np.array(get_difference_bt_means(model, i, j))
    print(feature_changes[order])

In [18]:
# example usage

# CIFAR 100
cols = np.array([
    "l1",
    "l2",
    "trace",
    "spectral",
    "code_sparsity",
    "computational_sparsity",
    "mean_singular_value",
    "var_singular_value",
    "mean_w",
    "median_w",
    "var_w",
    "mean_b",
    "median_b",
    "var_b",
])

model_pth='/scratch/myh2014/modeling-training/data/model_selection/32/cifar100_v3/True_True/--base.pkl'
data_dir='/scratch/myh2014/modeling-training/data/training_runs/cifar100_v3/True_True/'
n_components=5

model, data, best_predictions, lengths = munge_data(model_pth, data_dir, cols, n_components)

print('CHANGE 4 -> 1')

characterize_transition(model, data, best_predictions, cols, lengths, 4, 1)

print('CHANGE 1 -> 3')

characterize_transition(model, data, best_predictions, cols, lengths, 1, 3)

print('CHANGE 3 -> 2')

characterize_transition(model, data, best_predictions, cols, lengths, 3, 2)

print('CHANGE 2 -> 0')

characterize_transition(model, data, best_predictions, cols, lengths, 2, 0)

338166.5589622997
CHANGE 4 -> 1
Number of times transition happened: 40
['l1' 'mean_singular_value' 'var_singular_value' 'l2' 'var_w' 'median_w'
 'spectral' 'trace' 'var_b' 'median_b' 'computational_sparsity' 'mean_b'
 'mean_w' 'code_sparsity']
[ 0.627  0.563  0.705  0.644  0.429 -0.938  0.842 -0.768 -0.02   0.346
 -1.338  0.137 -0.712 -1.825]
CHANGE 1 -> 3
Number of times transition happened: 40
['mean_singular_value' 'l2' 'l1' 'var_w' 'var_singular_value' 'spectral'
 'code_sparsity' 'median_w' 'var_b' 'computational_sparsity' 'median_b'
 'mean_b' 'trace' 'mean_w']
[ 0.755  0.767  0.762  0.716  0.723  0.751 -0.699 -0.859  0.063 -0.652
 -0.056 -0.376 -0.722 -0.81 ]
CHANGE 3 -> 2
Number of times transition happened: 40
['l2' 'mean_singular_value' 'var_singular_value' 'var_w' 'median_w'
 'computational_sparsity' 'median_b' 'var_b' 'mean_b' 'trace'
 'code_sparsity' 'mean_w' 'spectral' 'l1']
[ 0.804  0.825  0.774  0.86  -0.709 -0.57  -0.114  0.248 -0.801 -0.791
 -0.436 -0.788  0.73   0.807