In [1]:
import numpy as np
import pickle
import glob
import pandas as pd
import os
from sklearn.linear_model import LinearRegression

from scipy.stats import zscore, t

from src.utils import get_markov_chain

In [2]:
def make_dfs(data_dir, first_n=1000):
    dfs = [
        (file, pd.read_csv(file)
        # .sort_values("epoch")
        .sort_values("step")
        .reset_index(drop=True)
        .head(first_n)
        )  
        for file in glob.glob(data_dir + "*")
    ]
    file_names, dfs = zip(*dfs)
    return file_names, dfs

def make_hmm_data(dfs, cols):
    dfs = [df[cols] for df in dfs]
    
    data = np.vstack(
        [np.apply_along_axis(zscore, 0, df.to_numpy()) for df in dfs]
    )
    return data

def break_list_by_lengths(lst, lengths):
    result = []
    start_index = 0
    
    for length in lengths:
        sublist = lst[start_index:start_index + length]
        result.append(sublist)
        start_index += length
    
    return result

def save_markov_chain(model_pth, data_dir, output_pth, n_components, trim=True):
    with open(model_pth, 'rb') as f:
        model = pickle.load(f)['best_models'][n_components-1]

    if trim:
        file_names, dfs = make_dfs(data_dir)
        transitions = set()

        for file_name, df in zip(file_names, dfs):
            predictions = df['latent_state'].to_numpy().astype(int)
            transitions.update(zip(predictions[:-1], predictions[1:]))
            # print(file_name, transitions)

        # create a list of all existing transitions
        transmat = np.zeros((n_components, n_components))
        # print(transitions)
        for i, j in transitions:
            transmat[i, j] = model.transmat_[i, j]
    else:
        transmat = model.transmat_
    dot = get_markov_chain(np.round(transmat, decimals=3))
    dot.render(f'{output_pth}_{n_components}', format='png')

def get_count_vector(prediction, n_components):
    count_vector = np.zeros(n_components)
    for i in range(n_components):
        count_vector[i] = np.sum(prediction == i)
    return count_vector


def save_predictions(model_pth, data_dir, output_pth, n_components, cols_to_keep, first_n=1000):
    file_names, dfs = make_dfs(data_dir, first_n)
    data = make_hmm_data(dfs, cols_to_keep)
    lengths =  [len(df) for df in dfs]

    with open(model_pth, 'rb') as f:
        models = pickle.load(f)

    model = models['best_models'][n_components-1]
    print(model.score(data, lengths=lengths))
    best_predictions = break_list_by_lengths(model.predict(data, lengths=lengths), lengths)
    # print(best_predictions)

    dfs = [df.assign(latent_state=prediction) for df, prediction in zip(dfs, best_predictions)]

    # make output path if does not exist
    if not os.path.exists(output_pth):
        os.makedirs(output_pth)
    
    for file_name, df in zip(file_names, dfs):
        base_name = file_name.split('/')[-1]
        df.to_csv(f'{output_pth}{n_components}_{base_name}')


def get_convergence_epochs(dfs, column, threshold=0.5):
    conv = []
    max_len = max([len(df) for df in dfs])
    for i, df in enumerate(dfs):
        first_index = df[df[column] > threshold].index.tolist()

        if first_index:
            conv.append(first_index[0])
        else:
            conv.append(max_len)
            
    return conv

def get_state_correlations(file_paths, column, threshold):
    file_names, dfs = make_dfs(file_paths)
    num_components = int(file_names[0].split('/')[-1].split('_')[0])
    convergence_epochs = get_convergence_epochs(dfs, column, threshold)

    corrs = []
    binary_states = []
    t_stats = []
    p_vals = []
    for i in range(num_components):
        has_state = [1 if i in df['latent_state'].values else 0 for df in dfs]
        corr = np.corrcoef(has_state, convergence_epochs)[0, 1]
        corrs.append(corr)
        binary_states.append(has_state)

        n = len(has_state)
        df = n - 2
        t_stat = corr * np.sqrt(df) / np.sqrt(1 - corr**2)
        t_stats.append(t_stat)
        p_val = 2 * (1 - t.cdf(t_stat, df=df))
        p_vals.append(p_val)


    return corrs, t_stats, p_vals, convergence_epochs, binary_states


In [4]:
# checking the QQP results, because the auto-generated graph is hard to read

