## Gaussion Process Regression

Trains the optimal GP model as determined by the previous experiments for:
- several training data resamples with different censoring rates (for cross-validation)
- several training data resamples with no censoring ("production run")

In [1]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pylab as plt
import torch

from synapse_utils import io

import pyro
import pyro.contrib.gp as gp
import pyro.distributions as dist

from pyro.infer import TraceMeanField_ELBO
from pyro.infer.util import torch_backward, torch_item

from sklearn.decomposition import PCA
from cuml import KMeans

import pickle

from collections import defaultdict

from sklearn.metrics import roc_curve, roc_auc_score

# assert pyro.__version__.startswith('1.7.0')
pyro.set_rng_seed(0)

## Train and test data

In [2]:
repo_root = '../..'
run_id = 'synapseclr__so3__seed_42__second_stage'
checkpoint_path = f'../../output/checkpoint__{run_id}'
output_root = f'../../output/checkpoint__{run_id}/analysis/gp'

dataset_path = '../../data/MICrONS__L23__8_8_40__processed'
contamination_indices_path = '../../tables/meta_df_contamination_indices.npy'

reload_epoch = 99
node_idx_list = [0, 1, 2, 3]
feature_hook = 'encoder.fc'
l2_normalize = False

device = torch.device('cuda')
dtype = torch.float32

perform_class_balancing = True
perform_pca = False
n_pca_features = 50
k_fold = 1
random_seed = 42
kernel_type = 'rbf'
z_jitter = 0.05
elbo_type = 'mean-field'

# initial kernel parameters
init_gaussian_variance = 0.1
init_rbf_variance = 1.0
init_rbf_lengthscale = 0.5
init_linear_variance = 1.0
init_constant_variance = 1.0
init_laplace_variance = 1.0
init_laplace_lengthscale = 0.5

print_loss_every = 1000
eval_every = 1000

lr = 0.001
num_optim_steps = 20_000 + 1


def get_augmented_table(meta_ext_df: pd.DataFrame) -> pd.DataFrame:
    
    # combined pre and post cell types
    pre_post_cell_types_map = {
        (0, 0): 0,
        (0, 1): 1,
        (1, 0): 2,
        (1, 1): 3,
    }
    
    pre_cell_type_values = meta_ext_df['pre_cell_type'].values
    post_cell_type_values = meta_ext_df['post_cell_type'].values
    pre_post_cell_type_values = np.asarray(list(
        map(pre_post_cell_types_map.get,
            zip(pre_cell_type_values, post_cell_type_values))))
    
    aug_meta_ext_df = meta_ext_df.copy()
    aug_meta_ext_df['pre_and_post_cell_types'] = pre_post_cell_type_values
    
    return aug_meta_ext_df

from typing import Tuple

def get_censored_table(
        meta_ext_df: pd.DataFrame,
        rng: np.random.RandomState,
        censored_fraction: float) -> Tuple[pd.DataFrame, pd.DataFrame]:

    n_entries = len(meta_ext_df)
    n_censored_entries = int(np.ceil(n_entries * censored_fraction))
    perm = rng.permutation(n_entries)
    censored_indices = perm[:n_censored_entries]
    kept_indices = perm[n_censored_entries:]
    
    return (
        meta_ext_df.iloc[kept_indices].copy().reset_index(drop=True),
        meta_ext_df.iloc[censored_indices].copy().reset_index(drop=True)
    )

def generate_manifest(var_dict: dict) -> dict:
    attributes = [
        'experiment_prefix',
        'experiment_desc',
        'experiment_output_root',
        'checkpoint_path',
        'reload_epoch',
        'feature_hook',
        'l2_normalize',
        'k_fold',
        'perform_class_balancing',
        'perform_pca',
        'n_pca_features',
        'z_jitter',
        'init_rbf_variance',
        'init_rbf_lengthscale',
        'init_gaussian_variance',
        'init_linear_variance',
        'init_constant_variance',
        'kernel_type',
        'elbo_type',
        'lr',
        'num_optim_steps',
        'trait_key_list',
        'trait_type_list',
        'trait_num_categories_list',
        'trait_control_list',
        'n_inducing_points_list',
        'censored_fraction',
        'random_seed'
    ]
    manifest = {attribute: var_dict[attribute] for attribute in attributes}
    return manifest

