## Gaussion Process Regression

This notebook is used for performing various GP experiments and saving the output for further analysis.

In [None]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pylab as plt
import torch

from synapse_utils import io

import pyro
import pyro.contrib.gp as gp
import pyro.distributions as dist

from pyro.infer import TraceMeanField_ELBO
from pyro.infer.util import torch_backward, torch_item

from sklearn.decomposition import PCA

import pickle

from collections import defaultdict

from sklearn.metrics import roc_curve, roc_auc_score
# from sklearn.cluster import KMeans
from cuml import KMeans

# assert pyro.__version__.startswith('1.7.0')
pyro.set_rng_seed(0)

## Train and test data

In [None]:
repo_root = '../..'
checkpoint_path = '../../output/checkpoint__synapseclr__so3__second_stage'
output_root = '../../output/gp'

node_idx_list = [0, 1, 2, 3]
reload_epoch = 99
feature_hook = 'encoder.fc'
dataset_path = '../../data/MICrONS__L23__8_8_40__processed'
l2_normalize = False

contamination_indices_path = os.path.join(
    checkpoint_path, 'indices', 'contamination_meta_df_row_indices.npy')

device = torch.device('cuda')
dtype = torch.float32

training_fraction = 0.9
perform_class_balancing = True
perform_pca = True
n_pca_features = 50
k_fold = 3
random_seed = 42

kernel_type = 'rbf'
n_inducing_points = 1000
z_jitter = 0.1
elbo_type = 'mean-field'

# initial kernel parameters
init_gaussian_variance = 0.5
init_rbf_variance = 1.0
init_rbf_lengthscale = 0.5
init_linear_variance = 1.0
init_constant_variance = 1.0
init_laplace_variance = 1.0
init_laplace_lengthscale = 0.5

print_loss_every = 1000
eval_every = 1000

lr = 0.001
num_optim_steps = 10_000 + 1

trait_key_list = [
    'cleft_size_log1p_zscore',
    'presyn_soma_dist_log1p_zscore',
    'postsyn_soma_dist_log1p_zscore',
    'mito_size_pre_vx_log1p_zscore_zi',
    'mito_size_post_vx_log1p_zscore_zi',
    'pre_and_post_cell_types',
    'pre_cell_type',
    'post_cell_type',
    'has_mito_pre',
    'has_mito_post'
]

trait_type_list = [
    'continuous',
    'continuous',
    'continuous',
    'continuous',
    'continuous',
    'categorical',
    'categorical',
    'categorical',
    'categorical',
    'categorical'
]

trait_num_categories_list = [
    None,
    None,
    None,
    None,
    None,
    4,
    2,
    2,
    2,
    2
]

trait_control_list = [
    None,
    None,
    None,
    'has_mito_pre',
    'has_mito_post',
    None,
    None,
    None,
    None,
    None
]

def get_augmented_table(meta_ext_df: pd.DataFrame) -> pd.DataFrame:
    
    # combined pre and post cell types
    pre_post_cell_types_map = {
        (0, 0): 0,
        (0, 1): 1,
        (1, 0): 2,
        (1, 1): 3,
    }
    
    pre_cell_type_values = meta_ext_df['pre_cell_type'].values
    post_cell_type_values = meta_ext_df['post_cell_type'].values
    pre_post_cell_type_values = np.asarray(list(
        map(pre_post_cell_types_map.get,
            zip(pre_cell_type_values, post_cell_type_values))))
    
    aug_meta_ext_df = meta_ext_df.copy()
    aug_meta_ext_df['pre_and_post_cell_types'] = pre_post_cell_type_values
    
    return aug_meta_ext_df

def generate_manifest(var_dict: dict) -> dict:
    attributes = [
        'experiment_prefix',
        'checkpoint_path',
        'reload_epoch',
        'feature_hook',
        'l2_normalize',
        'k_fold',
        'perform_class_balancing',
        'perform_pca',
        'n_pca_features',
        'n_inducing_points',
        'z_jitter',
        'init_rbf_variance',
        'init_rbf_lengthscale',
        'init_gaussian_variance',
        'init_linear_variance',
        'init_constant_variance',
        'kernel_type',
        'elbo_type',
        'lr',
        'num_optim_steps',
        'trait_key_list',
        'trait_type_list',
        'trait_num_categories_list',
        'trait_control_list'
    ]
    manifest = {attribute: var_dict[attribute] for attribute in attributes}
    return manifest