# with open('/scratch/myh2014/modeling-training/data/model_selection/32/glue_singletons/QQP-full-base.pkl', 'rb') as f:
#     data = pickle.load(f)

In [3]:
cols_to_keep = [
    "l1",
    "l2",
    "trace",
    "spectral",
    "code_sparsity",
    "computational_sparsity",
    "mean_lambda",
    "variance_lambda",
    "mean_w",
    "median_w",
    "var_w",
    "mean_b",
    "median_b",
    "var_b",
]

# cols_to_keep = [
#     "l1",
#     "l2",
#     "trace",
#     "spectral",
#     "code_sparsity",
#     "computational_sparsity",
#     "mean_singular_value",
#     "var_singular_value",
#     "mean_w",
#     "median_w",
#     "var_w",
#     "mean_b",
#     "median_b",
#     "var_b",
# ]

In [None]:
# make and save predictions
# no need to make predictions for GLUE, MNLI, QQP, since they have only 1 path through the markov chain

# sparse parities
save_predictions(
    model_pth='/scratch/myh2014/modeling-training/data/model_selection/32/parities/parities_v3-full-base.pkl',
    data_dir = '/scratch/myh2014/modeling-training/data/training_runs/parities_v3/',
    output_pth='/scratch/myh2014/modeling-training/results/parities/',
    cols_to_keep=cols_to_keep,
    metric_column = 'eval_accuracy',
    threshold = 0.9
)


In [9]:
# # MNIST
save_predictions(
    model_pth='/scratch/myh2014/modeling-training/data/model_selection/32/mnist_250-full-base.pkl',
    data_dir = '/scratch/myh2014/modeling-training/data/training_runs/mnist_v3/',
    output_pth='/scratch/myh2014/modeling-training/results/mnist_250/',
    cols_to_keep=cols_to_keep,
    metric_column = 'eval_accuracy',
    threshold = 0.97
)

[60, 61, 57, 67, 62, 59, 60, 65, 59, 57, 63, 60, 57, 57, 68, 59, 64, 59, 63, 65, 67, 63, 58, 65, 62, 67, 68, 60, 62, 62, 63, 55, 61, 70, 65, 63, 55, 66, 67, 59]


In [4]:
# modular addition
# remember to turn off sorting b/c there's no step or epoch (old file)
save_predictions(
    model_pth='/scratch/myh2014/modeling-training/data/model_selection/32/modular/modular_v3-full-base.pkl',
    data_dir = '/scratch/myh2014/modeling-training/data/training_runs/modular_v3/',
    output_pth='/scratch/myh2014/modeling-training/results/modular/',
    cols_to_keep=cols_to_keep,
    metric_column = 'eval_accuracy',
    threshold=0.9
)

[145, 225, 213, 361, 296, 360, 185, 292, 303, 309, 396, 293, 256, 184, 487, 306, 228, 227, 165, 225, 176, 1000, 650, 261, 670, 390, 479, 330, 403, 234, 110, 229, 253, 215, 225, 310, 282, 210, 350, 329]


In [6]:
# MultiBERTs
# remember to change "step" --> "epoch" (old file)
save_predictions(
    model_pth='/scratch/myh2014/modeling-training/data/model_selection/32/multiberts/multiberts_diag.pkl',
    data_dir = '/scratch/myh2014/modeling-training/data/training_runs/multiberts/',
    output_pth='/scratch/myh2014/modeling-training/results/multiberts_diag/',
    # n_components=5,
    cols_to_keep=cols_to_keep,
    metric_column = 'train_loss',
    threshold=1.6
)

[0, 0, 0, 0, 0]


In [51]:
save_predictions(
    model_pth='/scratch/myh2014/modeling-training/data/model_selection/32/mnist_cnn--base.pkl',
    data_dir='/scratch/myh2014/modeling-training/data/training_runs/mnist_cnn/',
    output_pth='/scratch/myh2014/modeling-training/results/mnist_cnn/',
    n_components=5,
    cols_to_keep=cols_to_keep,
    first_n=200
)

140388.41684841784


In [59]:
save_predictions(
    model_pth='/scratch/myh2014/modeling-training/data/model_selection/32/cifar100_cnn--base.pkl',
    data_dir='/scratch/myh2014/modeling-training/data/training_runs/cifar100_cnn/',
    output_pth='/scratch/myh2014/modeling-training/results/cifar100_cnn/',
    n_components=5,
    cols_to_keep=cols_to_keep,
    first_n=300
)

