# Evomics Machine Learning Workshop
The goal of today’s workshop is to give you a flavor of what you can do with machine learning, introduce of few very useful algorithms, and eventually apply this to something we care about– evolutionary genetics.

Presumably you already understand the context here: our basic goal is to use existing data (a so-called training set, more on this in a moment) to make predictions about new data that we encounter in the work. Traditionally the way to make such predictions was to rely on what we call a *generative* model, i.e. a probabilitistic model that describes the process from which our observed data was generated. With such a generative model in hand, we could then take new data and learn things like the values of parameters from the model which produced the observations. But what if we are unsure of the model? Can we still make useful predictions? The answer is yes and one very popular way is through machine learning.

In this workshop we will focus on a branch of machine learning (ML) called supervised ML. Supervised ML starts with a *labeled* or known dataset that we will call the training set. This training set allows us to teach an algorithm about how independent variables (call them $X$) map to dependent variables ($y$). As such we will focus on the *conditional probability* $p(y | X)$ and the models we will use are said to be *discriminative* models. The key here is that the algorithms that we will train with our training set will focus directly on the mapping from $X \rightarrow y$ rather than on the structure of the model per se. This has proven extremely useful in practice, and high prediction accuracies can be acheived for very complex problems-- problems that are often to complex to write down a full blown probabilistic model. So let's get cracking.


## Anderson's iris data 
For a first example lets use Anderson's classic iris dataset. This is a classic dataset because R. A. Fisher used these data in his 1936 paper entitled *The use of multiple measurements in taxonomic problems*. In that paper Fisher introduced a method called linear discriminant analysis (LDA) that is considered by some to be the first machine learning method. This dataset consists of 4 morphological measurements from 150 flowers belonging to three species, *Iris setosa*, *Iris virginica*, and *Iris versicolor*. 
Let's take a quick look at the plants 

# 
# <div style="margin: 20px 0;">
#     <img src=imgs/Iris_setosa.jpg style="margin-right: 20px;" align="left" />
#     <img src=imgs/Iris_virginica.jpg style="margin-right: 20px;" align="left" />
#     <img src=imgs/Iris_versicolor.jpg align="left" />
# </div>
# <div style="clear: both; margin-bottom: 20px;"></div>
    
<div style="text-align: center; font-style: italic; margin-bottom: 20px;">
From left to right: I. setosa (with its distinctive narrow petals), I. virginica (with larger, broader petals), and I. versicolor (with intermediate-sized petals).
</div>

These three iris species look quite similar to the untrained eye. Let's examine the data by first looking at a summary, then visualizing a subset. We'll use the `sklearn` package, which conveniently includes this classic iris dataset, for our analysis.


now let's import our data and get moving

In [None]:
import sklearn
from sklearn import datasets
import pandas as pd
from matplotlib import pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.model_selection import train_test_split
import h5py

iris = datasets.load_iris()
df_iris = pd.DataFrame(iris.data,columns=iris.feature_names)
df_iris['species'] = iris.target_names[iris.target]

# Take a look at df_iris
df_iris.head()

A really useful thing to do is to create a pairplot to examine everything. We can do this with the `sns.pairplot` function.


In [None]:
sns.pairplot(df_iris, hue='species')

So in comparing these two *features*, sepal length and sepal width, the species don't look perfectly seperable (i.e. we can't draw firm dividing lines between groups). That's ok, we're gonna solve this problem with ML!

## Classifying irises using support vector machines

The first ML algorithm we will look at are called Support Vector Machines (SVMs). SVMs were first developed by Vladimir Vapnik and Alexey Chervonenkis in 1963, but only became really popular in the mid 1990s after a few modifications to the original method were made and computing power had advanced sufficiently. The name of the game for SVMs is to draw an optiminally separately hyperplane between classes identified in training data. Such a hyperplane would linearly distinguish among groups in the training set such that new predictions of unlabelled data could be made. Multiple embellishments on the vanilla SVM have been made over the years which allow nonlinear discrimination and even potentially mislabelled training examples (i.e. soft margins). We will use the iris dataset to try to classify individual datapoints into their representative species.

In this case we will start by using the complete iris dataset as our training set. First we will fit a linear SVM to it and then visualize the decision surface. We will do this at first with only the sepal features (sepal length and width) and then will include the petal features later.

First let me import some helpful stuff

In [16]:
from sklearn import svm
from sklearn.datasets import make_blobs

# bring in some helper functions
def make_meshgrid(x, y, h=.02):
    """Create a mesh of points to plot in

    Parameters
    ----------
    x: data to base x-axis meshgrid on
    y: data to base y-axis meshgrid on
    h: stepsize for meshgrid, optional

    Returns
    -------
    xx, yy : ndarray
    """
    x_min, x_max = x.min() - 1, x.max() + 1
    y_min, y_max = y.min() - 1, y.max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                         np.arange(y_min, y_max, h))
    return xx, yy


def plot_contours(ax, clf, xx, yy, **params):
    """Plot the decision boundaries for a classifier.

    Parameters
    ----------
    ax: matplotlib axes object
    clf: a classifier
    xx: meshgrid ndarray
    yy: meshgrid ndarray
    params: dictionary of params to pass to contourf, optional
    """
    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    out = ax.contourf(xx, yy, Z, **params)
    return out

Now let's work with the iris data

In [None]:
# Take the first two features
X = iris.data[:, :2]
y = iris.target # target here is 0, 1, 2



#Support Vector Machine
C = 100.0  # SVM regularization parameter
clf = svm.SVC(kernel='linear', C=C)
         
model = clf.fit(X, y)


X0, X1 = X[:, 0], X[:, 1]
xx, yy = make_meshgrid(X0, X1)

ax = plt.gca()
plot_contours(ax, clf, xx, yy,
                  cmap=plt.cm.coolwarm, alpha=0.8)

ax.scatter(X0, X1, c=y, cmap=plt.cm.coolwarm, s=20, edgecolors='k')

ax.set_xlim(xx.min(), xx.max())
ax.set_ylim(yy.min(), yy.max())
ax.set_xlabel('Sepal length')
ax.set_ylabel('Sepal width')
ax.set_xticks(())
ax.set_yticks(())
ax.set_title("linear SVM")

So above we have plotted the same data as before, but now having fit a linear SVM to the training set. You can see the data as they fall along the linear, separating hyperplane among classes.

So we can see that while setosa is easier to classifiy, versicolor and virginica have a bunch of overlap in the SVM we just trained. Before we move on let's quantify the accuracy of our SVM by making it do prediction. For the sake of brevity we will just do prediction from the training set we already used. 

In [None]:
from sklearn.metrics import confusion_matrix
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
preds = clf.predict(X)
cm = confusion_matrix(y, preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=df_iris.species.unique())
disp.plot()

