# Ancestry analysis on regions under sweep selection

In [None]:
!pip install -qq malariagen_data
import malariagen_data
import numpy as np
import pandas as pd
import allel
import zarr
import matplotlib.pyplot as plt
import seaborn as sns


ag3 = malariagen_data.Ag3()
ag3

## workflow to prepare SNP data and plot PCoA and UMAP

In [None]:
#!pip install -qq umap
#!pip install -qq umap-learn

!pip install -qq  scipy
!pip install -qq scikit-bio
!pip install -q umap-learn

import umap.umap_ as umap

In [None]:
# SNP data
from dask.diagnostics.progress import ProgressBar

def get_snp_data(region, sample_query, filter):
    print(f"Processing SNP data for region: {region}")
    # Construct snp data
    gm_snps =ag3.snp_calls(
        region=region,
        sample_query=sample_query,)

    # construct population information
    taxon_sample = ag3.sample_metadata(sample_query=sample_query)
    # Define pop based on iso and taxon like GM-N_gamb, GM-U_arab etc.
    taxon_sample['pop'] = taxon_sample['admin1_iso'] + '_' + taxon_sample['taxon']
    taxon_sample.groupby('pop').size()
    # kepp pop >= 10
    taxon_sample = taxon_sample[taxon_sample['pop'].isin(taxon_sample['pop'].value_counts()[taxon_sample['pop'].value_counts() >= 10].index)]
    # SAve taxon_sample metadata
    taxon_sample.to_csv('taxon_sample.csv', sep=',', index=True)

    # Get all sample IDs from the dataset
    dataset_sample_ids = gm_snps['sample_id'].values  # Array of sample IDs in gm_snps

    # Find indices of selected samples
    sample_indices = np.where(np.isin(dataset_sample_ids, taxon_sample['sample_id'].values))[0]

    # Select samples using indices
    gm_snps_selected = gm_snps.isel(samples=sample_indices)

    samples_selected = gm_snps_selected['sample_id'].values

    # To filter the SNP dataset and warp the dataset to GT array
    filt_val = gm_snps_selected[f"variant_filter_pass_{filter}"].values
    gt_filtered = allel.GenotypeDaskArray(gm_snps_selected["call_genotype"][filt_val].data)

    return gt_filtered, gm_snps_selected, taxon_sample


def snp_filtering_proccess(gt_filtered, gm_snps_selected, taxon_sample):
  # Get all sample IDs in the SNP dataset
  dataset_sample_ids = gm_snps_selected['sample_id'].values  # Sample IDs in gm_snps_selected

  # Get indices of selected samples based on taxon_sample['sample_id']
  pop_idx = np.where(np.isin(dataset_sample_ids, taxon_sample['sample_id'].values))[0]
  with ProgressBar():
    ac = gt_filtered.take(pop_idx, axis=1).count_alleles(max_allele=3).compute()
  #filter - to keep biallelic and remove missing data
  filter_ac = (ac.max_allele() == 1) & (ac[:, :2].min(axis=1) > 1)
  gt_rm_miss = gt_filtered.compress(filter_ac, axis=0)
  geno_alt = gt_rm_miss.to_n_alt()

  return geno_alt


In [None]:
# Perform LD pruning
import dask.array as da
import allel
import numpy as np
def plot_ld(gn, title):
    """
    Compute LD between SNP and plot pairwise LD.
    """
    m = allel.rogers_huff_r(gn) ** 2
    ax = allel.plot_pairwise_ld(m)
    ax.set_title(title)
def ld_prune(gn, size, step, threshold=.1, n_iter=1):
    """
    Perform LD pruning on genotype data.
    """
    for i in range(n_iter):
        # Convert gn to a NumPy array if it's a Dask array
        if isinstance(gn, da.Array):
            gn = gn.compute()

        loc_unlinked = allel.locate_unlinked(gn, size=size, step=step, threshold=threshold)
        n = np.count_nonzero(loc_unlinked)
        n_remove = gn.shape[0] - n
        print('iteration', i + 1, 'retaining', n, 'removing', n_remove, 'variants')
        gn = gn.compress(loc_unlinked, axis=0)
    return gn

In [None]:
# SNP data processing
regions = ["2R:28,430,000-28,615,000", "3R:28,400,000-28,600,000", "X:15,130,000-15,324,000", "3R:32,000,000-32,080,000"]
taxons = ["coluzzii", "gambiae", "bissau", "arabiensis", "melas"]
#taxons = ["coluzzii", "gambiae", "bissau"]
country = "Gambia, The"
sample_query = f"country=='{country}' and taxon in {taxons}"
filter = 'gamb_colu_arab'
#filter = 'gamb_colu'
region = regions
gt_filtered, gm_snps_selected, taxon_sample = get_snp_data(region, sample_query, filter)
geno_alt = snp_filtering_proccess(gt_filtered, gm_snps_selected, taxon_sample)