In [3]:
# first wave
experiment_prefix = 'first_wave'
perform_pca = False

experiment_manifest_list = []

kernel_type = 'rbf'
n_inducing_points = 1000
for feature_hook, l2_normalize in [
        ('encoder.fc', False),
        ('projector.mlp.0', False),
        ('projector.mlp.3', False),
        ('projector.mlp.3', True)]:
    manifest = generate_manifest(locals())
    experiment_manifest_list.append(manifest)
    
n_inducing_points = 50
for feature_hook, l2_normalize in [
        ('encoder.fc', False),
        ('projector.mlp.0', False),
        ('projector.mlp.3', False),
        ('projector.mlp.3', True)]:
    manifest = generate_manifest(locals())
    experiment_manifest_list.append(manifest)

In [11]:
# second wave
experiment_prefix = 'second_wave'
perform_pca = False

experiment_manifest_list = []

kernel_type = 'laplace'
feature_hook = 'encoder.fc'
l2_normalize = False
for n_inducing_points in [10, 20, 50, 100, 200, 500]:
    manifest = generate_manifest(locals())
    experiment_manifest_list.append(manifest)

kernel_type = 'linear'
feature_hook = 'encoder.fc'
l2_normalize = False
for n_inducing_points in [10, 5]:
    manifest = generate_manifest(locals())
    experiment_manifest_list.append(manifest)

In [17]:
# third wave
experiment_prefix = 'third_wave'
perform_pca = True
kernel_type = 'rbf'
feature_hook = 'encoder.fc'
l2_normalize = False

experiment_manifest_list = []

for n_pca_features in [50, 100]:
    for n_inducing_points in [50, 500]:        
        manifest = generate_manifest(locals())
        experiment_manifest_list.append(manifest)

In [14]:
# fourth wave
experiment_prefix = 'fourth_wave'
perform_pca = False
kernel_type = 'rbf'
feature_hook = 'encoder.fc'
l2_normalize = False

experiment_manifest_list = []

for n_inducing_points in [10, 20, 50, 100, 200, 5, 300, 400, 500, 600, 700, 800, 900, 1000]:
    manifest = generate_manifest(locals())
    experiment_manifest_list.append(manifest)

In [51]:
# fifth wave
experiment_prefix = 'fifth_wave'
checkpoint_path = '/home/jupyter/dev/data/checkpoint__medicalnet__resnet18_23_dataset'
node_idx_list = [0]
reload_epoch = 0
perform_pca = False
kernel_type = 'rbf'
feature_hook = 'encoder.fc'
l2_normalize = False

experiment_manifest_list = []

for n_inducing_points in [5, 10, 20, 50, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]:
    manifest = generate_manifest(locals())
    experiment_manifest_list.append(manifest)

In [68]:
# sixth wave
experiment_prefix = 'sixth_wave'
checkpoint_path = '/home/jupyter/dev/data/checkpoint__random'
node_idx_list = [0]
reload_epoch = 0
perform_pca = False
kernel_type = 'rbf'
feature_hook = 'encoder.fc'
l2_normalize = False

experiment_manifest_list = []

for n_inducing_points in [5, 10, 20, 50, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]:
    manifest = generate_manifest(locals())
    experiment_manifest_list.append(manifest)

In [69]:
save_results = True
notebook_mode = True

if not notebook_mode:
    start_experiment_index = int(sys.argv[1])
    end_experiment_index = int(sys.argv[2])

else:
    start_experiment_index = 0
    end_experiment_index = len(experiment_manifest_list)