The table we just output is often called the *confusion matrix* in the ML work. On columns we have the true value of the individual data points, on the rows the predicted value. We can see that 98% of setosa examples are correctly classified, but only a much smaller percentage of versicolor and virginica examples are correctly classified. This fits with what we saw on our decision surfaces.

### Moar features! Moar kernels!

To improve our accuracy a bit lets first use all of the features of the iris dataset, both the sepal measurements and the petal measurements. 

In [None]:
# get all the features
X = iris.data
y = iris.target # target here is 0, 1, 2



#Support Vector Machine
C = 100.0  # SVM regularization parameter
clf = svm.SVC(kernel='linear', C=C)
model = clf.fit(X, y)

preds = clf.predict(X)
cm = confusion_matrix(y, preds, normalize='true')
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=df_iris.species.unique())
disp.plot()


So including the petal measurements improved things quite a bit. We are now seeing classification accuracies of 96% and 98% for versicolor and virginica respectively. 

Can we do even better? Well there isn't much room for improvement but a simple thing we can try is moving from a linear hyperplane for the SVM to a nonlinear decision surface. In the context of SVMs this can be done by changing what is called the *kernel*. Without getting into the maths behind this, the kernel creates an implicit mapping of the input features to a different coordinate space, one that may allow easier separation of the classes in the data. The most popular choice for a nonlinear kernel in the context of SVMs is the radial basis function kernel. Let's try it out for our classification problem.

In [None]:
#Support Vector Machine
clf = svm.SVC(kernel='rbf', C=1, gamma=10) # this version takes two parameters
model = clf.fit(X, y)

preds = clf.predict(X)
cm = confusion_matrix(y, preds, normalize='true')
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=df_iris.species.unique())
disp.plot()
 

So we've gone from a 2% misclassification rate to 0% with the change in kernel. Not bad. At this point I'd encourage you to play around with the other kernels that the `sklearn` SVM implementation has, for instance the sigmoid kernel. Does it do better or worse?

Let's quickly look at those decision surface visualizations that we created earlier. Note at the outset that we created the earlier ones based on a SVM that we trained with two features: sepal length and sepal width. Let's step back to that two feature input but visualize what a decision surface looks like using the radial basis function kernel. I'm going to turn up a parameter called gamma which basically controls how wiggly our kernel will get.

In [None]:

X = iris.data[:, :2]
y = iris.target # target here is 0, 1, 2

#Support Vector Machine
clf = svm.SVC(kernel='rbf', C=1, gamma=10) # this version takes two parameters
model = clf.fit(X, y)

X0, X1 = X[:, 0], X[:, 1]
xx, yy = make_meshgrid(X0, X1)

ax = plt.gca()
plot_contours(ax, clf, xx, yy, cmap=plt.cm.coolwarm, alpha=0.8)



ax.scatter(X0, X1, c=y, cmap=plt.cm.coolwarm, s=20, edgecolors='k')

ax.set_xlim(xx.min(), xx.max())
ax.set_ylim(yy.min(), yy.max())
ax.set_xlabel('Sepal length')
ax.set_ylabel('Sepal width')
ax.set_xticks(())
ax.set_yticks(())
ax.set_title("radial basis SVM")

So we can see now that our decision surface is altered in comparison to earlier. In particular with the RBF kernel and high gamma the SVM is doing a bit of craziness in trying to define the decision surface. 

### Lies and damn lies

So far, and mostly out of laziness, we've been using the entire Iris dataset for training AND testing. This is bad ML practice, because by testing our performance on the entire training set we are getting an over estimate of how well we are doing. So let's do the right thing shall we? Let's split our dataset into one half for training and the other half for testing the performance of our trained classifier. 

In [122]:
# test train split
from sklearn.model_selection import train_test_split

x_train, x_test, y_train, y_test = train_test_split(
    iris.data, iris.target, test_size=0.2, random_state=666)

We now have a *balanced* training and testing sets. They are said to be balanced in that equal numbers of examples from each class are present in the datasets. This is a key component to training a good classifier in that an unbalanced trainingset can yield a falsely accurate classifier that simply guesses the most frequent class of example present. 

Next let's retrain our svm on the training split of our data and then look at it's accuracy on the test set

In [None]:
#Support Vector Machine
clf = svm.SVC(kernel='rbf', C=1, gamma=10) # this version takes two parameters
model = clf.fit(x_train, y_train)

preds = clf.predict(x_test)
cm = confusion_matrix(y_test, preds, normalize='true')
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=df_iris.species.unique())
disp.plot()


As you can see, doing the proper thing of training with one part of our data and then evaluating our fit on another has affected our performance, but just a little. Seems like our SVM classifier is doing quite well on these data!

### Using an SVM to classify genomes to country or origin
Now let's try to apply these ideas to genome data. Let's use that arabidopsis dataset we had been messing with before and bring in some meta data that I have lying around that will tell us about where the samples came from. 

Our goal here will be to treat country of origin of the plant as our target and its associate genotype (genome sequence) as the data. 

In [None]:
f = h5py.File("data/araTha.hdf5", 'r')
# metadata
meta = pd.read_csv("data/araTha_meta.csv")
meta.index = meta.pk

# data clean up here
#
countries = {
    'GER':'Germany',
    'US':'United States',
    'UK':'United Kingdom',
    'POR':'Portugal',
    'LIB':'Libya',
    'SUI':'Switzerland',
    'NED':'Netherlands',
    'DEN':'Denmark',
    'GRE':'Greece',
    'BUL':'Bulgaria',
    'CRO':'Croatia'
}
meta.country = meta.country.replace(countries)
meta.head()

## Where are the samples from?
let's quickly take a look at where these samples were collected

In [None]:
import folium
m = folium.Map(location=[0, 40], zoom_start=1)
for index, row in meta.iterrows():
    if not np.isnan(row.longitude):
        # print([row.longitude, row.latitude])
        folium.Marker([row.latitude, row.longitude],popup=row.country).add_to(m)
m

so there are many countries that have very few samples. machine learning won't work in this situtation-- indeed we rely on have *big data* i.e, lots of training examples for each class so that our algorithm can find the decision surface to an adequate degree.

Here's what the counts look like

In [None]:
meta.country.value_counts()

so let's filter this to the top 6 countries and only keep those

In [127]:
keep_list = list(meta.country.value_counts()[:6].index)
top6 = meta[meta.country.isin(keep_list)]
keep_dict = {k: v for v, k in enumerate(keep_list)}

next let's match up the genotypes with the metadata, we'll use the `hdf5` file we elluded to earlier. The basic workflow here is that we will load the genotype data into a pandas dataframe, then we will match up the samples in the genotype data with the samples in the metadata. We further thin the genotype data to every 10th SNP as I want it to run fast.