geno_alt.shape

In [None]:
# LD perform LD pruning
gnu = ld_prune(geno_alt, size=500, step=200, threshold=.1, n_iter=10)
plot_ld(geno_alt[:1000], 'Figure 1. Pairwise LD before LD pruning.')
plot_ld(gnu[:1000], 'Figure 2. Pairwise LD after LD pruning.')

In [None]:
gnu.shape

In [None]:
# save genotype data
from google.colab import files
np.save('filtered_geno_alt.npy', geno_alt)
np.save('filtered_geno_alt_pruned.npy', gnu)

files.download('filtered_geno_alt.npy')
files.download('filtered_geno_alt.npy')

In [None]:
# Perform PCoA
from skbio.stats.ordination import pcoa

# Compute pairwise distance
matrix_distance = allel.pairwise_distance(geno_alt, metric='cityblock')
# Perform PCoA using scikit-bio
pcoa_results = pcoa(matrix_distance)

pcoa_results

In [None]:

import altair as alt

# import population info
taxon_sample = pd.read_csv('taxon_sample.csv')

#load anopheles gambiae from Gambia country
pop_info = taxon_sample[['sample_id', 'taxon', 'pop', 'location', 'admin1_iso', 'admin1_name', 'admin2_name', 'cohort_admin1_quarter']]

# show variance
pcoa_results.proportion_explained * 100
# save variance
variance_explained = pcoa_results.proportion_explained * 100
variance_explained = pd.DataFrame(variance_explained)
variance_explained.columns = ['variance_explained']
variance_explained.index.name = 'PC'
variance_explained.reset_index(inplace=True)
variance_explained.groupby('PC').sum()
variance_explained.to_csv(f'variance_explained_all_reigion.csv', sep=',', index=True)
variance_explained

# Sort the dataframe by variance explained in descending order
variance_explained_sorted = variance_explained.sort_values(by='variance_explained', ascending=False)
# Select the top 10 principal components
top_10_variance_explained = variance_explained_sorted.head(10)
# Create a bar chart using Altair
chart = alt.Chart(top_10_variance_explained).mark_bar().encode(
    # Map PC to the x-axis and variance_explained to the y-axis
    x=alt.X('PC', sort='-y', title='Principal Component'),
    y=alt.Y('variance_explained', title='Variance Explained'),
    # Add tooltips to display PC and variance explained when hovering
    tooltip=['PC', 'variance_explained']
).properties(
    title='Variance Explained by Top 10 Principal Components' # Add a title to the chart
)
chart

In [None]:
# join the pcoa result with the metadata
pcoa_results_df = pcoa_results.samples.iloc[:, 0:10]

pcoa_results_df['sample_id'] = pop_info['sample_id'].values
# merge pcoa and metadata by sample_id
pcoa_results_df = pcoa_results_df.merge(pop_info, on='sample_id', how='left')
pcoa_results_df

# save file
pcoa_results_df.to_csv(f'pcoa_taxon_all_region_PC1_vs_PC2_all_taxon.csv', sep=',', index=True)

In [None]:
# Extract population labels as list to match with geno
pop_info['taxon'] = pop_info['taxon'].replace(
    {
        "coluzzii": "An. coluzzii",
        "gambiae": "An. gambiae s.s",
        "bissau": "Bissau",
        "arabiensis": "An. arabiensis",
        "melas": "An. melas"
    }
)
population_labels = pop_info['taxon'].tolist()
population_labels

# Plot PCoA results with colors by population
#plt.figure(figsize=(10, 8))
plt.figure(figsize=(12, 10))

# Create a color palette for each population
unique_populations = np.unique(population_labels)
unique_populations

# define color palette manually
population_color_map = {
    'Bissau': 'purple',
    'An. gambiae s.s': 'blue',
    'An. coluzzii': 'orange',
    'An. arabiensis': 'green',
    'An. melas': 'red'
}

# Scatter plot with colors assigned by population
x=2
y=3
for pop in unique_populations:
    mask = np.array(population_labels) == pop
    plt.scatter(
        pcoa_results.samples[f'PC{x}'][mask],
        pcoa_results.samples[f'PC{y}'][mask],
        label=pop,
        edgecolor='white',     # White border
        s=200,
        linewidths=1.5,
        color=population_color_map[pop])
# Get current legend
legend = plt.legend(title='Population', loc='best', fontsize=21, bbox_to_anchor=(1.45, 1))
plt.setp(legend.get_title(), fontweight='bold')

# Customize individual legend text

for i, text in enumerate(legend.get_texts()):
    pop_name = text.get_text()
    if pop_name == 'Bissau':
        text.set_fontweight('bold')
        text.set_fontsize(22)
    else :
        text.set_fontstyle('italic')
        text.set_fontweight('bold')
        text.set_fontsize(22)