In [None]:
for experiment_index, manifest in list(
        enumerate(experiment_manifest_list))[start_experiment_index:end_experiment_index]:

    # set local variables from the manifest
    for key, value in manifest.items():
        setattr(sys.modules[__name__], key, value)

    # basic checks
    n_traits = len(trait_key_list)
    assert len(trait_type_list) == n_traits
    assert len(trait_control_list) == n_traits

    # announce
    print(f'Starting experiment {experiment_index} ...')
    print(manifest)
    print()

    # load features
    features_nf, meta_df, meta_ext_df = io.load_features(
        checkpoint_path,
        node_idx_list,
        reload_epoch,
        feature_hook=feature_hook,
        dataset_path=dataset_path,
        l2_normalize=l2_normalize,
        contamination_indices_path=contamination_indices_path)
    
    # scale
    features_nf = features_nf / np.std(features_nf)
    
    if perform_pca:
        features_nf = PCA(n_pca_features).fit_transform(features_nf)
    
    # add combined columns to the table (if necessary)
    meta_ext_df = get_augmented_table(meta_ext_df)
    
    # generating cross-validation data splits
    rng = np.random.RandomState(random_seed)
    n_traits = len(trait_key_list)

    synapse_ids_to_meta_ext_df_row_idx_map = {
        synapse_id: row_idx
        for row_idx, synapse_id in enumerate(meta_ext_df['synapse_id'].values)}

    synapse_ids_to_meta_df_row_idx_map = {
        synapse_id: row_idx
        for row_idx, synapse_id in enumerate(meta_df['synapse_id'].values)}

    train_meta_ext_df_dict = dict()
    test_meta_ext_df_dict = dict()

    for i in range(n_traits):

        trait_key = trait_key_list[i]
        trait_type = trait_type_list[i]
        trait_num_categories = trait_num_categories_list[i]
        trait_control = trait_control_list[i]

        if trait_type == 'categorical':
            per_category_indices = [
                np.nonzero(meta_ext_df[trait_key].values == category_index)[0]
                for category_index in range(trait_num_categories)]
        else:
            per_category_indices = None

        for k in range(k_fold):

            # if continuous, no class balancing is needed
            if trait_type == 'continuous':

                n_annotated = len(meta_ext_df)
                n_train = int(n_annotated * training_fraction)
                n_test = n_annotated - n_train
                perm = rng.permutation(n_annotated)
                train_indices = perm[:n_train]
                test_indices = perm[n_train:]

            # if categorical, perform class balancing
            elif trait_type == 'categorical':

                if perform_class_balancing:

                    n_annotated = len(meta_ext_df)
                    n_train = int(n_annotated * training_fraction)
                    n_test = n_annotated - n_train
                    n_train_per_category = n_train // trait_num_categories
                    n_test_per_category = n_test // trait_num_categories

                    train_indices = []
                    test_indices = []

                    for category_index in range(trait_num_categories):

                        # partition the category conditional annotations into disjoint test and train groups
                        n_annotated = len(per_category_indices[category_index])
                        n_train = int(n_annotated * training_fraction)
                        n_test = n_annotated - n_train
                        assert n_train > 0
                        assert n_test > 0

                        perm = rng.permutation(n_annotated)
                        all_train_indices = per_category_indices[category_index][perm[:n_train]]
                        all_test_indices = per_category_indices[category_index][perm[n_train:]]

                        train_indices += rng.choice(
                            all_train_indices, replace=True, size=n_train_per_category).tolist()
                        test_indices += rng.choice(
                            all_test_indices, replace=True, size=n_test_per_category).tolist()

                else:

                    n_annotated = len(meta_ext_df)
                    n_train = int(n_annotated * training_fraction)
                    n_test = n_annotated - n_train
                    perm = rng.permutation(n_annotated)
                    train_indices = perm[:n_train]
                    test_indices = perm[n_train:]

            else:
                raise ValueError

            rng.shuffle(train_indices)
            rng.shuffle(test_indices)

            train_meta_ext_df_dict[(i, k)] = meta_ext_df.iloc[train_indices].copy().reset_index(drop=True)
            test_meta_ext_df_dict[(i, k)] = meta_ext_df.iloc[test_indices].copy().reset_index(drop=True)

    # container for evaluations
    eval_container_dict = dict()
    loss_container_dict = dict()

    for k in range(k_fold):
        for trait_index in range(n_traits):

            # setup
            trait_key = trait_key_list[trait_index]
            trait_type = trait_type_list[trait_index]
            trait_num_categories = trait_num_categories_list[trait_index]
            trait_control = trait_control_list[trait_index]

            train_meta_ext_df = train_meta_ext_df_dict[(trait_index, k)]
            test_meta_ext_df = test_meta_ext_df_dict[(trait_index, k)]

            assert trait_type in {'continuous', 'categorical'}

            print(f'Running GP for {trait_key}, type = {trait_type}, fold = {k}, control = {trait_control}')

            # do we need to censor the train and test data?
            if trait_control is not None:
                train_meta_ext_df = train_meta_ext_df[train_meta_ext_df[trait_control] == 1]
                test_meta_ext_df = test_meta_ext_df[test_meta_ext_df[trait_control] == 1]
                assert len(train_meta_ext_df) > 0
                assert len(test_meta_ext_df) > 0

            train_trait_values_n = torch.tensor(
                train_meta_ext_df[trait_key].values,
                device=device, dtype=dtype)

            test_trait_values_n = torch.tensor(
                test_meta_ext_df[trait_key].values,
                device=device, dtype=dtype)

            print(f'Number of training data points: {len(train_trait_values_n)}')
            print(f'Number of test data points: {len(test_trait_values_n)}')

            # select the corresponding representations
            train_indices = list(map(synapse_ids_to_meta_df_row_idx_map.get, train_meta_ext_df['synapse_id'].values))
            test_indices = list(map(synapse_ids_to_meta_df_row_idx_map.get, test_meta_ext_df['synapse_id'].values))
            train_z_nf = torch.tensor(
                features_nf[train_indices],
                device=device, dtype=dtype)
            test_z_nf = torch.tensor(
                features_nf[test_indices],
                device=device, dtype=dtype)

            ### run GP ##
            
            # initialize the inducing inputs
            x_dim = features_nf.shape[-1]

            # select a subset of synapse representations + random jitter as inducing points
            Xu = torch.tensor(
                features_nf[rng.permutation(len(features_nf))[:n_inducing_points]],
                device=device, dtype=dtype)
            Xu = Xu + z_jitter * torch.randn_like(Xu)

            # set the covariates (X) to the representations
            X = train_z_nf

            # set the readout (y) to the trait
            y = train_trait_values_n

            # initialize the kernel, likelihood, and model
            pyro.clear_param_store()

            if trait_type == 'continuous':
                likelihood = gp.likelihoods.Gaussian(
                    variance=torch.tensor(init_gaussian_variance))
                latent_shape = None

            elif trait_type == 'categorical':
                likelihood = gp.likelihoods.MultiClass(num_classes=trait_num_categories)
                latent_shape = (trait_num_categories,)
            else:
                raise ValueError

            # instantiate the GP model
            if kernel_type == 'rbf':
                rbf_kernel = gp.kernels.RBF(
                    input_dim=x_dim,
                    variance=torch.tensor(init_rbf_variance),
                    lengthscale=torch.tensor(init_rbf_lengthscale))
                kernel = rbf_kernel
                
            elif kernel_type == 'linear':
                linear_kernel = gp.kernels.Linear(
                    input_dim=x_dim,
                    variance=torch.tensor(init_linear_variance))
                constant_kernel = gp.kernels.Constant(
                    input_dim=x_dim,
                    variance=torch.tensor(init_constant_variance))
                kernel = gp.kernels.Sum(linear_kernel, constant_kernel)
                
            elif kernel_type == 'laplace':
                laplace_kernel = gp.kernels.Exponential(
                    input_dim=x_dim,
                    variance=torch.tensor(init_laplace_variance),
                    lengthscale=torch.tensor(init_laplace_lengthscale))
                kernel = laplace_kernel

            else:
                raise ValueError
                
            kernel = kernel.to(device)
            vsgp = gp.models.VariationalSparseGP(
                X, y, kernel,
                Xu=Xu,
                likelihood=likelihood,
                whiten=True,
                jitter=1e-3,
                latent_shape=latent_shape).to(device)

            optimizer = torch.optim.Adam(vsgp.parameters(), lr=lr)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_optim_steps)

            if elbo_type == 'mean-field':
                loss_fn = pyro.infer.TraceMeanField_ELBO().differentiable_loss

            elif elbo_type == 'map':
                loss_fn = pyro.infer.Trace_ELBO().differentiable_loss

            else:
                raise ValueError

            def closure():
                optimizer.zero_grad()
                loss = loss_fn(vsgp.model, vsgp.guide)
                torch_backward(loss)
                return loss

            loss_container_dict[(trait_index, k)] = []
            eval_container_dict[(trait_index, k)] = defaultdict(list)

            for i_iter in range(num_optim_steps):

                # otpimizer step
                loss = optimizer.step(closure)

                # log
                if i_iter % print_loss_every == 0:
                    print(f'iter: {i_iter}, lr: {scheduler.get_last_lr()[0]:.5f}, loss: {torch_item(loss)}')

                # save loss
                loss_container_dict[(trait_index, k)].append((i_iter, torch_item(loss)))

                # scheduler step
                scheduler.step()

                # evaluate
                if i_iter % eval_every == 0:

                    for eval_set in {'train', 'test'}:

                        if eval_set == 'test':
                            X_test = test_z_nf
                            y_test = test_trait_values_n

                        elif eval_set == 'train':
                            X_test = train_z_nf
                            y_test = train_trait_values_n

                        else:
                            raise ValueError

                        with torch.no_grad():
                            y_test_pred_mean, y_test_pred_cov = vsgp(X_test, full_cov=False)
                            y_test_pred_sd = y_test_pred_cov.sqrt()

                        if trait_type == 'continuous':
                            residual_variance = torch.var(y_test_pred_mean - y_test).item()
                            total_variance = torch.var(y_test).item()
                            explained_variance = 1. - residual_variance / total_variance
                            eval_container_dict[(trait_index, k)][f'{eval_set}_explained_variance'].append((i_iter, explained_variance))
                            print(f'\t[{eval_set} eval] explained variance: {explained_variance:3f}')

                        elif trait_type == 'categorical':
                            y_test_pred_soft = torch.softmax(y_test_pred_mean, dim=0).cpu().numpy()
                            y_test_pred_hard = torch.softmax(y_test_pred_mean, dim=0).argmax(dim=0).cpu().numpy()
                            y_test_hard = y_test.type(torch.int).cpu().numpy()

                            # calculate confusion matrix
                            confusion_matrix = np.zeros((trait_num_categories, trait_num_categories))
                            for actual_category, pred_category in zip(y_test_hard, y_test_pred_hard):
                                confusion_matrix[actual_category, pred_category] += 1
                            eval_container_dict[(trait_index, k)][f'{eval_set}_confusion_matrix'].append((i_iter, confusion_matrix))

                            # calculate ROC curve and AUCROC
                            for i_category in range(trait_num_categories):
                                scores = y_test_pred_soft[i_category, :]
                                actual = (y_test_hard == i_category).astype(int)
                                fpr, tpr, threshold = roc_curve(actual, scores)
                                auc = roc_auc_score(actual, scores)
                                eval_container_dict[(trait_index, k)][f'{eval_set}_{i_category}_roc_fpr'].append((i_iter, fpr))
                                eval_container_dict[(trait_index, k)][f'{eval_set}_{i_category}_roc_tpr'].append((i_iter, tpr))
                                eval_container_dict[(trait_index, k)][f'{eval_set}_{i_category}_roc_thresholds'].append((i_iter, threshold))
                                eval_container_dict[(trait_index, k)][f'{eval_set}_{i_category}_roc_auc'].append((i_iter, auc))
                                print(f'\t[{eval_set} eval] category {i_category} AUCROC: {auc:3f}')

                        else:
                            raise ValueError

    # save the results
    if save_results:
        
        os.makedirs(output_root, exist_ok=True)
        output_file_name = f'experiment__{experiment_prefix}__{experiment_index}.pkl'

        with open(os.path.join(output_root, output_file_name), 'wb') as f:
            pickle.dump(manifest, f)
            pickle.dump(eval_container_dict, f)
            pickle.dump(loss_container_dict, f)

