# GWSS analysis using Fst_gwss

In [None]:
!pip install -qq malariagen_data
!pip install -q petl


import malariagen_data
import numpy as np
import pandas as pd
import allel
import zarr
import dask
import dask.array as da
# silence some dask warnings
dask.config.set(**{'array.slicing.split_large_chunks': True})
from dask.diagnostics.progress import ProgressBar

import random
import functools
import petl as ptl
import itertools
import scipy
from collections import Counter

# plotting setup
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from matplotlib.gridspec import GridSpec
import matplotlib_venn as venn
import seaborn as sns

ag3 = malariagen_data.Ag3()
ag3

In [None]:
# step1: Define population: between An. gambia complex or withing each taxon
taxons = ["coluzzii", "gambiae", "bissau"]
taxon="bissau"
country = "Gambia, The"
sample_query = f"country=='{country}' and taxon=='{taxon}'"

taxon_sample = ag3.sample_metadata(sample_query=sample_query)
#taxon_sample['pop'] = taxon_sample['admin1_iso'] + '_' + taxon_sample['taxon']
#taxon_sample.groupby('pop').size()
#taxon_sample = taxon_sample[taxon_sample['pop'].isin(taxon_sample['pop'].value_counts()[taxon_sample['pop'].value_counts() >= 10].index)]

# Get 100 samples from coluzzii
#coluzzii_subset = taxon_sample.query("taxon=='coluzzii'").groupby("admin1_iso").apply(lambda x: x.sample(n=min(20, len(x)))).reset_index(drop=True)

# Get all other taxons (gambiae and bissau)
#other_taxa_sample = taxon_sample[taxon_sample["taxon"] != "coluzzii"]

# Concatenate both to redefine taxon_sample
#taxon_sample = pd.concat([coluzzii_subset, other_taxa_sample], ignore_index=True)


#gambiae_subset_index = taxon_sample.query("taxon=='gambiae'").index
#coluzzii_subset_index = taxon_sample.query("taxon=='coluzzii'").index
#bissau_subset_index = taxon_sample.query("taxon=='bissau'").index
"""taxon_pop = {
  'gambiae': gambiae_subset_index,
  'coluzzii': coluzzii_subset_index,
  'bissau': bissau_subset_index
}"""
#taxon_pop = dict([(f"{p}", list(df.index)) for p, df in taxon_sample.reset_index().groupby(['taxon'])])
taxon_sample['pop'] = taxon_sample['admin1_iso'] + '_' + taxon_sample['taxon']
#cohort_pop = dict([(f"{p}", list(df.index)) for p, df in taxon_sample.reset_index().groupby(['pop'])])
pops = taxon_sample['pop'].unique()
taxon_pop = {pop: taxon_sample[taxon_sample['pop'] == pop].index for pop in pops}
taxon_pop.keys()
#pops = taxon_sample['taxon'].unique()


# step2: Get SNP data
samples_list = list(taxon_sample['sample_id'])
sample_query = f"sample_id in {samples_list}"
#To access to the genotypes within the 2L chromosomes
chrom2L_gt = ag3.snp_calls(region="3R", sample_query=sample_query)

# To filter the SNP dataset and warp the dataset to GT array
filt = 'gamb_colu'
filt_val = chrom2L_gt[f"variant_filter_pass_{filt}"].values
gt_filtered = allel.GenotypeDaskArray(chrom2L_gt["call_genotype"][filt_val].data)
filterd_val = chrom2L_gt[f"variant_filter_pass_{filt}"].values
filtered_pos = allel.SortedIndex(chrom2L_gt['variant_position'][filterd_val].values)

gt_filtered.shape #(37199402, 351, 2)
# step3: Compute allel count for each population for a given chromosome

ac_pop_taxon = {} ## compute allele counts pop - dict
with ProgressBar():
  for pop in taxon_pop.keys():
    print(f"Computing allele count for {pop}")
    ac_pop_taxon[pop] = gt_filtered.take(taxon_pop[pop],axis=1).count_alleles(max_allele=3).compute()

# step4: Get SNP position then sort it.
#snps_pos = allel.SortedIndex(chrom2L_gt['variant_position'].values)
#filtered_pos = allel.SortedIndex(chrom2L_gt['variant_position'][filterd_val].values)
filtered_pos.shape

# Save the ac_pop_taxon dict
import pickle
with open('ac_pop_taxon.pkl', 'wb') as f:
    pickle.dump(ac_pop_taxon, f)

In [None]:
@functools.lru_cache(maxsize=None)
def load_ac(pop):
    ac_popi = ac_pop_taxon[pop]
    return ac_popi

In [None]:
def compute_windowed_pair_fst(pops, filt_pos, size):
    # order is irrelevant
    pop1, pop2 = pops
    ac_pop1 = load_ac(pop=pop1)
    ac_pop2 = load_ac(pop=pop2)
    loc_pass = ac_pop1.is_segregating() & ac_pop2.is_segregating() & (ac_pop1.max_allele() <= 1) & (ac_pop2.max_allele() <= 1)
    # Find selected SNPs pos
    pos=filt_pos[loc_pass]
    # compute windowed_fst
    fst_wind, windows, counts = allel.windowed_hudson_fst(pos, ac_pop1[loc_pass], ac_pop2[loc_pass],size=size)
    return fst_wind, windows, counts