In [3]:
# empty the list
experiment_manifest_list = []

In [4]:
experiment_prefix = 'seventh_wave'
experiment_desc = 'synapse_simclr_consensus'

experiment_output_root = os.path.join(output_root, experiment_desc)
os.makedirs(experiment_output_root, exist_ok=True)

perform_class_balancing = True
perform_pca = False
perform_kmeans = True
feature_hook = 'encoder.fc'
l2_normalize = False
z_jitter = 0.05
training_fraction = 1.0
n_inducing_points = 50
random_seed = 42
kernel_type = 'rbf'
elbo_type = 'mean-field'
num_optim_steps = 20_000 + 1
impute_split_size = 10000
training_fraction = 1.0
censored_fraction = 0.1


trait_key_list = [
    'pre_cell_type',
    'post_cell_type',
]

trait_type_list = [
    'categorical',
    'categorical',
]

trait_num_categories_list = [
    2,
    2,
]

trait_control_list = [
    None,
    None,
]

n_inducing_points_list = [
    100,
    100,
]

n_total = 94874
n_annotations = 5623
labeled_fraction_list = [0.01, 0.02, 0.03, 0.04, 0.05]

for labeled_fraction in labeled_fraction_list:
    for random_seed in [46, 47, 48, 49, 50]:
        censored_fraction = (n_annotations - labeled_fraction * n_total) / n_annotations
        print(f'Fraction of censored data: {censored_fraction:.3f}')
        manifest = generate_manifest(locals())
        experiment_manifest_list.append(manifest)

Fraction of censored data: 0.831
Fraction of censored data: 0.831
Fraction of censored data: 0.831
Fraction of censored data: 0.831
Fraction of censored data: 0.831
Fraction of censored data: 0.663
Fraction of censored data: 0.663
Fraction of censored data: 0.663
Fraction of censored data: 0.663
Fraction of censored data: 0.663
Fraction of censored data: 0.494
Fraction of censored data: 0.494
Fraction of censored data: 0.494
Fraction of censored data: 0.494
Fraction of censored data: 0.494
Fraction of censored data: 0.325
Fraction of censored data: 0.325
Fraction of censored data: 0.325
Fraction of censored data: 0.325
Fraction of censored data: 0.325
Fraction of censored data: 0.156
Fraction of censored data: 0.156
Fraction of censored data: 0.156
Fraction of censored data: 0.156
Fraction of censored data: 0.156


In [5]:
len(experiment_manifest_list)

25

In [6]:
experiment_prefix = 'eighth_wave'
experiment_desc = 'synapse_simclr_production'

experiment_output_root = os.path.join(output_root, experiment_desc)
os.makedirs(experiment_output_root, exist_ok=True)

perform_class_balancing = True
perform_pca = False
perform_kmeans = True
feature_hook = 'encoder.fc'
l2_normalize = False
z_jitter = 0.05
training_fraction = 1.0
n_inducing_points = 50
random_seed = 42
kernel_type = 'rbf'
elbo_type = 'mean-field'
num_optim_steps = 20_000 + 1
impute_split_size = 10000

training_fraction = 1.0
censored_fraction = 0.0

trait_key_list = [
    'cleft_size_log1p_zscore',
    'presyn_soma_dist_log1p_zscore',
    'postsyn_soma_dist_log1p_zscore',
    'mito_size_pre_vx_log1p_zscore_zi',
    'mito_size_post_vx_log1p_zscore_zi',
    'pre_and_post_cell_types',
    'pre_cell_type',
    'post_cell_type',
    'has_mito_pre',
    'has_mito_post'
]

trait_type_list = [
    'continuous',
    'continuous',
    'continuous',
    'continuous',
    'continuous',
    'categorical',
    'categorical',
    'categorical',
    'categorical',
    'categorical'
]

trait_num_categories_list = [
    None,
    None,
    None,
    None,
    None,
    4,
    2,
    2,
    2,
    2
]

trait_control_list = [
    None,
    None,
    None,
    'has_mito_pre',
    'has_mito_post',
    None,
    None,
    None,
    None,
    None
]

