NOTE: min_cohort_size and max_cohort_size parameters are not supported in current release for snp_allele_counts. The relevant code has been commented out here.

In [1]:
!pip install malariagen_data

Collecting malariagen_data
  Downloading malariagen_data-7.13.0-py3-none-any.whl (133 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m133.2/133.2 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting BioPython (from malariagen_data)
  Downloading biopython-1.81-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m58.8 MB/s[0m eta [36m0:00:00[0m
Collecting dash (from malariagen_data)
  Downloading dash-2.14.1-py3-none-any.whl (10.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.4/10.4 MB[0m [31m101.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting dash-cytoscape (from malariagen_data)
  Downloading dash_cytoscape-0.3.0-py3-none-any.whl (3.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.6/3.6 MB[0m [31m87.2 MB/s[0m eta [36m0:00:00[0m
Collecting igv-notebook>=0.2.3 (from malariagen_data)
  Downloading igv_n

In [2]:
import malariagen_data
import allel
import pandas as pd
import numpy as np
import plotly.express as px

In [3]:
ag3 = malariagen_data.Ag3()

In [4]:
def average_fst(
    region,
    cohort1_query,
    cohort2_query,
    cohort_size = 10,
    #min_cohort_size = 15,
    #max_cohort_size = 50,
    n_jack = 200,
    site_mask = 'gamb_colu_arab',
    site_class = None,
    random_seed = 42,
):
    # calculate allele counts for each cohort
    cohort1_counts = ag3.snp_allele_counts(
        region=region,
        sample_sets=None,
        sample_query=cohort1_query,
        cohort_size=cohort_size,
        site_mask=site_mask,
        site_class=site_class,
        #min_cohort_size=min_cohort_size,
        #max_cohort_size=max_cohort_size,
        random_seed=random_seed,
    )

    cohort2_counts = ag3.snp_allele_counts(
        region=region,
        sample_sets=None,
        sample_query=cohort2_query,
        cohort_size=cohort_size,
        site_mask=site_mask,
        site_class=site_class,
        #min_cohort_size=min_cohort_size,
        #max_cohort_size=max_cohort_size,
        random_seed=random_seed,
    )

    # calculate block length for blen
    n_sites = cohort1_counts.shape[0]  # number of sites
    block_length = n_sites // n_jack  # number of sites in each block

    # calculate pairwise fst
    fst_hudson, se_hudson, vb_hudson, _ = allel.blockwise_hudson_fst(
        cohort1_counts, cohort2_counts, blen=block_length
    )

    return fst_hudson, se_hudson

In [23]:
# check data
assert isinstance(fst_hudson, np.float64)
assert isinstance(se_hudson, np.float64)

# check dimensions
assert np.isscalar(fst_hudson)
assert np.isscalar(se_hudson)

# check some values
assert np.allclose(fst_hudson, 0.039983, rtol=1e5), fst_hudson
assert np.allclose(se_hudson, 0.003327, rtol=1e5), se_hudson

In [5]:
fst_hudson, se_hudson = average_fst(
    region="3L:15,000,000-16,000,000",
    cohort1_query="cohort_admin2_year == 'ML-2_Kati_colu_2014'",
    cohort2_query="cohort_admin2_year == 'ML-2_Kati_gamb_2014'",
    n_jack=200,
    site_mask = 'gamb_colu')

Compute SNP allele counts:   0%|          | 0/44 [00:00<?, ?it/s]

Compute SNP allele counts:   0%|          | 0/24 [00:00<?, ?it/s]

In [6]:
fst_hudson

0.07666678164848241

In [7]:
se_hudson

0.0035530631065316047

In [10]:
def setup_cohorts(
    cohorts,
    sample_sets = None,
    sample_query = None,
    cohort_size = 15,
    #min_cohort_size = None,
):
    if isinstance(cohorts, dict):
        # user has supplied a custom dictionary mapping cohort identifiers
        # to pandas queries
        cohort_queries = cohorts

    elif isinstance(cohorts, str):
        # user has supplied one of the predefined cohort sets
        df_samples = ag3.sample_metadata(
            sample_sets=sample_sets, sample_query=sample_query
        )

        # determine column in dataframe - allow abbreviation
        if cohorts.startswith("cohort_"):
            cohorts_col = cohorts
        else:
            cohorts_col = "cohort_" + cohorts
        if cohorts_col not in df_samples.columns:
            raise ValueError(f"{cohorts_col!r} is not a known cohort set")

        # find cohort labels and build queries dictionary
        cohort_labels = sorted(df_samples[cohorts_col].dropna().unique())
        cohort_queries = {coh: f"{cohorts_col} == '{coh}'" for coh in cohort_labels}

    else:
        raise TypeError("cohorts parameter should be dict or str")

    # handle sample_query parameter
    if sample_query is not None:
        cohort_queries = {
            cohort_label: f"({cohort_query}) and ({sample_query})"
            for cohort_label, cohort_query in cohort_queries.items()
        }

    # check cohort sizes, drop any cohorts which are too small
    cohort_queries_checked = dict()
    for cohort_label, cohort_query in cohort_queries.items():
        df_cohort_samples = ag3.sample_metadata(
            sample_sets=sample_sets, sample_query=cohort_query
        )
        n_samples = len(df_cohort_samples)
        #if min_cohort_size is not None:
            #cohort_size = min_cohort_size
        if n_samples < cohort_size:
            print(
                f"cohort ({cohort_label}) has insufficient samples ({n_samples}) for requested cohort size ({cohort_size}), dropping"
            )
        else:
            cohort_queries_checked[cohort_label] = cohort_query
    return cohort_queries_checked

In [11]:
cohorts_checked = setup_cohorts('cohort_admin1_year', sample_query = 'country == "Mali" and taxon =="gambiae"')

cohort (ML-4_gamb_2004) has insufficient samples (1) for requested cohort size (15), dropping


In [12]:
cohorts_checked

{'ML-2_gamb_2004': '(cohort_admin1_year == \'ML-2_gamb_2004\') and (country == "Mali" and taxon =="gambiae")',
 'ML-2_gamb_2014': '(cohort_admin1_year == \'ML-2_gamb_2014\') and (country == "Mali" and taxon =="gambiae")',
 'ML-3_gamb_2012': '(cohort_admin1_year == \'ML-3_gamb_2012\') and (country == "Mali" and taxon =="gambiae")'}

In [13]:
def pairwise_average_fst(
    region,
    cohorts,
    sample_sets = None,
    sample_query = None,
    cohort_size = 10,
    #min_cohort_size = 15,
    #max_cohort_size = 50,
    n_jack = 200,
    site_mask = "gamb_colu_arab",
    site_class = None,
    random_seed = 42,
):
    # set up cohort queries
    cohorts_checked = setup_cohorts(
        cohorts,
        sample_sets=sample_sets,
        sample_query=sample_query,
        cohort_size=cohort_size,
        #min_cohort_size=min_cohort_size,
    )

    cohort_ids = list(cohorts_checked.keys())
    cohort_queries = list(cohorts_checked.values())
    cohort1_ids = []
    cohort2_ids = []
    fst_stats = []
    se_stats = []

    n_cohorts = len(cohorts_checked)
    for i in range(n_cohorts):
        for j in range(i + 1, n_cohorts):
            (
                fst_hudson,
                se_hudson,
            ) = average_fst(
                region=region,
                cohort1_query=cohort_queries[i],
                cohort2_query=cohort_queries[j],
                cohort_size=cohort_size,
                #min_cohort_size=min_cohort_size,
                #max_cohort_size=max_cohort_size,
                n_jack=n_jack,
                site_mask=site_mask,
                site_class=site_class,
                random_seed=random_seed,
            )
            # convert minus numbers to 0
            if fst_hudson < 0:
                fst_hudson = 0
            # add values to lists
            cohort1_ids.append(cohort_ids[i])
            cohort2_ids.append(cohort_ids[j])
            fst_stats.append(fst_hudson)
            se_stats.append(se_hudson)

    fst_df = pd.DataFrame(
        {
            "cohort1": cohort1_ids,
            "cohort2": cohort2_ids,
            "fst": fst_stats,
            "se": se_stats,
        }
    )

    return fst_df

In [15]:
pairwise_fst_df = pairwise_average_fst(
    region="3L:15,000,000-16,000,000",
    cohorts="cohort_admin1_year",
    sample_query="country == 'Mali' and taxon == 'gambiae'",
    n_jack=200,
    site_mask="gamb_colu",
)

cohort (ML-4_gamb_2004) has insufficient samples (1) for requested cohort size (10), dropping


Compute SNP allele counts:   0%|          | 0/44 [00:00<?, ?it/s]

Compute SNP allele counts:   0%|          | 0/24 [00:00<?, ?it/s]

Compute SNP allele counts:   0%|          | 0/44 [00:00<?, ?it/s]

Compute SNP allele counts:   0%|          | 0/44 [00:00<?, ?it/s]

Compute SNP allele counts:   0%|          | 0/24 [00:00<?, ?it/s]

Compute SNP allele counts:   0%|          | 0/44 [00:00<?, ?it/s]

In [16]:
pairwise_fst_df

Unnamed: 0,cohort1,cohort2,fst,se
0,ML-2_gamb_2004,ML-2_gamb_2014,0.03725,0.002041
1,ML-2_gamb_2004,ML-3_gamb_2012,0.041542,0.00237
2,ML-2_gamb_2014,ML-3_gamb_2012,0.0,0.000814


In [19]:
test_df = pd.DataFrame(
        {
            "cohort1": [
                'ML-2_gamb_2004',
                'ML-2_gamb_2004',
                'ML-2_gamb_2014',
            ],
            "cohort2": [
                'ML-2_gamb_2014',
                'ML-3_gamb_2012',
                'ML-3_gamb_2012',
            ],
            "fst": [
                0.037249514094550934,
                0.04154191785684654,
                0.0,
            ],
            "se": [
                0.0020406887352541958,
                0.002369740033208285,
                0.0008138514674580574,
            ],
        }
    )

In [20]:
# check data
assert isinstance(pairwise_fst_df, pd.core.frame.DataFrame)

# check some values
pd.testing.assert_frame_equal(pairwise_fst_df, test_df, rtol=1e5)
assert np.all(pairwise_fst_df["fst"] <= 1)
assert np.all(pairwise_fst_df["fst"] >= -0.05)

In [39]:
def plot_pairwise_average_fst(
    fst_df,
    annotate_se=True,
    zmin=0,
    zmax=1,
    width=None,
    height=None,
    text_auto=True,
    color_continuous_scale="gray",
    title=None,
    **kwargs,
):
    fst_pivot = fst_df.pivot(
        index="cohort2",
        columns="cohort1",
        values="fst",
    )
    se_pivot = fst_df.pivot(
        index="cohort2",
        columns="cohort1",
        values="se",
    )

    # remove index labels
    fst_pivot = fst_pivot.rename_axis(None, axis=1).rename_axis(None, axis=0)
    se_pivot = se_pivot.rename_axis(None, axis=1).rename_axis(None, axis=0)

    # place SE on upper triangle
    if annotate_se is True:
        se_df = pd.DataFrame(
            np.flip(se_pivot.values), index=se_pivot.index, columns=se_pivot.columns
        )
        fst_pivot = fst_pivot.fillna(se_df)

    # convert to str and 3 decimal places
    for col in fst_pivot:
        new_col = fst_pivot[col].map("{:.3f}".format).astype(str)
        fst_pivot[col] = new_col

    # place empty values on diagonal
    for i in range(min(fst_pivot.shape)):
        fst_pivot.iloc[i, i] = ""

    # create plot
    fig = px.imshow(
        img=fst_pivot,
        zmin=zmin,
        zmax=zmax,
        width=width,
        height=height,
        text_auto=text_auto,
        color_continuous_scale=color_continuous_scale,
        title=title,
        aspect="auto",
        **kwargs,
    )
    fig.update_traces(text=fst_pivot.values, texttemplate="%{text}")
    fig.update_layout(plot_bgcolor="rgba(0,0,0,0)")
    fig.update_yaxes(showgrid=False, linecolor="black")
    fig.update_xaxes(showgrid=False, linecolor="black")

    return fig

In [40]:
plot_pairwise_average_fst(pairwise_fst_df, annotate_se=True)

In [42]:
plot_pairwise_average_fst(pairwise_fst_df, annotate_se=False)