Starting experiment 0 ...
{'experiment_prefix': 'sixth_wave', 'checkpoint_path': '/home/jupyter/dev/data/checkpoint__random', 'reload_epoch': 0, 'feature_hook': 'encoder.fc', 'l2_normalize': False, 'k_fold': 3, 'perform_class_balancing': True, 'perform_pca': False, 'n_pca_features': 50, 'n_inducing_points': 5, 'z_jitter': 0.1, 'init_rbf_variance': 1.0, 'init_rbf_lengthscale': 0.5, 'init_gaussian_variance': 0.5, 'init_linear_variance': 1.0, 'init_constant_variance': 1.0, 'kernel_type': 'rbf', 'elbo_type': 'mean-field', 'lr': 0.001, 'num_optim_steps': 10001, 'trait_key_list': ['cleft_size_log1p_zscore', 'presyn_soma_dist_log1p_zscore', 'postsyn_soma_dist_log1p_zscore', 'mito_size_pre_vx_log1p_zscore_zi', 'mito_size_post_vx_log1p_zscore_zi', 'pre_and_post_cell_types', 'pre_cell_type', 'post_cell_type', 'has_mito_pre', 'has_mito_post'], 'trait_type_list': ['continuous', 'continuous', 'continuous', 'continuous', 'continuous', 'categorical', 'categorical', 'categorical', 'categorical', 'cate