325435.1311040439


In [12]:
save_predictions(
    model_pth='/scratch/myh2014/modeling-training/data/model_selection/32/cifar100_cnn_variants/--use_batch_norm/--base.pkl',
    data_dir='/scratch/myh2014/modeling-training/data/training_runs/cifar100_cnn_variants/--use_batch_norm/',
    output_pth='/scratch/myh2014/modeling-training/results/cifar100_cnn_variants/--use_batch_norm/',
    n_components=6,
    cols_to_keep=cols_to_keep,
    first_n=300
)

295927.78972488735


In [13]:
save_predictions(
    model_pth='/scratch/myh2014/modeling-training/data/model_selection/32/cifar100_cnn_variants/--use_residual/--base.pkl',
    data_dir='/scratch/myh2014/modeling-training/data/training_runs/cifar100_cnn_variants/--use_residual/',
    output_pth='/scratch/myh2014/modeling-training/results/cifar100_cnn_variants/--use_residual/',
    n_components=6,
    cols_to_keep=cols_to_keep,
    first_n=300
)

293187.5448208586


In [14]:
save_predictions(
    model_pth='/scratch/myh2014/modeling-training/data/model_selection/32/cifar100_cnn_variants/nothing/--base.pkl',
    data_dir='/scratch/myh2014/modeling-training/data/training_runs/cifar100_cnn_variants/nothing/',
    output_pth='/scratch/myh2014/modeling-training/results/cifar100_cnn_variants/nothing/',
    n_components=3,
    cols_to_keep=cols_to_keep,
    first_n=300
)

123601.67532493956


In [6]:
save_predictions(
    model_pth='/scratch/myh2014/modeling-training/data/model_selection/32/cifar100_v3/False_False/--base.pkl',
    data_dir='/scratch/myh2014/modeling-training/data/training_runs/cifar100_v3/False_False/',
    output_pth='/scratch/myh2014/modeling-training/results/cifar100_v3/False_False/',
    n_components=5,
    cols_to_keep=cols_to_keep,
    first_n=600
)

303606.83113125904


In [7]:
save_predictions(
    model_pth='/scratch/myh2014/modeling-training/data/model_selection/32/cifar100_v3/True_True/--base.pkl',
    data_dir='/scratch/myh2014/modeling-training/data/training_runs/cifar100_v3/True_True/',
    output_pth='/scratch/myh2014/modeling-training/results/cifar100_v3/True_True/',
    n_components=5,
    cols_to_keep=cols_to_keep,
    first_n=300
)

330989.21925717127


In [22]:
save_predictions(
    model_pth='/scratch/myh2014/modeling-training/data/model_selection/32/parities_adam_ln-full-base.pkl',
    data_dir='/scratch/myh2014/modeling-training/data/training_runs/parities_adam_ln/',
    output_pth='/scratch/myh2014/modeling-training/results/parities_adam_ln/',
    n_components=2,
    cols_to_keep=cols_to_keep,
    first_n=30
)

-7053.536360660506


In [None]:
save_predictions(
    model_pth='/scratch/myh2014/modeling-training/data/model_selection/32/parities_adam_ln-full-base.pkl',
    data_dir='/scratch/myh2014/modeling-training/data/training_runs/parities_adam_ln/',
    output_pth='/scratch/myh2014/modeling-training/results/parities_adam_ln/',
    n_components=2,
    cols_to_keep=cols_to_keep,
    first_n=30
)

In [25]:
save_predictions(
    model_pth='/scratch/myh2014/modeling-training/data/model_selection/32/mnist_v2--base.pkl',
    data_dir='/scratch/myh2014/modeling-training/data/training_runs/mnist_v2/',
    output_pth='/scratch/myh2014/modeling-training/results/mnist_v2/',
    n_components=6,
    cols_to_keep=cols_to_keep,
    first_n=250
)

288493.800843404


In [8]:
# generate the Markov chains

# sparse parities
# save_markov_chain(
#     model_pth='/scratch/myh2014/modeling-training/data/model_selection/32/parities/parities_v3-full-base.pkl',
#     data_dir='/scratch/myh2014/modeling-training/results/parities/',
#     output_pth='/scratch/myh2014/modeling-training/figures/markov_chains/parities/full',
#     n_components=6 
# )