In [None]:
geno_group = f['genotype']
# artifically thinning to every 10th SNP as I want it to run fast
# for class
thin = 10
chromosomes = geno_group['col_header']['chrom'][::thin]
positions = geno_group['col_header']['pos'][::thin]
geno_df = pd.DataFrame(geno_group['matrix'][:,::thin], columns=positions, 
                       index=geno_group['row_header']['sample_ID'][:],
                      dtype='float64')
print(f"shape of geno_df: {geno_df.shape}")

# get intersection index array
sample_idx = geno_df.index.intersection(top6.index)
print(f"shape of intersection is {sample_idx.shape}")
snps = geno_df.loc[sample_idx]
print(f"shape of snp matrix {snps.shape}")
geno_df.head()

next let's create the targets for our classifier. We will use the country of origin for each sample as the target, again matching up the samples in the genotype data with the samples in the metadata in our restricted set.

In [None]:
country_targets = top6.country.loc[sample_idx]
country_targets.value_counts()

okay now to use machine learning I'm going to have to numerically encode these country labels. I'm going to use that convenient dictionary I created above

In [132]:
number_targets = [keep_dict[x] for x in country_targets]

okay so I've got my targets and I've got my matrix of SNPs-- let's make the machines learn!

In [None]:
# test train split
from sklearn.model_selection import train_test_split

# train test split

x_train, x_test, y_train, y_test = train_test_split(
    snps, number_targets, test_size=0.2, random_state=666)

clf = svm.SVC(kernel='rbf', C=1,) # this version takes two parameters
model = clf.fit(x_train, y_train)

preds = clf.predict(x_test)


preds = clf.predict(x_test)
cm = confusion_matrix(y_test, preds, normalize='true')
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=country_targets.unique())
disp.plot()

plt.xticks(rotation=70)


# Precision Recall curves
A great way to characterize the performance of a classifier is with precision and recall. So a couple of definitions

Precision = # true positives / (# true positives + # false positives)

Recall = # true positives / (# true positives + # false negatives)

recall can also be thought of as the true positive rate. We want classifiers to have high precision *and* high recall. `sklearn` will calculate all of this for us from our classifier

In [None]:
preds = clf.predict(x_test)
print(sklearn.metrics.classification_report(y_test, preds, target_names=keep_list))


so we can wee that we really do have high precision and recall in most cases. This output also gives the so-called F1 score, which is simply the harmonic mean between precision and recall. 

Let's plot our Precision Recall curve for this classifier. Note the code below is a bit involved, but I want to just focus on the result. 

(note this code block is a bit slow to run ~3 minutes)

In [None]:
from sklearn.preprocessing import label_binarize
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import average_precision_score


# Use label_binarize to be multi-label like settings
Y = label_binarize(number_targets, classes=list(range(6)))
n_classes = Y.shape[1]

# Split into training and test
x_train, x_test, y_train, y_test = train_test_split(snps, Y, test_size=0.2,
                                                    random_state=666)

# We use OneVsRestClassifier for multi-label prediction
from sklearn.multiclass import OneVsRestClassifier

# Run linear SVM
classifier = OneVsRestClassifier(svm.SVC(kernel='rbf', C=1))
classifier.fit(x_train, y_train)
y_score = classifier.decision_function(x_test)



# For each class
precision = dict()
recall = dict()
average_precision = dict()
for i in range(n_classes):
    precision[i], recall[i], _ = precision_recall_curve(y_test[:, i],
                                                        y_score[:, i])
    average_precision[i] = average_precision_score(y_test[:, i], y_score[:, i])

# A "micro-average": quantifying score on all classes jointly
precision["micro"], recall["micro"], _ = precision_recall_curve(y_test.ravel(),
    y_score.ravel())
average_precision["micro"] = average_precision_score(y_test, y_score,
                                                     average="micro")

plt.figure()
plt.step(recall['micro'], precision['micro'], where='post')

plt.xlabel('Recall')
plt.ylabel('Precision')
plt.ylim([0.0, 1.05])
plt.xlim([0.0, 1.0])
plt.title(
    f'Average precision score, micro-averaged over all classes: AP={round(average_precision["micro"],2)}')

### NICE
so this classifier looks very good from a Precision Recall perspective. We have high precision at all recalls basically. Classifying Arabidopsis genomes turns out to be a decently easy problem for SVMs

# Random Forest approaches
A completely different approach to supervised machine learning is the Random Forest. Rather than focus on one, good predictor, RFs borrow strength among a number of so-called weak learners and combine those in an ensemble to make one good prediction. As we say in the lecture slides these weak learners are individual, randomized decision trees. We can fit RF models on the SNP data above quite easily

In [None]:
from sklearn.ensemble import RandomForestClassifier

x_train, x_test, y_train, y_test = train_test_split(
    snps, number_targets, test_size=0.2, random_state=666)

clf = RandomForestClassifier()
model = clf.fit(x_train, y_train)

preds = clf.predict(x_test)

cm = pd.DataFrame(confusion_matrix(y_test, preds, normalize='true'), columns=keep_list, index=keep_list)
sns.heatmap(cm, annot=True,cmap="plasma")
preds = clf.predict(x_test)
print(sklearn.metrics.classification_report(y_test, preds, target_names=keep_list))

so RFs in the case are doing slightly worse than SVMs. Can we fix this up a bit by changing the number of trees in our ensemble?

I'll try changing two things-- the number of trees we use in the forest and the metric for splitting trees.

In [None]:
r = []
nt = [100,200,500,1000,2000,5000]
for n in nt:
    clf = RandomForestClassifier(n_estimators=n)
    model = clf.fit(x_train, y_train)
    preds = clf.predict(x_test)
    r.append(sklearn.metrics.precision_recall_fscore_support(y_test, preds, average='weighted')[2])
r2 = []
for n in nt:
   # print(n)
    clf = RandomForestClassifier(n_estimators=n, criterion='entropy')
    model = clf.fit(x_train, y_train)
    preds = clf.predict(x_test)
    r2.append(sklearn.metrics.precision_recall_fscore_support(y_test, preds, average='weighted')[2])
plt.plot(nt,r)
plt.plot(nt,r2)
plt.legend(['gini','entropy'])
plt.xlabel("number of trees")
plt.ylabel("F1-score")

so in this case it looks like Gini impurity is doing a bit better than entropy, and that 2000 trees is a good number. Let's try to use this model to classify our Arabidopsis genomes.

In [None]:
from sklearn.ensemble import RandomForestClassifier

x_train, x_test, y_train, y_test = train_test_split(
    snps, number_targets, test_size=0.2, random_state=666
)

