# Cross-validation of embedding methods

Perform cross-validation of an embedding method to tune hyperparameters and evaluate the accuracy of classifications based on Euclidean distances of embedded space.

## Analysis outline

1. Load genome sequences from a FASTA file
1. Load clade membership annotations for each genome from a separate metadata file
1. Calculate a pairwise distance matrix from the genome sequences.
1. For each fold in a *k-fold* analysis
  1. Identify sequences in the fold
  1. Subset the distance matrix to only those sequences in the fold
  1. Apply the current embedding method (e.g., MDS, t-SNE, UMAP) to the distance matrix
  1. Calculate the pairwise Euclidean distance between sequences in the embedding
  1. Calculate and store the Pearson's correlation between genetic and Euclidean distances for all pairs in the embedding
  1. Calculate and store a distance threshold below which any pair of sequences are assigned to the same clade
  1. Apply the current embeddng method to the subset of the distance matrix corresponding to the validation data for the current fold
  1. Calculate the pairwise Euclidean distance between sequences in the validation embedding
  1. Assign all pairs of sequences in the validation set to estimated "within" or "between" clade statuses based on their distances
  1. Calculate the confusion matrix from the estimated and observed clade identities
  1. Calculate and store accuracy, Matthew's correlation coefficient, etc. from the confusion matrix
1. Plot the distribution of Pearson's correlations across all *k* folds
1. Plot the distribution of accuracies, etc. across all *k* folds

## Define inputs, outputs, and parameters

In [3]:
sequences_path = "../seasonal-flu-nextstrain/results/variable_sites.fasta"
clades_path = "../seasonal-flu-nextstrain/results/clades.json"

## Imports

In [31]:
from augur.utils import read_node_data
import Bio.SeqIO
from collections import OrderedDict
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.spatial.distance import squareform
from sklearn import svm
from sklearn.manifold import TSNE
from sklearn.model_selection import KFold

%matplotlib inline

## Load genome sequences

In [10]:
sequences_by_name = OrderedDict()

for sequence in Bio.SeqIO.parse(sequences_path, "fasta"):
    sequences_by_name[sequence.id] = str(sequence.seq)
    
sequence_names = list(sequences_by_name.keys())

## Load clade membership annotations

In [5]:
node_data = read_node_data(clades_path)

In [26]:
clade_annotations = pd.DataFrame([
    {"strain": strain, "clade": annotations["clade_membership"]}
    for strain, annotations in node_data["nodes"].items()
    if strain in sequences_by_name
])

In [27]:
len(clade_annotations)

707

In [28]:
clade_annotations.head()

Unnamed: 0,strain,clade
0,A/Alaska/04/2019,3c3.A
1,A/Alaska/35/2019,A1b/197R
2,A/Alaska/38/2019,A1b/197R
3,A/Alaska/46/2018,A2/re
4,A/Alaska/47/2018,A2/re


## Calculate distance matrix

In [None]:
# TODO: define `calculate_distance_matrix` function to make this cell run.
# The function should take an OrderedDict of sequences by name and return
# a distance matrix with rows and columns in the same order.
distance_matrix = calculate_distance_matrix(
    sequences_by_name
)

## Assign sequences to k-fold groups

For each of *k* different folds, partition sequences into training and validation sets.

In [11]:
sequence_names[:5]

['U26830.1',
 'A/Alaska/04/2019',
 'A/Georgia/32/2018',
 'A/Kentucky/35/2018',
 'A/Montana/35/2019']

In [18]:
fold_factory = KFold(n_splits=5, shuffle=True)
folds = fold_factory.split(sequence_names)

## Analyze each fold

For each fold, use the training indices to subset the distance matrix to just those columns and rows that belong in the training data. Apply a given embedding method to the distance matrix subset, identify the classification threshold for clade membership, and validate that threshold on the subset of the distance matrix corresponding to the validation indices.

In [20]:
for k, (train_index, validate_index) in enumerate(folds):
    print(f"fold {k}")
    print(f"Training index: {train_index}")
    print(f"Validation index: {validate_index}")
    print()