# # modular addition
# save_markov_chain(
#     model_pth='/scratch/myh2014/modeling-training/data/model_selection/32/modular/modular_v3-full-base.pkl',
#     data_dir='/scratch/myh2014/modeling-training/results/modular/',
#     output_pth='/scratch/myh2014/modeling-training/figures/markov_chains/modular/full',
#     n_components=4 
# )


save_markov_chain(
    model_pth='/scratch/myh2014/modeling-training/data/model_selection/32/modular_v3-diag-ablation2.pkl',
    data_dir='/scratch/myh2014/modeling-training/results/modular/',
    output_pth='/scratch/myh2014/modeling-training/figures/markov_chains/modular/ablation2',
    n_components=8,
    trim=False
)

# # MNIST
# save_markov_chain(
#     model_pth='/scratch/myh2014/modeling-training/data/model_selection/32/mnist_250-full-base.pkl',
#     data_dir='/scratch/myh2014/modeling-training/results/mnist_250/',
#     output_pth='/scratch/myh2014/modeling-training/figures/markov_chains/mnist/full',
#     n_components=6
# )

# # MultiBERTs
# save_markov_chain(
#     model_pth='/scratch/myh2014/modeling-training/data/model_selection/32/multiberts/multiberts_diag.pkl',
#     data_dir='/scratch/myh2014/modeling-training/results/multiberts_diag/',
#     output_pth='/scratch/myh2014/modeling-training/figures/markov_chains/multiberts_diag',
#     n_components=5 
# )

In [6]:
save_markov_chain(
    model_pth='/scratch/myh2014/modeling-training/data/model_selection/32/parities_v3_swap--modular_best.pkl',
    data_dir='/scratch/myh2014/modeling-training/results/parities/',
    output_pth='/scratch/myh2014/modeling-training/figures/markov_chains/parities/modular_swap',
    n_components=8,
    trim=False
)

save_markov_chain(
    model_pth='/scratch/myh2014/modeling-training/data/model_selection/32/modular_v3_swap--parities_best.pkl',
    data_dir='/scratch/myh2014/modeling-training/results/modular/',
    output_pth='/scratch/myh2014/modeling-training/figures/markov_chains/modular/parities_swap',
    n_components=8,
    trim=False
)

In [60]:
# save_markov_chain(
#     model_pth='/scratch/myh2014/modeling-training/data/model_selection/32/mnist_cnn--base.pkl',
#     data_dir='/scratch/myh2014/modeling-training/results/mnist_cnn/',
#     output_pth='/scratch/myh2014/modeling-training/figures/markov_chains/mnist_cnn/',
#     n_components=5,
#     trim=True
# )

save_markov_chain(
    model_pth='/scratch/myh2014/modeling-training/data/model_selection/32/cifar100_cnn--base.pkl',
    data_dir='/scratch/myh2014/modeling-training/results/cifar100_cnn/',
    output_pth='/scratch/myh2014/modeling-training/figures/markov_chains/cifar100/',
    n_components=5,
    trim=True
)

In [15]:
save_markov_chain(
    model_pth='/scratch/myh2014/modeling-training/data/model_selection/32/cifar100_cnn_variants/--use_batch_norm/--base.pkl',
    data_dir='/scratch/myh2014/modeling-training/results/cifar100_cnn_variants/--use_batch_norm/',
    output_pth='/scratch/myh2014/modeling-training/figures/markov_chains/cifar100_cnn_variants/--use_batch_norm/',
    n_components=6,
    trim=True
)

In [17]:
save_markov_chain(
    model_pth='/scratch/myh2014/modeling-training/data/model_selection/32/cifar100_cnn_variants/--use_residual/--base.pkl',
    data_dir='/scratch/myh2014/modeling-training/results/cifar100_cnn_variants/--use_residual/',
    output_pth='/scratch/myh2014/modeling-training/figures/markov_chains/cifar100_cnn_variants/--use_residual/',
    n_components=6,
    trim=True
)

save_markov_chain(
    model_pth='/scratch/myh2014/modeling-training/data/model_selection/32/cifar100_cnn_variants/nothing/--base.pkl',
    data_dir='/scratch/myh2014/modeling-training/results/cifar100_cnn_variants/nothing/',
    output_pth='/scratch/myh2014/modeling-training/figures/markov_chains/cifar100_cnn_variants/nothing/',
    n_components=3,
    trim=True
)