clf = RandomForestClassifier(n_estimators=2000, criterion="gini")
model = clf.fit(x_train, y_train)

preds = clf.predict(x_test)

cm = pd.DataFrame(
    confusion_matrix(y_test, preds, normalize="true"),
    columns=keep_list,
    index=keep_list,
)
sns.heatmap(cm, annot=True, cmap="plasma")
preds = clf.predict(x_test)
print(sklearn.metrics.classification_report(y_test, preds, target_names=keep_list))

## Which SNPs are important? Feature importance

Which a RF is a tree based model, we can get a sense of which SNPs are important by looking at the feature importance scores.
These scores are calculated by looking at how much the tree splits on a given SNP and then averaging over all of the trees in the forest.
They give us a sense of which SNPs (or more generally features) are important for the classification task.

We can get these scores from the `feature_importances_` attribute of the `RandomForestClassifier` object.




In [None]:
# get the feature importance scores
importances = model.feature_importances_

# plot the feature importance scores as a bar plot; add a histogram of the feature importance scores in a separate plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
ax1.bar(range(len(importances)), importances)
ax1.set_xlabel("SNP")
ax1.set_ylabel("Feature Importance")
ax2.hist(importances, bins=20)
ax2.set_xlabel("Feature Importance") 
ax2.set_ylabel("Count")
# rotate the x-axis labels
plt.xticks(rotation=70)
plt.tight_layout()
plt.show()


so really, very few SNPs are important for this classification task. This is a good thing, as it means that the RF is not simply memorizing the data, but rather is using the underlying structure of the data to make predictions.

Let's look at the relationship between feature importance and frequency in the matrix.

In [None]:
# top 10 most important SNPs
top10 = np.argsort(importances)[-10:]
# plot the top 10 most important SNPs versus frequency in the matrix
freqs = snps.sum(axis=0)
plt.scatter(importances, freqs)
plt.xlabel("Feature Importance")
plt.ylabel("Frequency")
plt.show()


Cool! So we can see that the most important SNPs are actually at intermediate frequencies. This makes sense, as the RF is using the underlying structure of the data to make predictions, and so it is important to have SNPs that are at intermediate frequencies, so that the RF can learn the underlying structure of the data.


------------------------------------

# Working with simulated training sets -- a population genetics classification problem

All of the examples above have been using empirical data (i.e. data that come from the real world and have real world labels). However in evolutionary biology, and especially  population genetics, it's oftern impossible to get ground truth labels for our data. Consider for instance the case of "knowing" the population history of a species-- we can't go back in time and know the truth. In this case we can use simulated data to train our classifier! The idea, generally speaking, is that we can simulate data from a variety of models and then use the simulated data to train our classifier. We can then use the classifier to predict the labels of new, unseen data. 

One of the classic tasks in population genetics is to use genetic variation data to test among competing 
demographic hypotheses. For instance Tajima’s famous D statistic was originally proferred as a way to look
for population expansion (or contraction) through a frequentist based hypothesis test. 
Rather than do hypothesis testing, we can frame this same sort of task as machine learning classification. In this section we will do exactly that– we will build a RF classifier to distinguish between a population with static population size (sometimes called an equilibrium population) and a population that has undergone population growth. Rather than train our classifier using empirical data from populations, we will be training using coalescent simulations. 