fold 0
Training index: [  0   1   2   3   6   7   8   9  11  12  13  14  15  17  18  19  20  21
  22  23  25  26  27  29  30  31  32  33  34  35  37  38  39  40  41  42
  43  44  45  47  48  49  50  51  52  53  54  55  56  57  58  59  60  62
  63  64  65  66  67  69  70  71  73  74  75  76  77  78  79  80  81  83
  84  86  87  88  89  91  92  95  96  97  98  99 100 101 102 105 107 110
 111 112 113 114 115 116 117 118 119 120 123 126 127 128 129 130 131 132
 134 135 136 137 139 140 142 144 145 146 148 149 150 152 153 154 156 157
 158 159 160 161 162 164 165 167 169 170 171 172 173 175 176 177 178 179
 180 183 184 185 186 187 188 190 193 195 196 197 199 200 201 202 203 204
 205 207 208 211 212 213 214 215 216 217 218 220 222 223 224 225 226 227
 228 229 230 232 233 234 235 236 237 238 240 241 242 243 244 245 246 248
 249 250 251 252 253 255 256 257 258 259 260 261 262 264 265 266 268 269
 270 271 272 273 274 275 276 277 278 279 280 281 282 283 285 286 287 288
 290 291 292 294 295 296 297

In [34]:
# TODO: In practice we will want to iterate over all combinations of parameters
# for all embedding methods and also support PCA (which does not use a distance matrix as input).
embedding_class = TSNE
embedding_parameters = {
    "perplexity": 30,
    "learning_rate": 500
}

In [None]:
# TODO: consider adding support for an exhaustive grid search of parameters 
# for different embeddings and the classifier:
# https://scikit-learn.org/stable/modules/grid_search.html#exhaustive-grid-search
accuracies = []
for k, (training_index, validation_index) in enumerate(folds):    
    # Subset distance matrix to training indices.
    training_distance_matrix = distance_matrix[
        training_index,
        training_index
    ]
    
    # Embed training distance matrix.
    embedder = embedding_class(**embedding_parameters)
    training_embedding = embedder.fit_transform(training_distance_matrix)
    
    # Calculate Euclidean distance between pairs of samples in the embedding.
    # The output should be a data frame of distances between pairs.
    training_embedding_distances = calculate_euclidean_distance_for_embedding(training_embedding)
    
    # Assign a binary class to each pair of samples based on their clade memberships.
    # Samples from different clades are assigned 0, samples from the same clade as assigned 1.
    # This vector of binary values will be the output to fit a classifier to.
    # These pairs should be in the same order as the embedding distances above.
    training_clade_status_for_pairs = assign_clade_status_to_pairs(
        training_embedding_distances,
        clade_annotations
    )
    
    # Use a support vector machine classifier to identify an optimal threshold
    # to distinguish between within and between class pairs.
    # See also: https://scikit-learn.org/stable/modules/svm.html#svm-classification
    classifier = svm.LinearSVC()
    classifier.fit(embedding_distances, training_clade_status_for_pairs)
    
    # Subset distance matrix to validation indices.
    validation_distance_matrix = distance_matrix[
        validation_index,
        validation_index
    ]
    
    # Embed validation distance matrix.
    validation_embedding = embedder.fit_transform(validation_distance_matrix)
    
    # Calculate Euclidean distance between pairs of samples in the embedding.
    # The output should be a data frame of distances between pairs.
    validation_embedding_distances = calculate_euclidean_distance_for_embedding(validation_embedding)
    
    # Assign a binary class to each pair of samples based on their clade memberships.
    # Samples from different clades are assigned 0, samples from the same clade as assigned 1.
    # This vector of binary values will be the output to fit a classifier to.
    # These pairs should be in the same order as the embedding distances above.
    validation_clade_status_for_pairs = assign_clade_status_to_pairs(
        validation_embedding_distances,
        clade_annotations
    )

    # Predict and score clade status from embedding distances and the trained classifier.
    # The first argument is the set to predict classifier labels for. The second argument
    # is the list of true labels. The return argument is the mean accuracy of the predicted
    # labels.
    # https://scikit-learn.org/stable/modules/generated/sklearn.svm.LinearSVC.html#sklearn.svm.LinearSVC.score
    accuracy = classifier.score(
        validation_embedding_distances,
        validation_clade_status_for_pairs
    )
    
    accuracies.append(accuracy)

## Plot and summarize accuracies

In [None]:
# Print the mean accuracy and stddev
print(f"Accuracy: {np.mean(accuracies)} +/- {np.std(accuracies)}")

In [33]:
np.linspace(0, 1, 20)

array([0.        , 0.05263158, 0.10526316, 0.15789474, 0.21052632,
       0.26315789, 0.31578947, 0.36842105, 0.42105263, 0.47368421,
       0.52631579, 0.57894737, 0.63157895, 0.68421053, 0.73684211,
       0.78947368, 0.84210526, 0.89473684, 0.94736842, 1.        ])

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
bins = np.arange(0, 1.01, 0.1)

ax.hist(accuracies, bins=bins)

ax.set_xlabel("Accuracy of classifier")
ax.set_ylabel("Number of cross-validation folds")