n_inducing_points_list = [
    400,
    100,
    200,
    300,
    10,
    100,
    100,
    100,
    200,
    300
]

n_total = 94874
n_annotations = 5623

for random_seed in [40, 41, 42, 43, 45]:
    manifest = generate_manifest(locals())
    experiment_manifest_list.append(manifest)

In [7]:
len(experiment_manifest_list)

30

In [8]:
experiment_prefix = 'tenth_wave'
experiment_desc = 'synapse_simclr_consensus'

experiment_output_root = os.path.join(output_root, experiment_desc)
os.makedirs(experiment_output_root, exist_ok=True)

perform_class_balancing = True
perform_pca = False
perform_kmeans = True
feature_hook = 'encoder.fc'
l2_normalize = False
z_jitter = 0.05
training_fraction = 1.0
n_inducing_points = 50
random_seed = 42
kernel_type = 'rbf'
elbo_type = 'mean-field'
num_optim_steps = 20_000 + 1
impute_split_size = 10000
training_fraction = 1.0
censored_fraction = 0.1


trait_key_list = [
    'pre_cell_type',
    'post_cell_type',
]

trait_type_list = [
    'categorical',
    'categorical',
]

trait_num_categories_list = [
    2,
    2,
]

trait_control_list = [
    None,
    None,
]

n_inducing_points_list = [
    100,
    100,
]

n_total = 94874
n_annotations = 5623
labeled_fraction_list = [0.01, 0.02, 0.03, 0.04, 0.05]

for labeled_fraction in labeled_fraction_list:
    for random_seed in [40, 41, 42, 43, 45]:
        censored_fraction = (n_annotations - labeled_fraction * n_total) / n_annotations
        print(f'Fraction of censored data: {censored_fraction:.3f}')
        manifest = generate_manifest(locals())
        experiment_manifest_list.append(manifest)

Fraction of censored data: 0.831
Fraction of censored data: 0.831
Fraction of censored data: 0.831
Fraction of censored data: 0.831
Fraction of censored data: 0.831
Fraction of censored data: 0.663
Fraction of censored data: 0.663
Fraction of censored data: 0.663
Fraction of censored data: 0.663
Fraction of censored data: 0.663
Fraction of censored data: 0.494
Fraction of censored data: 0.494
Fraction of censored data: 0.494
Fraction of censored data: 0.494
Fraction of censored data: 0.494
Fraction of censored data: 0.325
Fraction of censored data: 0.325
Fraction of censored data: 0.325
Fraction of censored data: 0.325
Fraction of censored data: 0.325
Fraction of censored data: 0.156
Fraction of censored data: 0.156
Fraction of censored data: 0.156
Fraction of censored data: 0.156
Fraction of censored data: 0.156


In [9]:
len(experiment_manifest_list)

55