In [8]:
save_markov_chain(
    model_pth='/scratch/myh2014/modeling-training/data/model_selection/32/cifar100_v3/False_False/--base.pkl',
    data_dir='/scratch/myh2014/modeling-training/results/cifar100_v3/False_False/',
    output_pth='/scratch/myh2014/modeling-training/figures/markov_chains/cifar100_v3/False_False/',
    n_components=5,
    trim=True
)

save_markov_chain(
    model_pth='/scratch/myh2014/modeling-training/data/model_selection/32/cifar100_v3/True_True/--base.pkl',
    data_dir='/scratch/myh2014/modeling-training/results/cifar100_v3/True_True/',
    output_pth='/scratch/myh2014/modeling-training/figures/markov_chains/cifar100_v3/True_True/',
    n_components=5,
    trim=True
)


In [21]:
save_markov_chain(
    model_pth='/scratch/myh2014/modeling-training/data/model_selection/32/parities_adam_ln-full-base.pkl',
    data_dir='/scratch/myh2014/modeling-training/results/parities_adam_ln/',
    output_pth='/scratch/myh2014/modeling-training/figures/markov_chains/parities_adam_ln/',
    n_components=2,
    trim=True
)

In [27]:
save_markov_chain(
    model_pth='/scratch/myh2014/modeling-training/data/model_selection/32/mnist_v2--base.pkl',
    data_dir='/scratch/myh2014/modeling-training/results/mnist_v2/',
    output_pth='/scratch/myh2014/modeling-training/figures/markov_chains/mnist_v2/',
    n_components=6,
    trim=True
)

In [6]:
# correlate states with convergence time to find detour states
# nan means the states are obligatory

# sparse parities
corrs, t_stats, p_val, convergence_epochs, binary_states = get_state_correlations(
    file_paths='/scratch/myh2014/modeling-training/results/parities/',
    column='eval_accuracy',
    threshold=0.9
)
print(corrs)
print(p_val)

[0.37528728341108014, nan, nan, nan, nan, nan]
[0.017023075499225326, nan, nan, nan, nan, nan]


  c /= stddev[:, None]
  c /= stddev[None, :]


In [6]:
# modular addition
corrs, t_stats, p_val, convergence_epochs, binary_states = get_state_correlations(
    file_paths='/scratch/myh2014/modeling-training/results/modular/',
    column='eval_accuracy',
    threshold=0.5
)
print(corrs)
print(p_val)

[-0.8203864115241618, nan, -0.21177564173644814, -0.7080030048781754, -0.7239012669101329, 0.4869906827595854]
[1.9999999999076166, nan, 1.8104300160431606, 1.9999996775966193, 1.9999998699558523, 0.0014387457313806973]


  c /= stddev[:, None]
  c /= stddev[None, :]


In [104]:
convergence_epoch = np.array(convergence_epoch)
binary_states = np.array(binary_states).astype(bool)

In [114]:
for i in [0,1,2,4]:
    print(np.mean(convergence_epoch[binary_states[i]]))

276.81081081081084
302.6756756756757
276.81081081081084
456.625


In [29]:
# MNIST
corrs, t_stats, p_val, convergence_epochs, binary_states = get_state_correlations(
    file_paths="/scratch/myh2014/modeling-training/results/mnist/",
    column='eval_accuracy',
    threshold=0.98
)
print(corrs)
print(p_val)

[0.0953948567951651, -0.0054531581093004685, 0.033178795860150774, nan, nan, nan, nan]
[0.558187098266139, 1.0266406977873728, 0.8389451808018655, nan, nan, nan, nan]


  c /= stddev[:, None]
  c /= stddev[None, :]


In [121]:
# MultiBERTs
# remember, need to flip sign
get_state_correlations(
    file_paths="/scratch/myh2014/modeling-training/results/multiberts_diag/",
    column='train_loss',
    threshold=1.7
)

0 -0.6123724356957947
1 nan
2 nan
3 0.6123724356957947
4 nan


  c /= stddev[:, None]
  c /= stddev[None, :]


([-0.6123724356957947, nan, nan, 0.6123724356957947, nan],
 [19, 19, 22, 19, 19],
 [[0, 1, 0, 1, 1],
  [1, 1, 1, 1, 1],
  [1, 1, 1, 1, 1],
  [1, 0, 1, 0, 0],
  [1, 1, 1, 1, 1]])

In [None]:
# TODO: move over the plotting code. Detour states should finish with a presentation of detours.