# Add title and axis labels with bold font
plt.title(f'PCoA of Anopheles gambiae complex using sweep regions (39 971 SNPs)', fontsize=18, fontweight='bold')
plt.xlabel(f'PC{x} ({pcoa_results.proportion_explained[x-1]*100:.2f}% Variance)', fontsize=16, fontweight='bold')
plt.ylabel(f'PC{y} ({pcoa_results.proportion_explained[y-1]*100:.2f}% Variance)', fontsize=16, fontweight='bold')

# Make tick labels bold
plt.xticks(fontsize=14, fontweight='bold')
plt.yticks(fontsize=14, fontweight='bold')

# Make axes lines thicker
ax = plt.gca()
for spine in ax.spines.values():
    spine.set_linewidth(3)

# Save figure
#plt.savefig(f'pcoa_taxon_PC{x}_vs_PC{y}_{region}.png', bbox_inches='tight', dpi=600)
plt.savefig(f'pcoa_taxon_PC{x}_vs_PC{y}_all_region_all_taxon.png', bbox_inches='tight', dpi=600)
plt.show()

### For only Bissau molecular form

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.cm import get_cmap
import matplotlib.font_manager as font_manager

# Extract population labels as list
population_labels = pop_info['admin1_iso'].tolist()

# Define components to plot
x = 1
y = 2

# Unique populations, sorted for consistent color mapping
unique_populations = sorted(np.unique(population_labels))

# Get a colormap with enough unique colors
cmap = get_cmap('tab20')  # 'Set3' is another good alternative for soft, distinct colors

# Generate color map: map each population to a distinct color
#color_map = {pop: cmap(i % 20) for i, pop in enumerate(unique_populations)}  # tab20 has 20 colors
color_map = {
    'GM-L': 'blue',
    'GM-M': 'green',
    'GM-W': 'orange',
    'GM-N': 'purple',
    'GM-U': 'red'
}
region_color_map = {
    "Basse": "orange",
    "Kanifing": "blue",
    "Janjanbureh" : "red",
    "Kuntaur" : "yellow",
    "Mansakonko" : "#33ECF5",
    "Brikama" : "green",
    "Kerewan" : "purple"
}
# Plotting
plt.figure(figsize=(12, 10))

for pop in unique_populations:
    mask = np.array(population_labels) == pop
    plt.scatter(
        pcoa_results.samples[f'PC{x}'][mask],
        pcoa_results.samples[f'PC{y}'][mask],
        label=pop,
        edgecolor='k',
        s=90,
        linewidths=1.5,
        color=color_map[pop]
    )

# Plot details
plt.title(f'PCoA of Anopheles gambiae complex in sweep regions (16 823 SNPs)')
plt.xlabel(f'PC{x} ({pcoa_results.proportion_explained[x-1]*100:.2f}% Variance)')
plt.ylabel(f'PC{y} ({pcoa_results.proportion_explained[y-1]*100:.2f}% Variance)')
plt.legend(title='Population', loc='best',   bbox_to_anchor=(1, 1), prop=font_manager.FontProperties(weight='bold', size=16))

# Save and show
plt.savefig(f'pcoa_bissau_all_region_PC{x}_vs_PC{y}.png', bbox_inches='tight', dpi=300)
plt.savefig(f'pcoa_bissau_all_region_PC{x}_vs_PC{y}.pdf', bbox_inches='tight', dpi=300)
plt.show()

## Perform PCoA-UMAP


## Optional: adding UMAP for Supervised Dimension Reduction and Metric Learning step

In [None]:
!pip uninstall -y umap
!pip install -q umap-learn
import umap.umap_ as umap

In [None]:
#import umap
import seaborn as sns
import matplotlib.pyplot as plt
# Get  the pcoa resuts
# Convert to NumPy array for UMAP
pcoa_coords = pcoa_results.samples.iloc[:, :20]  # Select first 20 PCoA axes


# Apply UMAP on PCoA coordinates

umap_model = umap.UMAP(n_neighbors=15, min_dist=0.5, n_components=2, metric="euclidean") # min_dist=0.5 for fine features on UMAP more visible
umap_results = umap_model.fit_transform(pcoa_coords)  # Using the same genotype data (PCoA transformed)

# Map the populations to your PCoA and UMAP results
umap_df = pd.DataFrame(umap_results, columns=['UMAP1', 'UMAP2'])

# Assuming you have a DataFrame `taxon_sample` that includes sample IDs and their populations
taxon_sample
# Create a mapping from sample IDs to population
#sample_to_pop = dict(zip(taxon_sample['sample_id'], taxon_sample['taxon']))
sample_to_pop = dict(zip(taxon_sample['sample_id'], taxon_sample['pop']))
sample_ids = taxon_sample['sample_id'].values
# Map the populations to your PCoA and UMAP results
umap_df['population'] = [sample_to_pop[sample_id] for sample_id in sample_ids]  # same for umap_df
#
# Ensure population column has no NaN values (to avoid issues)
umap_df = umap_df.dropna(subset=['population'])