In [None]:
for experiment_index in range(len(experiment_manifest_list)):
    
    manifest = experiment_manifest_list[experiment_index]

    # set local variables from the manifest
    for key, value in manifest.items():
        setattr(sys.modules[__name__], key, value)
        
    rng = np.random.RandomState(random_seed)
    
    # basic checks
    n_traits = len(trait_key_list)
    assert len(trait_type_list) == n_traits
    assert len(trait_control_list) == n_traits

    # announce
    print(f'Starting experiment {experiment_index} ...')
    print(manifest)
    print()

    # load features
    features_nf, meta_df, meta_ext_df = io.load_features(
        checkpoint_path,
        node_idx_list,
        reload_epoch,
        feature_hook=feature_hook,
        dataset_path=dataset_path,
        l2_normalize=l2_normalize,
        contamination_indices_path=contamination_indices_path)

    # add combined columns to the table (if necessary)
    meta_ext_df = get_augmented_table(meta_ext_df)

    # censor / keep
    meta_ext_df, censored_meta_ext_df = get_censored_table(
        meta_ext_df, rng, censored_fraction)

    synapse_ids_to_meta_ext_df_row_idx_map = {
        synapse_id: row_idx
        for row_idx, synapse_id in enumerate(meta_ext_df['synapse_id'].values)}

    synapse_ids_to_meta_df_row_idx_map = {
        synapse_id: row_idx
        for row_idx, synapse_id in enumerate(meta_df['synapse_id'].values)}

    # pre-processing
    if perform_pca:
        features_nf = PCA(n_pca_features).fit_transform(features_nf)

    train_meta_ext_df_dict = dict()
    test_meta_ext_df_dict = dict()

    for i in range(n_traits):

        trait_key = trait_key_list[i]
        trait_type = trait_type_list[i]
        trait_num_categories = trait_num_categories_list[i]
        trait_control = trait_control_list[i]

        # make a train dataframe
        train_meta_ext_df = meta_ext_df.copy()

        # censor by trait control
        if trait_control is not None:
            train_meta_ext_df = train_meta_ext_df[train_meta_ext_df[trait_control] == 1]

        if trait_type == 'categorical':
            per_category_indices = [
                np.nonzero(train_meta_ext_df[trait_key].values == category_index)[0]
                for category_index in range(trait_num_categories)]
        else:
            per_category_indices = None

        # if continuous, no class balancing is needed
        if trait_type == 'continuous':

            n_annotated = len(train_meta_ext_df)
            n_train = int(n_annotated * training_fraction)
            n_test = n_annotated - n_train
            perm = rng.permutation(n_annotated)
            train_indices = perm[:n_train]
            test_indices = perm[n_train:]

        # if categorical, perform class balancing
        elif trait_type == 'categorical':

            if perform_class_balancing:

                n_annotated = len(train_meta_ext_df)
                n_train = int(n_annotated * training_fraction)
                n_test = n_annotated - n_train
                n_train_per_category = n_train // trait_num_categories
                n_test_per_category = n_test // trait_num_categories

                train_indices = []
                test_indices = []

                for category_index in range(trait_num_categories):

                    # partition the category conditional annotations into disjoint test and train groups
                    n_annotated = len(per_category_indices[category_index])
                    n_train = int(n_annotated * training_fraction)
                    n_test = n_annotated - n_train

                    perm = rng.permutation(n_annotated)
                    all_train_indices = per_category_indices[category_index][perm[:n_train]]
                    all_test_indices = per_category_indices[category_index][perm[n_train:]]

                    train_indices += rng.choice(
                        all_train_indices,
                        replace=True,
                        size=n_train_per_category).tolist()

                    test_indices += rng.choice(
                        all_test_indices,
                        replace=True,
                        size=n_test_per_category).tolist()

            else:

                n_annotated = len(train_meta_ext_df)
                n_train = int(n_annotated * training_fraction)
                n_test = n_annotated - n_train
                perm = rng.permutation(n_annotated)
                train_indices = perm[:n_train]
                test_indices = perm[n_train:]

        else:
            raise ValueError

        rng.shuffle(train_indices)
        rng.shuffle(test_indices)

        train_meta_ext_df_dict[i] = train_meta_ext_df.iloc[train_indices].copy().reset_index(drop=True)
        test_meta_ext_df_dict[i] = train_meta_ext_df.iloc[test_indices].copy().reset_index(drop=True)

    y_pred_dict = dict()

    for trait_index in range(n_traits):

        # setup
        trait_key = trait_key_list[trait_index]
        trait_type = trait_type_list[trait_index]
        trait_num_categories = trait_num_categories_list[trait_index]
        trait_control = trait_control_list[trait_index]

        train_meta_ext_df = train_meta_ext_df_dict[trait_index]
        test_meta_ext_df = test_meta_ext_df_dict[trait_index]

        assert trait_type in {'continuous', 'categorical'}

        print(f'Running GP for {trait_key}, type = {trait_type}, control = {trait_control}')

        train_trait_values_n = torch.tensor(
            train_meta_ext_df[trait_key].values,
            device=device, dtype=dtype)

        test_trait_values_n = torch.tensor(
            test_meta_ext_df[trait_key].values,
            device=device, dtype=dtype)

        print(f'Number of training data points: {len(train_trait_values_n)}')
        print(f'Number of test data points: {len(test_trait_values_n)}')

        # select the corresponding representations
        train_indices = list(map(synapse_ids_to_meta_df_row_idx_map.get, train_meta_ext_df['synapse_id'].values))
        test_indices = list(map(synapse_ids_to_meta_df_row_idx_map.get, test_meta_ext_df['synapse_id'].values))
        train_z_nf = torch.tensor(
            features_nf[train_indices],
            device=device, dtype=dtype)
        test_z_nf = torch.tensor(
            features_nf[test_indices],
            device=device, dtype=dtype)

        ### run GP ##

        # initialize the inducing inputs
        x_dim = features_nf.shape[-1]

        # k-means selection of inducing points
        n_inducing_points = n_inducing_points_list[trait_index]
        print(f'Number of inducing points for {trait_key_list[trait_index]}: {n_inducing_points}')

        if perform_kmeans:
            print('Performing k-means ...')
            Xu_init_kf = KMeans(n_clusters=n_inducing_points).fit(features_nf).cluster_centers_
            print('Done!')

        else:
            Xu_init_kf = torch.tensor(
                features_nf[rng.permutation(len(features_nf))[:n_inducing_points]],
                device=device, dtype=dtype)
            Xu_init_kf = Xu_init_kf + z_jitter * torch.randn_like(Xu_init_kf)
            Xu_init_kf = Xu_init_kf.detach().cpu().numpy()
        
        # select a subset of synapse representations + random jitter as inducing points
        Xu = torch.tensor(Xu_init_kf, device=device, dtype=dtype)

        # set the covariates (X) to the representations
        X = train_z_nf

        # set the readout (y) to the trait
        y = train_trait_values_n

        # initialize the kernel, likelihood, and model
        pyro.clear_param_store()

        if trait_type == 'continuous':
            likelihood = gp.likelihoods.Gaussian(
                variance=torch.tensor(init_gaussian_variance))
            latent_shape = None

        elif trait_type == 'categorical':
            likelihood = gp.likelihoods.MultiClass(num_classes=trait_num_categories)
            latent_shape = (trait_num_categories,)
        else:
            raise ValueError

        # instantiate the GP model
        if kernel_type == 'rbf':
            rbf_kernel = gp.kernels.RBF(
                input_dim=x_dim,
                variance=torch.tensor(init_rbf_variance),
                lengthscale=torch.tensor(init_rbf_lengthscale))
            kernel = rbf_kernel

        elif kernel_type == 'linear':
            linear_kernel = gp.kernels.Linear(
                input_dim=x_dim,
                variance=torch.tensor(init_linear_variance))
            constant_kernel = gp.kernels.Constant(
                input_dim=x_dim,
                variance=torch.tensor(init_constant_variance))
            kernel = gp.kernels.Sum(linear_kernel, constant_kernel)

        elif kernel_type == 'laplace':
            laplace_kernel = gp.kernels.Exponential(
                input_dim=x_dim,
                variance=torch.tensor(init_laplace_variance),
                lengthscale=torch.tensor(init_laplace_lengthscale))
            kernel = laplace_kernel

        else:
            raise ValueError

        kernel = kernel.to(device)
        vsgp = gp.models.VariationalSparseGP(
            X, y, kernel,
            Xu=Xu,
            likelihood=likelihood,
            whiten=True,
            jitter=1e-4,
            latent_shape=latent_shape).to(device)

        optimizer = torch.optim.Adam(vsgp.parameters(), lr=lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_optim_steps)

        if elbo_type == 'mean-field':
            loss_fn = pyro.infer.TraceMeanField_ELBO().differentiable_loss

        elif elbo_type == 'map':
            loss_fn = pyro.infer.Trace_ELBO().differentiable_loss

        else:
            raise ValueError

        def closure():
            optimizer.zero_grad()
            loss = loss_fn(vsgp.model, vsgp.guide)
            torch_backward(loss)
            return loss

        for i_iter in range(num_optim_steps):

            # otpimizer step
            loss = optimizer.step(closure)

            # log
            if i_iter % print_loss_every == 0:
                print(f'iter: {i_iter}, lr: {scheduler.get_last_lr()[0]:.5f}, loss: {torch_item(loss)}')

            # scheduler step
            scheduler.step()

            # evaluate
            if i_iter % eval_every == 0:

                for eval_set in {'train', 'test'}:

                    if eval_set == 'test':
                        X_test = test_z_nf
                        y_test = test_trait_values_n

                    elif eval_set == 'train':
                        X_test = train_z_nf
                        y_test = train_trait_values_n

                    else:
                        raise ValueError

                    if len(X_test) == 0:
                        continue

                    with torch.no_grad():
                        y_test_pred_mean, y_test_pred_cov = vsgp(X_test, full_cov=False)
                        y_test_pred_sd = y_test_pred_cov.sqrt()

                    if trait_type == 'continuous':
                        residual_variance = torch.var(y_test_pred_mean - y_test).item()
                        total_variance = torch.var(y_test).item()
                        explained_variance = 1. - residual_variance / total_variance
                        print(f'\t[{eval_set} eval] explained variance: {explained_variance:3f}')

                    elif trait_type == 'categorical':
                        y_test_pred_soft = torch.softmax(y_test_pred_mean, dim=0).cpu().numpy()
                        y_test_pred_hard = torch.softmax(y_test_pred_mean, dim=0).argmax(dim=0).cpu().numpy()
                        y_test_hard = y_test.type(torch.int).cpu().numpy()

                        # calculate confusion matrix
                        confusion_matrix = np.zeros((trait_num_categories, trait_num_categories))
                        for actual_category, pred_category in zip(y_test_hard, y_test_pred_hard):
                            confusion_matrix[actual_category, pred_category] += 1

                        # calculate ROC curve and AUCROC
                        for i_category in range(trait_num_categories):
                            scores = y_test_pred_soft[i_category, :]
                            actual = (y_test_hard == i_category).astype(int)
                            fpr, tpr, threshold = roc_curve(actual, scores)
                            auc = roc_auc_score(actual, scores)
                            print(f'\t[{eval_set} eval] category {i_category} AUCROC: {auc:3f}')

                    else:
                        raise ValueError

        # impute!
        y_pred_mean_list = []
        y_pred_std_list = []

        print("Imputing ...")
        for split_features_nf in torch.split(torch.tensor(features_nf), impute_split_size):

            with torch.no_grad():
                y_pred_mean, y_pred_cov = vsgp(
                    split_features_nf.to(device).type(dtype),
                    full_cov=False)
                y_pred_std = y_pred_cov.sqrt()

                if trait_type == 'continuous':
                    y_pred_mean_list.append(y_pred_mean.cpu().numpy())
                    y_pred_std_list.append(y_pred_std.cpu().numpy())

                elif trait_type == 'categorical':
                    y_pred_mean_list.append(y_pred_mean.cpu().numpy().T)
                    y_pred_std_list.append(y_pred_std.cpu().numpy().T)

                else:
                    raise ValueError

        y_pred_mean = np.concatenate(y_pred_mean_list, axis=0)
        y_pred_std = np.concatenate(y_pred_std_list, axis=0)

        if trait_type == 'continuous':
            y_pred_dict[f'imputed__{trait_key}__mean'] = y_pred_mean
            y_pred_dict[f'imputed__{trait_key}__std'] = y_pred_std

        elif trait_type == 'categorical':
            y_pred_mean = torch.softmax(torch.tensor(y_pred_mean), dim=-1).cpu().numpy()
            for i_category in range(trait_num_categories):
                y_pred_dict[f'imputed__{trait_key}__class_{i_category}'] = y_pred_mean[:, i_category]

        else:
            raise ValueError

        print('Done!')

    imputed_meta_df = meta_df.copy()

    for k, v in y_pred_dict.items():
        imputed_meta_df[k] = v

    imputed_meta_df = imputed_meta_df.drop(
        ['n_cutout_sections',
         'filename',
         'post_synaptic_volume',
         'pre_synaptic_volume',
         'synaptic_cleft_volume'], axis=1)

    imputed_meta_df.to_csv(
        os.path.join(experiment_output_root, f'imputed_meta__{kernel_type}__{n_inducing_points}__c={censored_fraction:.3f}__s={random_seed}.csv'))
    meta_ext_df.to_csv(
        os.path.join(experiment_output_root, f'training_meta_ext__{kernel_type}__{n_inducing_points}__c={censored_fraction:.3f}__s={random_seed}.csv'))
    censored_meta_ext_df.to_csv(
        os.path.join(experiment_output_root, f'censored_meta_ext__{kernel_type}__{n_inducing_points}__c={censored_fraction:.3f}__s={random_seed}.csv'))