def compute_windowed_pairwise_fst(coh_list, pos_filt, size):
    pairwise_fst = {}
    wind_pairwise_fst = {}
    for pop1, pop2 in itertools.combinations(coh_list, 2):
        pops = tuple(sorted([pop1, pop2]))
        fst, wind, _ = compute_windowed_pair_fst(pops=pops,filt_pos=pos_filt,size=size)
        pairwise_fst[f'fst_{pop1}_{pop2}'] = fst
        wind_pairwise_fst[f'wind_{pop1}_{pop2}'] = wind
    return pairwise_fst, wind_pairwise_fst

In [None]:
# compute fst and windows
import json
fst_dict_ag, wind_fst_dict_ag = compute_windowed_pairwise_fst(coh_list=pops, pos_filt=filtered_pos, size=100000)
type(wind_fst_dict_ag)
# save a dict
fst_dict_ag_cpy = fst_dict_ag.copy()
for key in fst_dict_ag_cpy:
    fst_dict_ag_cpy[key] = fst_dict_ag_cpy[key].tolist()  # Convert each ndarray to a list

# save a dict
with open('fst_gwss_taxon.json', 'w') as f:
    json.dump(fst_dict_ag_cpy, f)

# warp fst data to dataframe
df_fst=pd.DataFrame(fst_dict_ag)
df_fst
df_fst.to_csv('fst_gwss_taxon.csv')
# load data
#df_fst = pd.read_csv('/content/fst_gwss_taxon.csv')
# Extract windowed FST means for all population pairs
"""df_wind = pd.DataFrame({
    key: np.mean(wind_fst_dict_ag[key], axis=1)  # Compute mean FST for each pair
    for key in wind_fst_dict_ag.keys()
})"""
 # Rename columns to indicate pairs
#df_wind.columns = [key.replace('wind_', '') for key in wind_fst_dict_ag.keys()]
key = list(wind_fst_dict_ag.keys())[0]
df_wind = pd.DataFrame(np.mean(wind_fst_dict_ag[key], axis=1), columns=['wind_mean'])
df_wind.to_csv('wind_mean_3R.csv')
data_fst_ag = pd.concat([df_wind, df_fst], axis=1)
ag_fst = data_fst_ag.copy()
ag_fst[ag_fst < 0] = 0  # Ensure FST values are non-negative
ag_fst.to_csv('fst_gwss_bissau_3R.csv')



# Warp to df
data_fst_ag = ag_fst.melt('wind_mean', var_name='taxon_Pop', value_name='pair_fst')
data_fst_ag.to_csv('fst_gwss_bissau_3R.csv')
#data_fst_ag = data_fst_ag.query("taxon_Pop =='fst_coluzzii_gambiae'")
xlim1 = list(ag_fst.wind_mean)

#data_fst_ag.query("pair_fst >= 0.2 and taxon_Pop != 'fst_gambia_bissau' and 6000000 <= wind_mean <= 10000000").sort_values(by='wind_mean', ascending=True)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

def plot_fst_gwss(data=data_fst_ag, figsize=(12,4), ax=None, title=None, xlab=None,
                  ylab=None, legend_loc=None, fig_name=None):

    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
        sns.despine(ax=ax, offset=5)

    # Plot FST
    sns.lineplot(x='wind_mean', y='pair_fst', hue='taxon_Pop', data=data, ax=ax, linewidth=0.5)

    # Set axis titles
    if title:
        ax.set_title(title)
    if xlab:
        ax.set_xlabel(xlab)
    if ylab:
        ax.set_ylabel(ylab)

    # Set legend
    if legend_loc:
        ax.legend(loc='best', bbox_to_anchor=legend_loc)

    # Set axis limits
    xlim = [xlim1[0], xlim1[-1]]
    ax.set_xlim(*xlim)
    #ax.set_ylim(0, 1)  # Change y-axis range from 0 to 1
    ax.set_ylim(0,0.25)
    #ax.set_yticks([i / 20 for i in range(21)])  # Set ticks at 0.05 intervals (0, 0.05, 0.10, ..., 1.0)
    ax.set_yticks([i / 10 for i in range(11)])

    # Format x-axis tick labels
    ax.set_xticklabels(['{:,}'.format(int(v)) for v in ax.get_xticks()])

    # Save figure
    if fig_name:
        fig.savefig(f"{fig_name}.png", dpi=300, bbox_inches='tight')



In [None]:
## Compute An. gambiae Fst

fig, ax = plt.subplots(figsize=(12, 4))
sns.despine(ax=ax, offset=5)
plot_fst_gwss(data=data_fst_ag, ax=ax, xlab='Chrom 3R (bp)', ylab='Pairwise $F_{ST}$ in 100kb windows',
              legend_loc=(1, 1))
"""
ax.annotate('$vgsc$ region', xy=(2431617, 0.7), xytext=(600000, 1),
            color='darkred', arrowprops=dict(arrowstyle="->", color='slategrey'))
ax.annotate('$OBP41$ region', xy=(11551855.5, 0.3), xytext=(15000000, 0.5),
            color='darkred', arrowprops=dict(arrowstyle="->", color='slategrey'))
ax.annotate('$Ors$ region', xy=(32080000, 0.05), xytext=(31000000, 0.10),
            color='darkred', arrowprops=dict(arrowstyle="->", color='slategrey'))"""

# Save fig
fig.savefig(f"3R_window_bissau_fst.png", dpi=300, bbox_inches='tight')
fig.show()