To do this we will need to simulate some data. We will use the `stdpopsim` package to do this, which is a package that allows us to simulate data from a variety of organisms/demographic models. You can read more about it [here](https://stdpopsim.readthedocs.io/en/stable/). 

For our purposes we will simulate data from the *Homo sapiens* species. We will use the `stdpopsim` package to simulate data from a population that has undergone a recent population expansion as well as a static population. We will then use the `tskit` package to load the simulated data into a tree sequence, and calulated summary statistics from the tree sequence which will serve as features for our classifier. 

## Simulation aside -- the joys of `stdpopsim`

stdpopsim is a really useful package for simulating data from a variety of organisms. It is built on top of the `msprime` package, which is a really fast coalescent simulator and can simulate data for a wide variety of demographic models with empirical recombination maps, realistic mutation rates, and more. 

The first thing we do in the standard `stdpopsim` workflow is to get the species we want to simulate data from. Here we will get the *Homo sapiens* species object. We will then get a list of all of the demographic models that are available for this species. 

In [None]:
import stdpopsim

species = stdpopsim.get_species("HomSap")
for x in species.demographic_models:
    print(x.id)

okay there are a lot of models here. We will use the model `Africa_1T12` model, which is for one population that has changed population size quite dramatically. We can plot what this model looks like with the `plot_demography` function.

In [None]:
import demes
import demesdraw

# Get the Arabidopsis thaliana species and the African model
species = stdpopsim.get_species("HomSap")
model = species.get_demographic_model("Africa_1T12")

# Get the demes graph directly from the model
graph = model.model.to_demes()
demesdraw.tubes(model.model.to_demes())


Next lets simulate some data from this model. We will simulate 10 diploid individuals from the population, with 1e6 sites. We will then use the `tskit` package to load the simulated data into a tree sequence, and calulated summary statistics from the tree sequence which will serve as features for our classifier. 

In [None]:
# Get the Arabidopsis thaliana species and the African model
species = stdpopsim.get_species("HomSap")
model = species.get_demographic_model("Africa_1T12")
contig = species.get_contig(
    "chr1", 
    left=19e6, 
    right=20e6, 
    mutation_rate=model.mutation_rate
)
samples = {"AFR": 10}
engine = stdpopsim.get_engine("msprime")
ts = engine.simulate(model, contig, samples)
ts

this yield a `TreeSequence` object that we can do all sorts of things with. We can plot the tree sequence, examine it's nodes and edges, and more. Here we will just calculate some summary statistics from the tree sequence. 

In [None]:
# Calculate these statistics in windows
# Let's use 10 windows across the sequence
windows = np.linspace(0, ts.sequence_length, num=11)
diversity_windows = ts.diversity(windows=windows)
tajimas_d_windows = ts.Tajimas_D(windows=windows)
segregating_sites_windows = ts.segregating_sites(windows=windows)

# concatenate these into a single array
stats = np.concatenate([diversity_windows, tajimas_d_windows, segregating_sites_windows])
stats


In [None]:
# let's replicate this for 1000 simulations
sims = []
for i in range(1000):
    ts = engine.simulate(model, contig, samples)
    windows = np.linspace(0, ts.sequence_length, num=11)
    diversity_windows = ts.diversity(windows=windows)
    tajimas_d_windows = ts.Tajimas_D(windows=windows)
    segregating_sites_windows = ts.segregating_sites(windows=windows)
    stats = np.concatenate([diversity_windows, tajimas_d_windows, segregating_sites_windows])
    sims.append(stats)
sims = np.array(sims)
sims.shape


Now let's do the same thing for the static population size model. First we will set up the model and then plot the demes graph side by side with the growth model. 

In [None]:
# set up the same simulation, but with the static population size model
model_static = stdpopsim.PiecewiseConstantSize(species.population_size)
samples = {"pop_0": 10}  # this is for the static population size model,

from matplotlib import pyplot as plt

# side by side plot of the two models   
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
# Get the demes graph directly from the model
graph = model_static.model.to_demes()
demesdraw.tubes(model_static.model.to_demes(), ax=ax1)
graph = model.model.to_demes()
demesdraw.tubes(model.model.to_demes(), ax=ax2)
# add a title to each plot
ax1.set_title("Static Population Size")
ax2.set_title("Population Growth")
plt.show()


now simulate 1000 replicates of the static population size model and record the summary statistics. 

In [None]:
sims_static = []
for i in range(1000):
    ts = engine.simulate(model_static, contig, samples)
    windows = np.linspace(0, ts.sequence_length, num=11)
    diversity_windows = ts.diversity(windows=windows)
    tajimas_d_windows = ts.Tajimas_D(windows=windows)
    segregating_sites_windows = ts.segregating_sites(windows=windows)
    stats = np.concatenate(
        [diversity_windows, tajimas_d_windows, segregating_sites_windows]
    )
    sims_static.append(stats)
sims_static = np.array(sims_static)
sims_static.shape

In [None]:
# let's plot the distribution of the statistics for the two models
# we can use seaborn to make this look nicer
import seaborn as sns
import matplotlib.pyplot as plt

# create a figure with two subplots side by side
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
# plot the diversity distribution in the first window
sns.histplot(sims[:,0], label="growth", alpha=0.5, ax=ax1)
sns.histplot(sims_static[:,0], label="static", alpha=0.5, ax=ax1)
ax1.set_title("Diversity")
ax1.legend()

# plot the Tajima's D distribution in the first window
sns.histplot(sims[:,11], label="growth", alpha=0.5, ax=ax2)
sns.histplot(sims_static[:,11], label="static", alpha=0.5, ax=ax2)
ax2.set_title("Tajima's D")
ax2.legend()




so obviously the growth model has a much more skewed distribution of the statistics. We can use these statistics as features for our classifier!

first we need to to create a label for our data. We will use the `growth` label for the growth model and the `static` label for the static model.

In [None]:
# create a label for our data, 0 for growth and 1 for static
labels = np.concatenate([np.zeros(1000), np.ones(1000)])
labels

# now lets concatenate our statistics into one array
data = np.concatenate([sims, sims_static])
data.shape

# let's split this into training and testing data
x_train, x_test, y_train, y_test = train_test_split(data, labels, test_size=0.2, random_state=666)

# finally let's examine the shapes of our training and testing data
print(f"training data shape: {x_train.shape}")
print(f"testing data shape: {x_test.shape}")
print(f"training labels shape: {y_train.shape}")
print(f"testing labels shape: {y_test.shape}")


great! we have our training and testing data. Now we can use a random forest classifier to classify the data. 

In [None]:
# fit a random forest classifier
import sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

clf = RandomForestClassifier()
model = clf.fit(x_train, y_train)

# make predictions on the testing data
preds = clf.predict(x_test)

# evaluate the classifier
print(sklearn.metrics.classification_report(y_test, preds))

# let's plot the confusion matrix
cm = confusion_matrix(y_test, preds, normalize='true')
sns.heatmap(cm, annot=True, cmap="plasma", xticklabels=["growth", "static"], yticklabels=["growth", "static"])

# 

obviously that worked well! We can see that the classifier is able to distinguish between the two models with high precision and recall. 

# A regression problem

Let's move on to a repression problem. In this case we aren't interested in classifying the data, but rather in predicting a continuous variable. We will use a similar set up as above, relying on `stdpopsim` to simulate data from a two population, isolatioin with migration model. We will aim to predict the time of the split between the two populations.

In [None]:
# isolation with migration model

species = stdpopsim.get_species("HomSap")
samples = {"pop1": 10, "pop2": 10}

# here's how we set up the IM model
model = stdpopsim.IsolationWithMigration(
    NA=10000, # ancestral population size
    N1=10000, # population 1 size
    N2=10000, # population 2 size
    T=1, # time of split, in generations
    M12=1, # migration rate from pop1 to pop2
    M21=1 # migration rate from pop2 to pop1
)
contig = species.get_contig(
    "chr1", 
    left=10e6, 
    right=20e6, 
    mutation_rate=model.mutation_rate
)
ts = engine.simulate(model, contig, samples)
ts

okay let's set up replicates of this simulation, varying the time of the split. We will simulate 1000 random times of the split between 0 and 10000 generations. For summary statistics we will use the joint [allele frequency spectrum](https://en.wikipedia.org/wiki/Allele_frequency_spectrum).

The allele frequency spectrum is a way to summarize the genetic variation in a population. It is a vector of length 2N, where N is the number of diploid samples in the population, that gives the frequency of each allele in the sample. In the case of more than one population, we can calculate the joint allele frequency spectrum, which is a matrix of size 2N x 2N. 

We can calculate the joint allele frequency spectrum for our tree sequence using the `allele_frequency_spectrum` function and make a plot of it. We will do this for two different times of the split.


In [None]:
from matplotlib.colors import LogNorm


# comparison of allele frequency spectra

# T = 1
model1 = stdpopsim.IsolationWithMigration(
    NA=1000,  # ancestral population size
    N1=1000,  # population 1 size
    N2=1000,  # population 2 size
    T=1,  # time of split, in generations
    M12=5e-4,  # migration rate from pop1 to pop2
    M21=5e-4,  # migration rate from pop2 to pop1
)
ts1 = engine.simulate(model, contig, samples)

jAFS1 = ts1.allele_frequency_spectrum(
    sample_sets=[
        ts1.samples(population=0), 
        ts1.samples(population=1)
    ],
    span_normalise=False,
    polarised=True,
)
# flip the y axis
jAFS1 = jAFS1[::-1, :]

# T = 1000
model2 = stdpopsim.IsolationWithMigration(
    NA=1000,  # ancestral population size
    N1=1000,  # population 1 size
    N2=1000,  # population 2 size
    T=1000,  # time of split, in generations
    M12=5e-4,  # migration rate from pop1 to pop2
    M21=5e-4,  # migration rate from pop2 to pop1
)
ts2 = engine.simulate(model2, contig, samples)
jAFS2 = ts2.allele_frequency_spectrum(
    sample_sets=[ts2.samples(population=0), ts2.samples(population=1)],
    span_normalise=False,
    polarised=True,
)
# flip the y axis
jAFS2 = jAFS2[::-1, :]

# plot the joint allele frequency spectra side by side
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))


