# Haplotype cluster analysis

In [None]:
!pip install -qq malariagen_data
import malariagen_data
import numpy as np
import pandas as pd
import allel
import zarr
import matplotlib.pyplot as plt
import seaborn as sns

ag3 = malariagen_data.Ag3()
ag3

In [None]:
# Select regions detected under sweep selection
regions = ["2R:28,430,000-28,615,000", "3R:28,400,000-28,600,000", "3R:32,000,000-32,080,000", "X:15,130,000-15,324,000"]

# Load the sample metadata for the Gambia
sample_metadata = ag3.sample_metadata(sample_query="country=='Gambia, The' and taxon in ['coluzzii', 'gambiae', 'bissau']")

# Define the subsets of samples for each taxon
coluzzii_subset = sample_metadata.query("taxon=='coluzzii'").groupby("admin1_iso").apply(lambda x: x.sample(n=min(20, len(x)))).reset_index(drop=True)
coluzzii_subset.groupby('admin1_iso').size()
gambiae_subset = sample_metadata.query("taxon=='gambiae'")
bissau_subset_GM_N = sample_metadata.query("taxon=='bissau'and admin1_iso == 'GM-N'").sample(n=50, random_state=42)
bissau_subset = sample_metadata.query("taxon=='bissau' and admin1_iso not in ['GM-U', 'GM-N']")
# Take 40 indiv in each taxon
"""coluzzii_subset = sample_metadata.query("taxon=='coluzzii'").sample(n=40, random_state=42)
gambiae_subset =sample_metadata.query("taxon=='gambiae'").sample(n=40, random_state=42)
bissau_subset = sample_metadata.query("taxon=='bissau'").sample(n=40, random_state=42)
"""
samples_df = pd.concat([coluzzii_subset, gambiae_subset, bissau_subset_GM_N, bissau_subset])
#rename sample
samples_df["sample_pop"] = samples_df["taxon"] + "_" + samples_df["admin1_iso"]
# Samples list
samples_list = list(samples_df['sample_id'].values)

# Select a region for analysis
region = regions[2]
# Compute the pairwise haplotype distance
dist, phased_sample, n_snp = ag3.haplotype_pairwise_distances(
    region=region,
    analysis="gamb_colu",
    sample_query=f"country=='Gambia, The' and sample_id in {samples_list}",
)


In [None]:
import scipy.cluster.hierarchy as sch
from scipy.spatial.distance import squareform
import matplotlib.patches as mpatches

# Align sample metadata with haplotypes.
df_samples_phased = (
    samples_df.set_index("sample_id").loc[phased_sample.tolist()].reset_index()
)

# Ensure correct distance matrix format
if len(dist.shape) == 1:  # If dist is condensed (1D)
    distance_matrix_sq = squareform(dist)
else:
    distance_matrix_sq = dist  # Already square

# Repeat metadata so there is one row per haplotype
df_haps = pd.DataFrame(np.repeat(df_samples_phased.values, 2, axis=0))
df_haps.columns = df_samples_phased.columns  # Restore column names

# Ensure haplotype IDs match the distance matrix shape
haplotype_ids = list(df_haps["sample_pop"].values)  # Use the repeated sample IDs for haplotypes

# Validate matrix and haplotype list size
assert distance_matrix_sq.shape == (len(haplotype_ids), len(haplotype_ids)), \
    f"Mismatch: Distance matrix shape {distance_matrix_sq.shape} vs. haplotype list {len(haplotype_ids)}"

# Perform hierarchical clustering using the condensed form of distances
linkage_matrix = sch.linkage(dist, method='ward')

# Convert distance matrix to DataFrame with correct haplotype IDs
distance_df = pd.DataFrame(distance_matrix_sq, index=haplotype_ids, columns=haplotype_ids)


# Create color mapping based on taxon
taxon_palette = {
    "coluzzii": "orange",
    "gambiae": "blue",
    "bissau": "purple"
}

# Convert taxon labels to colors
df_haps["color"] = df_haps["taxon"].map(taxon_palette)

# Create row_colors as a Pandas DataFrame (ensuring index matches the heatmap data)
row_colors = pd.DataFrame(df_haps["color"])
row_colors.index = haplotype_ids  # Align with heatmap index
row_colors.columns = ["Taxon"]  # Name the column for better visualization