# Define custom colors for populations
pop_color_map = {
    'coluzzii': 'orange',
    'bissau': 'purple',
    'gambiae': 'blue',
    'arabiensis': 'green',
    'melas': 'red'
}

# Map population values to colors
umap_df['color'] = umap_df['population'].map(color_map )

# 🎨 Scatter plot using explicit colors
plt.figure(figsize=(10, 8))
sns.scatterplot(x="UMAP1", y="UMAP2", hue="population", palette=color_map, data=umap_df, alpha=0.8)

plt.xlabel("UMAP 1")
plt.ylabel("UMAP 2")
plt.title(f"PCoA-UMAP in sweep region")

plt.legend(title="Population")
plt.savefig(f'pcoa_umap_PC1_vs_PC2_all_region.png')
plt.show()
# save the fig
#Save umpa
umap_df.to_csv(f'umap_taxon__PC1_vs_PC2_all_region.csv', sep=',', index=True)


## Perform UMAP on genotype data

In [None]:

# Apply UMAP on Genotype data

umap_model = umap.UMAP(n_neighbors=15, min_dist=0.5, n_components=2, metric="euclidean") 
umap_results = umap_model.fit_transform(geno_alt.T)  # Using the same genotype data (PCoA transformed)

# Map the populations to your PCoA and UMAP results
umap_df = pd.DataFrame(umap_results, columns=['UMAP1', 'UMAP2'])

# Assuming you have a DataFrame `taxon_sample` that includes sample IDs and their populations
# Create a mapping from sample IDs to population
sample_to_pop = dict(zip(taxon_sample['sample_id'], taxon_sample['taxon']))
#sample_to_pop = dict(zip(taxon_sample['sample_id'], taxon_sample['pop']))
sample_ids = taxon_sample['sample_id'].values
# Map the populations to your PCoA and UMAP results
umap_df['population'] = [sample_to_pop[sample_id] for sample_id in sample_ids]  # same for umap_df
umap_df['sample_id'] = sample_ids

# Ensure population column has no NaN values (to avoid issues)
umap_df = umap_df.dropna(subset=['population'])
umap_df['population'] = umap_df['population'].replace(
    {
        "coluzzii": "An. coluzzii",
        "gambiae": "An. gambiae s.s",
        "bissau": "Bissau",
        "arabiensis": "An. arabiensis",
        "melas": "An. melas"
    }
)
# save umap_df
umap_df.to_csv(f'umap_taxon_all_region_PC1_vs_PC2_all_taxon.csv', sep=',', index=True)

In [None]:
# Prepare figure
plt.figure(figsize=(12, 10))

# Get unique populations
unique_pops = umap_df['population'].unique()

# Define custom colors for populations
pop_color_map = {
    'An. coluzzii': 'orange',
    'Bissau': 'purple',
    'An. gambiae s.s': 'blue',
    'An. arabiensis': 'green',
    'An. melas': 'red'
}

# Plot each population with white edge (using matplotlib directly)
for pop in unique_pops:
    subset = umap_df[umap_df['population'] == pop]
    plt.scatter(
        subset['UMAP1'],
        subset['UMAP2'],
        label=pop,
        color=pop_color_map[pop],
        edgecolor='white',       # White outline
        s=200,
        linewidths=1.5
    )

# Customize legend
legend = plt.legend(title='Population', fontsize=21, bbox_to_anchor=(1.45, 1))
plt.setp(legend.get_title(), fontweight='bold')

for text in legend.get_texts():
    pop_name = text.get_text()
    text.set_fontweight('bold')
    text.set_fontsize(22)
    if pop_name != 'Bissau':
        text.set_fontstyle('italic')

# Bold axis labels
plt.xlabel("UMAP 1", fontsize=16, fontweight='bold')
plt.ylabel("UMAP 2", fontsize=16, fontweight='bold')

# Bold tick labels
plt.xticks(fontsize=14, fontweight='bold')
plt.yticks(fontsize=14, fontweight='bold')

# Thicker axis lines
ax = plt.gca()
for spine in ax.spines.values():
    spine.set_linewidth(3)

# Title
plt.title(f"UMAP Projection in sweep regions", fontsize=18, fontweight='bold')

# Save
plt.savefig(f'umap_PC1_vs_PC2_all_regions_all_taxon.png', bbox_inches='tight', dpi=600)
plt.show()


In [None]:
#  dowload all results
from google.colab import files
import glob
# collect all png files
png_files = glob.glob('*.png')
csv_files = glob.glob('*.csv')
# download all png files
for file in png_files:
    files.download(file)
for file in csv_files:
    files.download(file)