ax1.imshow(jAFS1, cmap="viridis", norm=LogNorm())
ax1.set_title("T = 1")

ax2.imshow(jAFS2, cmap="viridis", norm=LogNorm())
ax2.set_title("T = 1000")

plt.show()
# Create tick positions and labels that increment by 5
tick_positions = np.arange(0, jAFS1.shape[0], 5)
tick_labels = np.arange(jAFS1.shape[0])[::-1][::5]

# Set ticks for both axes
for ax in [ax1, ax2]:
    ax.set_yticks(tick_positions)
    ax.set_yticklabels(tick_labels)
    ax.set_xticks(tick_positions)
    ax.set_xticklabels(tick_positions)

# Add colorbars
plt.colorbar(ax1.images[0], ax=ax1)
plt.colorbar(ax2.images[0], ax=ax2)

plt.show()

so these look very different! populations with short split times have very similar allele frequency spectra, while populations with long split times have very different allele frequency spectra. 

Let's use these allele frequency spectra as features for our regression model. We will simulate 1000 random times of the split between 0 and 10000 generations, generate the joint allele frequency spectra for each simulation, flatten it into a vector, and then use that as our feature vector. 



In [None]:
# simulate 1000 random times of the split between 0 and 10000 generations
times = np.random.randint(0, 10000, size=1000)

# simulate the data for each time
sims_stats = []
for t in times:
    model = stdpopsim.IsolationWithMigration(NA=1000, N1=1000, N2=1000, T=t, M12=5e-4, M21=5e-4)
    ts = engine.simulate(model, contig, samples)
    # calculate the joint allele frequency spectrum
    jAFS = ts.allele_frequency_spectrum(
        sample_sets=[
            ts.samples(population=0),
            ts.samples(population=1),
        ],
        span_normalise=False,
        polarised=True,
    )
    # flatten the joint allele frequency spectrum
    jAFS = jAFS.flatten()
    sims_stats.append(jAFS)
sims_stats = np.array(sims_stats)
sims_stats.shape
    



In [None]:
# let's fit a linear regression model to the data
from sklearn.linear_model import LinearRegression

# do a train test split
x_train, x_test, y_train, y_test = train_test_split(sims_stats, times, test_size=0.2, random_state=666)

# fit a linear regression model to the training data
model = LinearRegression()
model.fit(x_train, y_train)

# make predictions on the testing data
preds = model.predict(x_test)

# Create a DataFrame with actual and predicted values
results_df = pd.DataFrame({"Actual": y_test, "Predicted": preds})

# plot the predictions vs actual values
plt.figure(figsize=(8, 6))
sns.scatterplot(data=results_df, x="Actual", y="Predicted")
plt.plot([0, 10000], [0, 10000], "r--")  # Add perfect prediction line
plt.xlabel("Actual Time")
plt.ylabel("Predicted Time")
plt.title("Predicted vs Actual Split Times")
plt.show()

# Calculate and print various regression metrics
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
import numpy as np

r2 = r2_score(y_test, preds)
rmse = np.sqrt(mean_squared_error(y_test, preds))
mae = mean_absolute_error(y_test, preds)

print(f"R² Score: {r2:.3f}")
print(f"Root Mean Square Error: {rmse:.3f}")
print(f"Mean Absolute Error: {mae:.3f}")


not bad! We can see that the model is able to predict the time of the split with a reasonable degree of accuracy. let's see if we can do better with a more complex model. 


We will use a random forest regressor to predict the time of the split. We will use the same summary statistics as before, but this time we will use a random forest regressor. 

In [None]:
# fit a SVM regressor to the data
from sklearn.svm import SVR
from sklearn.preprocessing import StandardScaler

# 1. First scale your data
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(x_train)
X_test_scaled = scaler.transform(x_test)

# 2. Then fit the SVR with some basic parameter tuning
model = SVR(kernel="linear")
model.fit(X_train_scaled, y_train)

# 3. Make predictions
preds = model.predict(X_test_scaled)

# plot the predictions vs actual values
plt.figure(figsize=(8, 6))
# plot the actual values, untransformed
sns.scatterplot(x=y_test, y=preds)
plt.plot([0, 10000], [0, 10000], "r--")  # Add perfect prediction line
plt.xlabel("Actual Time")
plt.ylabel("Predicted Time")
plt.title("Predicted vs Actual Split Times")
plt.show()

# 4. Calculate metrics
r2 = r2_score(y_test, preds)
rmse = np.sqrt(mean_squared_error(y_test, preds))
mae = mean_absolute_error(y_test, preds)

print(f"R² Score: {r2:.3f}")
print(f"Root Mean Square Error: {rmse:.3f}")
print(f"Mean Absolute Error: {mae:.3f}")


this is doing slightly better than the linear regression model, which itself was already doing pretty well!

Next lets see if you can try to predict the time of the split and the migration rate between the two populations. 


In [None]:
# simulate 1000 random times of the split between 0 and 10000 generations
times = np.random.randint(0, 10000, size=1000)
migration_rates = np.random.uniform(0, 0.05, size=1000)
# simulate the data for each time
sims_stats = []
for t, m in zip(times, migration_rates):
    model = stdpopsim.IsolationWithMigration(
        NA=1000, N1=1000, N2=1000, T=t, M12=m, M21=m
    )
    ts = engine.simulate(model, contig, samples)
    # calculate the joint allele frequency spectrum
    jAFS = ts.allele_frequency_spectrum(
        sample_sets=[
            ts.samples(population=0),
            ts.samples(population=1),
        ],
        span_normalise=False,
        polarised=True,
    )
    # flatten the joint allele frequency spectrum
    jAFS = jAFS.flatten()
    sims_stats.append(jAFS)
sims_stats = np.array(sims_stats)
sims_stats.shape

now we have two targets, the time of the split and the migration rate between the two populations. We can use a multi-task learning approach to predict both of these targets at the same time. Again we will use a SVM regressor.


In [None]:
from sklearn.preprocessing import RobustScaler
from sklearn.svm import SVR
from sklearn.multioutput import MultiOutputRegressor

# First do a train/test split
X_train, X_test, y_train, y_test = train_test_split(
    sims_stats,
    np.column_stack((times, migration_rates)),
    test_size=0.2,
    random_state=666,
)

# Use RobustScaler instead of StandardScaler
scaler_X = RobustScaler()
X_train_scaled = scaler_X.fit_transform(X_train)
X_test_scaled = scaler_X.transform(X_test)

# Log transform the targets (adding small constant to avoid log(0))
y_train_log = np.log1p(y_train)
y_test_log = np.log1p(y_test)