# Create enhanced legend handles with proper formatting
legend_label_map = {
    "gambiae": r"$\boldsymbol{An.\ gambiae\ s.s}$ ",
    "coluzzii": r"$\boldsymbol{An.\ coluzzii}$",
    "bissau": r"$\mathbf{Bissau}$",
}

legend_handles = [
    mpatches.Patch(color=color, label=legend_label_map.get(label, label))
    for label, color in taxon_palette.items()
]

# Plot heatmap with hierarchical clustering
g = sns.clustermap(
    pd.DataFrame(distance_matrix_sq, index=haplotype_ids, columns=haplotype_ids),
    row_linkage=linkage_matrix,
    col_linkage=linkage_matrix,
    cmap="viridis",
    figsize=(17, 15),
    row_colors=row_colors,
    col_colors=row_colors,
    cbar_pos=(0, 0.8, 0.02, 0.18)  # (x, y, width, height)
)

# Add enhanced legend
legend = g.ax_heatmap.legend(handles=legend_handles, title='Population',
                            bbox_to_anchor=(1.2, 1.2), loc='upper left',
                            fontsize=18, title_fontsize=20)
plt.setp(legend.get_title(), fontweight='bold')

g.fig.suptitle(fr"Haplotype Heatmap-tree using {region} region",
               fontsize=20, fontweight='bold', y=1.02)

# Save with high quality
g.fig.savefig(f'heatmap_tree_{region}.png', dpi=600, bbox_inches='tight')
#g.fig.savefig(f'heatmap_tree_{region}.svg', bbox_inches='tight', format="svg")

# Show the plot
plt.show()

Bissau molecular form display unique cluster whitout any shared clusters with their sister, in the region 3R:32,000,000-32,080,000 where Ors genes are located.

## Diploid cluster

In [None]:
# Load genome features for the Ag3
genome_features = ag3.genome_features(region="3R")

def get_gene_name(genome_features, ID):
    gene_features = genome_features.query(f'ID=="{ID}"')["Parent"].values[0]
    if gene_features:
        gene_name = genome_features.query(f'ID=="{gene_features}"')["Name"].values[0]
        if gene_name:
            return gene_name
        else:
            return  gene_features
    else:
        return None

In [None]:
# advanced diplotype  cluster
snp_transcript = ["AGAP009390-RA","AGAP009391-RA", "AGAP009392-RA",
                  "AGAP009393-RA", "AGAP009394-RA", "AGAP009395-RA",
                  "AGAP009396-RA", "AGAP009397-RA", "AGAP009398-RA"]


sample_metadata = ag3.sample_metadata(sample_query="country=='Gambia, The' and taxon in ['coluzzii', 'gambiae', 'bissau']")

#coluzzii_subset = sample_metadata.query("taxon=='coluzzii'").groupby("admin1_iso").apply(lambda x: x.sample(n=min(20, len(x)))).reset_index(drop=True)
coluzzi_susbet = sample_metadata.query("taxon=='coluzzii'").sample(n=100, random_state=42)
gambiae_subset = sample_metadata.query("taxon=='gambiae'")
bissau_subset = sample_metadata.query("taxon=='bissau' and admin1_iso != 'GM-U'").groupby('admin1_iso').apply(lambda x: x.sample(n=min(20, len(x)))).reset_index(drop=True)


samples_df = pd.concat([coluzzi_susbet, gambiae_subset, bissau_subset])
samples_list = samples_df['sample_id'].tolist()

if isinstance(snp_transcript, pd.Series):
    print("Not a good format")
    snp_transcript = snp_transcript.tolist()
for transcript in snp_transcript:
  print(transcript)
  gene_name = get_gene_name(genome_features, transcript)
  print(gene_name)
  ag3.plot_diplotype_clustering_advanced(
        region="3R:32000000-32080000",
        site_mask="gamb_colu",
        #sample_query="country=='Gambia, The' and taxon in ['coluzzii', 'gambiae', 'bissau']",
        sample_query=f"country=='Gambia, The' and sample_id in {samples_list}",
        heterozygosity=True,
        snp_transcript=transcript,
        snp_filter_min_maf=0.10,
        #cnv_region="3R:32000000-32080000",
        snp_query="effect == 'NON_SYNONYMOUS_CODING'",
        color='taxon',
        marker_size=6,
        linkage_method="complete",
        title=f"<i>Anopheles gambiae</i> complex Diplotype clustering with {gene_name} SNPs frequencies",
        )