PCA based on scikit-allel, with some code taken from http://alimanfoo.github.io/2015/09/28/fast-pca.html 

#TODO: I have changed the matplotlib to inline from widget - this has broken the legend for the plots, so that has to be fixed.

In [1]:
#Initial configuration, probably overkill in imports.
import sys, os, re
import numpy as np
import allel
import zarr
import dask
import numcodecs
import warnings
from pathlib import Path

#os.environ["MODIN_ENGINE"] = "ray"

#import modin.pandas as pd
import pandas as pd

%matplotlib widget
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('retina', 'png')
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import seaborn as sns
sns.set()
sns.set_theme()
sns.set_style("white")
sns.set_context("notebook")

In [2]:
meta_data_samples = pd.read_table("../data/meta_data_samples.txt", sep = " ")
meta_data_samples

Unnamed: 0.1,Unnamed: 0,PGDP_ID,Provider_ID,Provider,Genus,Species,Origin,Sex,address,longitude,latitude,callset_index
0,1,PD_0199,09SNF1101115,Knauf/Chuma/Roos,Papio,anubis,"Serengeti, Tanzania",F,"Serengeti, Mara, Lake Zone, Tanzania",34.742544,-1.996626,0
1,2,PD_0200,11SNF1101115,Knauf/Chuma/Roos,Papio,anubis,"Serengeti, Tanzania",F,"Serengeti, Mara, Lake Zone, Tanzania",34.742544,-1.996626,1
2,3,PD_0201,19SNM1131115,Knauf/Chuma/Roos,Papio,anubis,"Serengeti, Tanzania",M,"Serengeti, Mara, Lake Zone, Tanzania",34.742544,-1.996626,2
3,4,PD_0202,20SNF1131115,Knauf/Chuma/Roos,Papio,anubis,"Serengeti, Tanzania",F,"Serengeti, Mara, Lake Zone, Tanzania",34.742544,-1.996626,3
4,5,PD_0203,21SNF1151115,Knauf/Chuma/Roos,Papio,anubis,"Serengeti, Tanzania",F,"Serengeti, Mara, Lake Zone, Tanzania",34.742544,-1.996626,4
...,...,...,...,...,...,...,...,...,...,...,...,...
155,212,PD_0789,34417_BZ11064,Rogers/Jolly/Phillips-Conroy,Papio,kindae,"Chunga, Zambia",F,"Chunga, Mumbwa District, Central Province, Zambia",26.005210,-15.053557,155
156,213,PD_0790,34418_BZ11065,Rogers/Jolly/Phillips-Conroy,Papio,kindae,"Chunga, Zambia",F,"Chunga, Mumbwa District, Central Province, Zambia",26.005210,-15.053557,156
157,214,PD_0791,34419_BZ11066,Rogers/Jolly/Phillips-Conroy,Papio,kindae,"Chunga, Zambia",F,"Chunga, Mumbwa District, Central Province, Zambia",26.005210,-15.053557,157
158,215,PD_0792,34420_BZ11067,Rogers/Jolly/Phillips-Conroy,Papio,kindae,"Chunga, Zambia",F,"Chunga, Mumbwa District, Central Province, Zambia",26.005210,-15.053557,158


In [3]:
#Opening the zarr data
callset = zarr.open_group('/faststorage/project/primatediversity/people/kmt/baboon_flagship/steps/callset.zarr', mode='r')
callset.tree(expand=False)