# Scale the log-transformed targets
scaler_y = RobustScaler()
y_train_scaled = scaler_y.fit_transform(y_train_log)

# Fit model with RBF kernel and increased regularization
model = MultiOutputRegressor(SVR(kernel="rbf", C=1.0, epsilon=0.1, gamma="scale"))
model.fit(X_train_scaled, y_train_scaled)

# Make predictions and inverse transform
preds_scaled = model.predict(X_test_scaled)
preds_log = scaler_y.inverse_transform(preds_scaled)
preds = np.expm1(preds_log)  # Inverse of log1p

# Plot results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

ax1.scatter(y_test[:, 0], preds[:, 0])
ax1.plot([0, 10000], [0, 10000], "r--")
ax1.set_xlabel("Actual Time")
ax1.set_ylabel("Predicted Time")
ax1.set_title("Predicted vs Actual Split Times")

ax2.scatter(y_test[:, 1], preds[:, 1])
ax2.plot([0, 0.05], [0, 0.05], "r--")
ax2.set_xlabel("Actual Migration Rate")
ax2.set_ylabel("Predicted Migration Rate")
ax2.set_title("Predicted vs Actual Migration Rates")

plt.tight_layout()
plt.show()

# Calculate metrics
r2_times = r2_score(y_test[:, 0], preds[:, 0])
r2_migration_rates = r2_score(y_test[:, 1], preds[:, 1])

print(f"R² Score for Time: {r2_times:.3f}")
print(f"R² Score for Migration Rate: {r2_migration_rates:.3f}")

rmse_times = np.sqrt(mean_squared_error(y_test[:, 0], preds[:, 0]))
rmse_migration_rates = np.sqrt(mean_squared_error(y_test[:, 1], preds[:, 1]))

print(f"Root Mean Square Error for Time: {rmse_times:.3f}")
print(f"Root Mean Square Error for Migration Rate: {rmse_migration_rates:.3f}")


---------------------

# Locator -- a deep learning approach to geolocation

In 2020 we released a machine learning package called [Locator](https://github.com/kr-colab/locator) that uses deep learning to predict the geographic location of individuals from their DNA sequence data. 

<!-- insert the locator paper image here -->
<img src="imgs/locator_paper.jpg" alt="Locator paper image" width="500">


This paper introduced a new way to do geolocation that was much more accurate than anything that had been done before. In particular, it used a fully connected neural network (a.k.a a multi-layer perceptron) to learn the relationship between DNA sequence data and geographic location. 

The idea was to use a training set of DNA sequences from a set of known geographic locations, and then use a neural network to learn the relationship between the DNA sequence data and the geographic location. 
This is similar to what we did above using the SVM and RF models, but rather than use classification we are using regression-- that is to say we are predicting the continuous variables of latitude and longitude. 

Here's what the neural network architecture looks like:

<!-- insert the locator neural network architecture image here -->
<img src="imgs/locator_network.jpg" alt="Locator neural network architecture image" width="500">

A key idea here is that recombination will breakup the genealogical relationships along the chromsomes such that some portions of the genome will have a set of ancestors in some locations and other portions will have a set of ancestors in other locations. We can use this information to derived the _uncertainty_ in the location of an individual, along with predicting where certain portions of the genome are derived from. 

Let's see how this works in practice. 



## Locator in action -- predicting the location of Arabidopsis thaliana

We will use the `locator` package to predict the location of Arabidopsis thaliana. 

I'll add that we are using a brand new version of the `locator` package that I've been working on that is much more modular than the current release. The release version is available on github at https://github.com/kr-colab/locator.git and it is a standalone script that you can run from the command line. 

The version we will work with is also available on github at https://github.com/kr-colab/locator.git@module and it is a python package that you can import into your python scripts, although user beware, it is still under active development and the API may change. 

We start by importing the package and some of the functions we will need. 


In [105]:
from locator import Locator, plot_predictions, plot_error_summary

next let's set up the Arabidopsis thaliana data that we will use to train the model. This is the same data that we used above to train the SVM and RF models, but we will use a smaller subset of the data for this example due to time constraints. 

In [None]:
# arabidopsis data
import h5py
import pandas as pd

f = h5py.File("data/araTha.hdf5", "r")
# metadata
meta = pd.read_csv("data/araTha_meta.csv")
meta.index = meta.pk

# data clean up here
countries = {
    "GER": "Germany",
    "US": "United States",
    "UK": "United Kingdom",
    "POR": "Portugal",
    "LIB": "Libya",
    "SUI": "Switzerland",
    "NED": "Netherlands",
    "DEN": "Denmark",
    "GRE": "Greece",
    "BUL": "Bulgaria",
    "CRO": "Croatia",
}
meta.country = meta.country.replace(countries)

# keep subset of european countries only
euro_countries_list = [
    "France",   
    "United Kingdom",
    "Czech Republic",
    "Austria",
    "Sweden",
    "Germany",
    "Belgium",
    "Netherlands",
    "Norway",
    "Switzerland",
    "Denmark",
]


keep_list = euro_countries_list
euro_meta = meta[meta.country.isin(keep_list)]
keep_dict = {k: v for v, k in enumerate(keep_list)}
geno_group = f["genotype"]



# artifically thinning to every 5th SNP as I want it to run fast
# for class
thin = 20
chromosomes = geno_group["col_header"]["chrom"][::thin]
positions = geno_group["col_header"]["pos"][::thin]
geno_df = pd.DataFrame(
    geno_group["matrix"][:, ::thin],
    columns=positions,
    index=geno_group["row_header"]["sample_ID"][:],
    dtype="float64",
)
print(f"shape of geno_df: {geno_df.shape}")

# get intersection index array
sample_idx = geno_df.index.intersection(euro_meta.index)
print(f"shape of intersection is {sample_idx.shape}")

euro_meta_with_snps = euro_meta.loc[sample_idx]
# First cap each country at 50 samples
capped_meta = []
for country in euro_meta_with_snps.country.unique():
    country_data = euro_meta_with_snps[euro_meta_with_snps.country == country]
    if len(country_data) > 50:
        country_data = country_data.sample(n=50)
    capped_meta.append(country_data)

capped_euro_meta = pd.concat(capped_meta)

# Then do inverse frequency sampling on the capped data
freq = capped_euro_meta.country.value_counts()
inv_freq = 1 / freq
probs = inv_freq / inv_freq.sum()
sample_weights = capped_euro_meta.country.map(probs)
euro_meta_with_snps = capped_euro_meta.sample(n=250, weights=sample_weights, replace=False)

# reset sample index
sample_idx = euro_meta_with_snps.index
snps = geno_df.loc[sample_idx]
print(f"shape of snp matrix {snps.shape}")
snps.head()

The above matrix is the first 5 rows of the SNP matrix for Arabidopsis thaliana. The row indexes are the sample IDs and the column indexes are the SNP positions. 

Next we will set up the targets for the model. We will use the latitude and longitude of the samples as the targets. 

In [None]:
coords = pd.DataFrame(
    {
        "sampleID": sample_idx.astype(str).values,  # Add .values to get plain array
        "x": euro_meta_with_snps.longitude.loc[sample_idx],
        "y": euro_meta_with_snps.latitude.loc[sample_idx],
    }
).reset_index(drop=True)
coords.head()

`locator` uses the labels `x` and `y` to refer to the longitude and latitude of the samples. The sample IDs are used to match the samples to the SNP data. 

Next we will set up a `Locator` object. The `Locator` object is the main class in the `locator` package and it is used store configuration information, the data, and the model. 

We will also load the genotype and sample data into the `Locator` object
using the `load_genotypes` function

In [None]:
locator = Locator(
    {
        "out": "araTha",
        "sample_data": coords,
        "genotype_data": snps,
        "keras_verbose": 0, # suppress training output
    }
)
genotypes, samples = locator.load_genotypes()
locator

this returns configuration information about the model, including the number of layers, the number of neurons in each layer, and the batch size. It also includes status information about the model, the sample data, and the genotype data. 

Next we will train a simple model. We will start with this configuration, and hold out 10 samples for testing. There is a convenience function `train_holdout` that will do this for us. It hides all the details of the training process, including the test/train split, the training loop, and the evaluation of the model. 




In [None]:
locator.train_holdout(genotypes=genotypes, samples=samples, k=10)
locator

that was pretty quick! Let's see how well the model did. The `locator` object automatically outputs a quick summary of the training process, showing the final valiation loss and the loss history. When we use the `train_holdout` function, it also returns the sampleIDs that were held out that we can use to evaluate the model. 




Let's get predictions for the held out samples. We will use the `predict_holdout` function to get the predictions. The `locator` object will automatically saved which samples were held out, so we can use those to evaluate the model. 

In [None]:
locator.predict_holdout(return_df=True)

This returns a dataframe with the predicted coordinates, as well as a little plot of the predictions in comparison to the true coordinates. 

# improve the model

Let's next change the model architecture to see if we can improve the model. 


In [None]:
locator = Locator(
    {
        "out": "araTha",
        "sample_data": coords,
        "genotype_data": snps,
        "keras_verbose": 0,  # suppress training output
        "nlayers": 24,
        "width": 128,
        "batch_size": 64,
        "min_mac": 1,
    }
)
genotypes, samples = locator.load_genotypes()
locator

In [None]:
locator.train_holdout(genotypes=genotypes, samples=samples, k=10)
locator

so perhaps that's not a big improvement, but it's doing a bit better. we can definitely play with this if there is time during the workshop. 

Next let's run this model on the full dataset. To do this we will use the `run_holdouts` function. This function will hold out a set number of samples at a time and then predict the coordinates for those samples. It then will loop through the held out samples and predict the coordinates for each set, until all samples have been predicted without using them during training. This will take a few minutes to run. 


In [None]:
locator = Locator(
    {
        "out": "araTha",
        "sample_data": coords,
        "genotype_data": snps,
        "keras_verbose": 0,
        "nlayers": 24,
        "width": 128,
        "batch_size": 32,
        "min_mac": 1,
    }
)


# load genotypes from geno_df
genotypes, samples = locator.load_genotypes()
ho_preds = locator.run_holdouts(genotypes, samples, k=50, return_df=True)

training is now done on all of the held out samples, and predictions are made for all samples. Let's plot the results.

To do this we will use the `plot_error_summary` function. This function will plot a histogram of the errors, showing the mean error, the median error, and the standard deviation of the error. 

It will also output a map of the predictions, showing a point where the true location is, connected to a line that shows the predicted location. Predictions are colored by the size of the error, with red being the largest errors and blue being the smallest errors. 

In [None]:
plot_error_summary(
    predictions=ho_preds,
    sample_data=coords,
    plot_map=True,
)

Not bad. We see that the model is doing a pretty good job of predicting the location of the samples, with most samples being predicted within a dozens of kilometers of the true location, however there are some samples that are predicted to be quite far away from the true location-- some of these could be due to recent migration events, others could be due to the model not being able to capture the full signal-- we have thinned out a lot of the data, and we have only used a small subset of the genome. 

# uncertainty in the predictions

Next let's look at uncertainty in our predictions. There are a few ways to do this that we have implemented in the `locator` package. The fastest way is to use "jackknife" resampling. For `locator` this is done by training the model on all the samples once, and then predicting by sampling the genotypes with replacement. Each sample is then associated with multiple predictions from resampling its genotypes. 

    

In [None]:
locator = Locator(
    {
        "out": "araTha",
        "sample_data": coords,
        "genotype_data": snps,
        "keras_verbose": 0,
        "nlayers": 24,
        "width": 128,
        "batch_size": 32,
        "min_mac": 1,
    }
)


# load genotypes from geno_df
genotypes, samples = locator.load_genotypes()
jacknife_preds = locator.run_jacknife_holdouts(
    genotypes,
    samples,
    return_df=True,
    k=10,
    jacknife_prop=0.25,
    n_replicates=500,
    )



jacknifing is done. now let's plot the results. For this we will use the `plot_predictions` function. This function will plot the predictions in comparison to the true coordinates, and it will also plot the uncertainty in the predictions. 

In [None]:
jacknife_preds


In [None]:
plot_predictions(
    predictions=jacknife_preds, 
    locator=locator,
    out_prefix="jacknife_example",
    plot_map=True,
)


## windowed predictions

Next let's look at windowed predictions. Windowed predictions are done by dividing the genome into windows of a given size, and then predicting the coordinates for each window. This is perhaps a better way at getting a sense of the uncertainty in the predictions, as it allows us to see how the model performs across the genome, capturing the uncertainty in both the estimator (the neural network) and the evolutionary process (the random sampling of ancestors going back in time).

The `run_windows_holdouts` function will do this for us, again holding out a set number of samples for testing, training on the remaining samples, but doing so in windows across the genome. 

At the end of the function call, it will return a dataframe with the predictions for each window.

In [None]:
window_preds = locator.run_windows_holdouts(
    genotypes=genotypes,
    samples=samples,
    k=10,
    return_df=True,
    save_full_pred_matrix=True,
    window_size=2.5e6,
)

let's look quickly at the predictions. 

In [None]:
window_preds

In [None]:
plot_predictions(
    predictions=window_preds,
    locator=locator,
    out_prefix="windowed_example",
    plot_map=True,
)