Tree(nodes=(Node(disabled=True, name='/', nodes=(Node(disabled=True, name='chr1', nodes=(Node(disabled=True, n…

In [4]:
#Functions from http://alimanfoo.github.io/2015/09/28/fast-pca.html 
def plot_ld(gn, title):
    m = allel.rogers_huff_r(gn) ** 2
    ax = allel.plot_pairwise_ld(m)
    ax.set_title(title)

def ld_prune(gn, size, step, threshold=.1, n_iter=1):
    for i in range(n_iter):
        loc_unlinked = allel.locate_unlinked(gn, size=size, step=step, threshold=threshold)
        n = np.count_nonzero(loc_unlinked)
        n_remove = gn.shape[0] - n
        print('iteration', i+1, 'retaining', n, 'removing', n_remove, 'variants')
        gn = gn.compress(loc_unlinked, axis=0)
    return gn


In [5]:
#Setting up a function to do a PCA for a specific input
def pruning_and_pca(chrom, IDs, subsampling_n, size, n_iter):
    print("Investigating {} with {} individuals".format(chrom, len(IDs)))
    gt_zarr = allel.GenotypeChunkedArray(callset["{}/calldata/GT".format(chrom)]) #Loading in the zarr dataset
    gt_zarr = gt_zarr.take(IDs, axis=1)
    ac = gt_zarr.count_alleles()[:] #Allele counts for each pos
    flt = (ac.max_allele() == 1) & (ac[:, :2].min(axis=1) > 1) #Filtering for biallelic 
    #and at least two individuals in the alt state
    gf = gt_zarr.compress(flt, axis=0) #Applying filter
    gn = gf.to_n_alt() #Transform genotype to number of non-ref alleles
    if subsampling_n > len(gn):
        subsampling_n = len(gn)-1
    vidx = np.random.choice(gn.shape[0], subsampling_n, replace=False) #Random subsampling
    vidx.sort()
    gnr = gn.take(vidx, axis=0) #Applying the random subsample
    gnu = ld_prune(gnr, size=size, step=200, threshold=.1, n_iter=n_iter) #Pruning based on LD
    if len(gnu) > 150000:
        print("Too large dataset for pca")
        return "Too large dataset for pca"
    gnu = gnu[:] #Taking it out of chunked storage
    coords1, model1 = allel.pca(gnu, n_components=10, scaler='patterson') #Running the pca
    #The plotting solution in the example is inflexible, so I will try to make it through seaborn.
    pca_df = pd.DataFrame()
    c_transposed = coords1.transpose()
    for i in range(len(c_transposed)):
        pc = "pc{}".format(i+1)
        pca_df[pc] = c_transposed[i]
    pca_df_meta = pd.concat([pca_df, meta_data_samples.iloc[IDs].reset_index()], axis = 1, ignore_index=False)
    return pca_df_meta

In [6]:
IDs = meta_data_samples.loc[(meta_data_samples.Sex == "F") & (meta_data_samples.PGDP_ID != "PD_0202")].callset_index.values
pca_df = pruning_and_pca("chrX", IDs, 3000000, 500, 2) #Make sure that enough variants are removed.

Investigating chrX with 63 individuals
iteration 1 retaining 105204 removing 1525739 variants
iteration 2 retaining 17080 removing 88124 variants


In [7]:
scatter_sns = sns.scatterplot(data = pca_df, x="pc1", y="pc2", hue="Origin")
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
fig = scatter_sns.get_figure()
fig.savefig("../results/pca_plot_origin")

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [8]:
scatter_sns = sns.scatterplot(data = pca_df, x="pc1", y="pc2", hue="Species")
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
fig = scatter_sns.get_figure()
fig.savefig("../results/pca_plot_species")

In [9]:
sns.scatterplot(data = pca_df, x="pc3", y="pc4", hue="Species", style="Sex")
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)

<matplotlib.legend.Legend at 0x2acfcb03d210>

In [10]:
#Only taking 10 individuals per species.
females = meta_data_samples.loc[(meta_data_samples.Sex == "F") & (meta_data_samples.PGDP_ID != "PD_0202")]
IDs = []

for species in meta_data_samples.Species.unique():
    species_IDs = females.loc[females.Species == species]
    n = min(10, len(species_IDs))
    print("Taking {} from species {}".format(n, species))
    IDs.extend(np.random.choice(species_IDs.callset_index.values, n, replace=False))

Taking 10 from species anubis
Taking 10 from species cynocephalus
Taking 5 from species papio
Taking 3 from species ursinus (grayfoot)
Taking 5 from species hamadryas
Taking 7 from species kindae


In [11]:
pca_df = pruning_and_pca("chrX", IDs, 3000000, 250, 2)

Investigating chrX with 40 individuals
iteration 1 retaining 113168 removing 1297662 variants
iteration 2 retaining 16787 removing 96381 variants


In [12]:
scatter_sns = sns.scatterplot(data = pca_df, x="pc1", y="pc2", hue="Species")
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
fig = scatter_sns.get_figure()
fig.savefig("../results/pca_plot_species_5_IDs")

In [13]:
scatter_sns = sns.scatterplot(data = pca_df, x="pc1", y="pc2", hue="Origin")
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
fig = scatter_sns.get_figure()
fig.savefig("../results/pca_plot_origin_5_IDs")

In [14]:
#Only taking 10 individuals per species, here on chr7 and with males.
IDs = []

for species in meta_data_samples.Species.unique():
    species_IDs = meta_data_samples.loc[meta_data_samples.Species == species]
    n = min(10, len(species_IDs))
    print("Taking {} from species {}".format(n, species))
    IDs.extend(np.random.choice(species_IDs.callset_index.values, n, replace=False))

Taking 10 from species anubis
Taking 10 from species cynocephalus
Taking 10 from species papio
Taking 4 from species ursinus (grayfoot)
Taking 10 from species hamadryas
Taking 10 from species kindae


In [15]:
pca_df = pruning_and_pca("chr7", IDs, 3000000, 250, 2)

Investigating chr7 with 54 individuals
iteration 1 retaining 254932 removing 2466334 variants
iteration 2 retaining 43278 removing 211654 variants


In [16]:
scatter_sns = sns.scatterplot(data = pca_df, x="pc1", y="pc2", hue="Species")
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
fig = scatter_sns.get_figure()
fig.savefig("../results/pca_plot_species_10_IDs_chr7")

In [17]:
scatter_sns = sns.scatterplot(data = pca_df, x="pc1", y="pc2", hue="Origin")
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
fig = scatter_sns.get_figure()
fig.savefig("../results/pca_plot_origin_10_IDs_chr7")

In [18]:
scatter_sns = sns.scatterplot(data = pca_df, x="pc3", y="pc4", hue="Species")
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)

<matplotlib.legend.Legend at 0x2acfca2d3750>