# GWSS analysis using iHs

In [None]:
!pip install -qq malariagen_data
import malariagen_data
import numpy as np
import pandas as pd
import allel
import zarr
import matplotlib.pyplot as plt
import seaborn as sns

ag3 = malariagen_data.Ag3()
ag3

In [None]:
def run_ihS(cohort, contig,  window_size=200, country=None):
  """
  function that compute iHs value for a given cohort and contig
  """
  #for contig in contigs:
  pos, ihs = ag3.ihs_gwss(
      contig=contig,
      #sample_query=f"cohort_admin1_year == '{cohort}'" if per_cohort else f"country=='{country} and taxon=='{cohort}'",
      sample_query=f"country=='{country} and taxon=='{cohort}'",
      window_size=window_size,
      analysis="gamb_colu",
      min_cohort_size=10,
      )
  return pos, ihs # x: ndarray ihs: ndarray

def run_ihs_per_cohort(cohort_group , cohort, taxon, country, contig,  window_size=200):
  """
  function that compute iHs value for a given cohort and contig
  """
  pos, ihs = ag3.ihs_gwss(
      contig=contig,
      sample_query=f"country=='{country}' and taxon=='{taxon}' and {cohort_group}=='{cohort}'",
      window_size=window_size,
      analysis="gamb_colu",
      min_cohort_size=10,
      )
  return pos, ihs # x: ndarray ihs: ndarray



#GSS with iHs
def run_ihS_by_period(country, taxon, year, contig,  window_size=200):
  """
  function that compute iHs value for a given cohort and contig
  """
  #for contig in contigs:
  pos, ihs = ag3.ihs_gwss(
      contig=contig,
      sample_query=f"country=='{country}' and taxon=='{taxon}' and year=={year}",
      window_size=window_size,
      analysis="gamb_colu",
      min_cohort_size=10,
      )
  return pos, ihs # x: ndarray ihs: ndarray

def  sorted_iHs_data_by_year(ihs_taxon_cohort_data):
  """
  function to sort iHs data by cohort and year
  """
  # Extract the year from Cohort column
  ihs_taxon_cohort_data_sorted = ihs_taxon_cohort_data.sort_values(by=['year'], ascending=[True])

  return ihs_taxon_cohort_data_sorted

## iHs analysis by taxon

In [None]:
contigs = ['2L', '2R', '3L', '3R', 'X']
taxon='coluzzii'
# Let's use cohorts with at least 30 samples for good resutl
size_info = ag3.sample_metadata(sample_query=f"taxon=='{taxon}' and country=='Gambia, The'").groupby('cohort_admin1_year').size().reset_index()
size_info.columns = ['cohort_admin1_year', 'size']
cohorts = []
for index, row in size_info.iterrows():
  if row['size'] >= 30:
    cohorts.append(row['cohort_admin1_year'])
    print(f"{row['cohort_admin1_year']}: {row['size']}")
cohorts

ihs_final_df = pd.DataFrame()
for  cohort in cohorts:
  for contig in contigs:
      pos, ihs = run_ihS(cohort, contig, window_size=200, per_cohort=True, country=None)
      ihs_ref = []
      ihs_alt = []
      for i in ihs:
        ihs_ref.append(i[0])
        ihs_alt.append(i[1])

      for i in range(ihs.shape[1]):
        ihs_perc = ihs[:, i]
      ihs_perc
      # create dataFrame for iHs result
      ihs_df = pd.DataFrame({
          "chrom": contig,
          "Chr_pos": pos,
          "ihs": ihs_perc,
          "ihs_ref": ihs_ref,
          "ihs_alt": ihs_alt,
          "Cohort": cohort,
      })
      ihs_final_df = pd.concat([ihs_final_df, ihs_df], ignore_index=True)
  ihs_final_df
# remove row not in cohorts
ihs_final_df = ihs_final_df[ihs_final_df['Cohort'].isin(cohorts)]
# Save the data
ihs_final_df.to_csv(f'ihs_{taxon}_cohort_data.csv')


##iHs analysis by cohort

In [None]:
# By cohort
contigs = ['2L', '2R', '3L', '3R', 'X']
taxon='coluzzii'
country='Gambia, The'
# Let's use cohorts with at least 30 samples for good resutl
size_info = ag3.sample_metadata(sample_query=f"taxon=='{taxon}' and country=='Gambia, The'").groupby('admin1_iso').size().reset_index()
size_info.columns = ['admin1_iso', 'size']
cohorts = []
for index, row in size_info.iterrows():
  if row['size'] >= 10:
    cohorts.append(row['admin1_iso'])
    print(f"{row['admin1_iso']}: {row['size']}")
cohorts

ihs_final_df = pd.DataFrame()
list_cohort_group = ['country_iso', 'admin1_name', 'admin1_iso', 'admin2_name', 'cohort_admin1_year']
print("cohort group list:")
for i in list_cohort_group:
  print(i)
cohort_group = input("Enter the cohort group to use: ")
if cohort_group not in list_cohort_group or cohort_group is None:
  print("Invalid cohort group. Please choose from the list.")
  cohort_group = input("Enter the cohort group to use: ")
for  cohort in cohorts:
  print(f"{cohort} ihs analysis")
  for contig in contigs:
      print(f"{contig} contig analysis...")
      pos, ihs = run_ihs_per_cohort(cohort_group ,cohort, taxon, country, contig,  window_size=200)
      print(f"{contig} contig analysis done.")
      ihs_ref = []
      ihs_alt = []
      for i in ihs:
        ihs_ref.append(i[0])
        ihs_alt.append(i[1])

      for i in range(ihs.shape[1]):
        ihs_perc = ihs[:, i]
      ihs_perc
      # create dataFrame for iHs result
      ihs_df = pd.DataFrame({
          "chrom": contig,
          "Chr_pos": pos,
          "ihs": ihs_perc,
          "ihs_ref": ihs_ref,
          "ihs_alt": ihs_alt,
          "Cohort": cohort,
          "Taxon": taxon,
          "Country": country
      })
      ihs_final_df = pd.concat([ihs_final_df, ihs_df], ignore_index=True)
  ihs_final_df
# remove row not in cohorts
ihs_final_df = ihs_final_df[ihs_final_df['Cohort'].isin(cohorts)]
# Save the data
ihs_final_df.to_csv(f'ihs_{taxon}_cohort_data.csv')


##iHs interactive plot

In [None]:
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go

# Load your data
"""ihs_final_df = pd.read_csv('/content/ihs_gambiae_cohort_data.csv')
ihs_final_df = ihs_final_df.drop(columns=['Unnamed: 0'])
ihs_df = ihs_final_df.query('Cohort=="GM-W_gamb_2019" and chrom == "2L"')
gambiae = pd.read_csv("/content/ihs_gambiae_data.csv")
coluzzii_df =pd.read_csv("/content/ihs_coluzzii_data.csv")
bissau_df = pd.read_csv("/content/ihs_bissau_data.csv")
ihs_final_df = pd.concat([gambiae, coluzzii_df, bissau_df])
ihs_df = ihs_final_df.drop(columns=['Unnamed: 0'])
ihs_df = ihs_final_df.query('Cohort=="coluzzii" and chrom == "2R"')"""

ihs_df_init = pd.read_csv("/content/ihs_bissau_data.csv")
high_ihs_dict = {}
#ihs_df_init = sorted_iHs_data(ihs_df_init)
cohorts = ihs_df_init['Cohort'].unique()
chromosome = "3R"
for cohort in cohorts:
  print(cohort)
  ihs_df = ihs_df_init.query(f'chrom=="{chromosome}" and Cohort=="{cohort}"')
  # Ensure data has the required columns
  data = ihs_df[['Chr_pos', 'ihs']]

  # Add absolute iHS values
  data['abs_iHS'] = np.abs(data['ihs'])

  # Apply log transformation
  data['ihs_log'] = np.log(data['abs_iHS'])


  # Define the threshold for high iHS values
  threshold = data['abs_iHS'].quantile(0.99)  # Top 1% as threshold
  #threshold = 4
  #threshold = np.percentile(data['ihs_log'], 99)
  # use mean to define threshold
  #threshold = data['abs_iHS'].mean() + 3 * data['abs_iHS'].std()



  # Highlight points above the threshold
  high_ihs = data[data['abs_iHS'] > threshold]
  #high_ihs = data[data['ihs_log'] > threshold]
  high_ihs_dict[cohort] = high_ihs
  # Create the scatter plot
  fig = go.Figure()

  # Add all points
  fig.add_trace(go.Scatter(
      x=data['Chr_pos'],
      y=data['abs_iHS'],
      #y=data['ihs_log'],
      mode='markers',
      marker=dict(color='blue', size=6),
      name='All Points'
  ))

  # Add high iHS points
  fig.add_trace(go.Scatter(
      x=high_ihs['Chr_pos'],
      y=high_ihs['abs_iHS'],
      #y=high_ihs['ihs_log'],
      mode='markers+text',
      marker=dict(color='red', size=8),
      name='High iHS (> Threshold)',
      #text=high_ihs['Chr_pos'],  # Annotate with Chr_pos
      textposition="top center"
  ))

  # Add the threshold line
  fig.add_trace(go.Scatter(
      x=[data['Chr_pos'].min(), data['Chr_pos'].max()],
      y=[threshold, threshold],
      mode='lines',
      line=dict(color='black', dash='dash'),
      name=f'Threshold: {threshold:.2f}'
  ))

  # Customize layout
  fig.update_layout(
      title=f'iHS Plot for Chromosome {chromosome} of {cohort}',
      xaxis_title='Position on Chromosome',
      yaxis_title='|iHS|',
      #yaxis_title='|iHS_log|',
      legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
      template="plotly_white"
  )

  # Show the plot
  fig.show()




In [None]:
# get number SNP in high_df
cohorts
#high_ihs = high_ihs_dict['GM-N_biss_2021']
high_ihs
selected = high_ihs.query("2.6e+7 < Chr_pos <3e+7")
selected["contig"] = "2R"
selected
# get number of SNP in selected
len(selected['Chr_pos'].values)

start = int(selected['Chr_pos'].min())
end = int(selected['Chr_pos'].max())
print(start, end)

In [None]:
#list of gene for region under selection in 2L chromosome in gambiae taxon
formated_start = f"{start:,}"
formated_end = f"{end:,}"
print(formated_start, formated_end)
genes_list = (
    ag3.genome_features(region=f"2R:{formated_start}-{formated_end}").query("type!='chromosome'")
    #ag3.genome_features(region="X:14,000,000-15,000,000").query("type=='gene'")
    [["contig", "ID","type", "start", "end", "Name", "description"]]
    .set_index("ID")
)
genes_list
genes_list["window"] = f"{start}-{end}"
genes_list.reset_index(inplace=True)
genes_list.set_index("window", inplace=True)
#genes_list = genes_list.drop(columns=['index'])
len(genes_list)
genes_list.to_excel('genes_list_ihs_bissau_2L.xlsx')
genes_name = genes_list['Name'].values
genes_name
# Remove nan
genes_name = genes_name[~pd.isnull(genes_name)]
genes_name
len(genes_name)
genes_name_df = genes_list.query("Name in @genes_name")
genes_name_df.to_excel("ihs_gene_list_coluzzii.xlsx")

# desbribe gene
genes_describe = genes_list['description'].values
genes_describe = genes_describe[~pd.isnull(genes_describe)]
genes_describe = genes_list.query("description in @genes_describe")
genes_describe
len(genes_describe)
genes_describe.to_excel("genes_describe_bissau_3R.xlsx")
genes_describe['Name'].values

#dowload
from google.colab import files
files.download('genes_list_ihs_coluzzii_2L.xlsx')
files.download("genes_describe_bissau_3R.xlsx")

In [None]:
import pandas as pd

# Load SNP data
snp_data = selected[['contig', 'Chr_pos']]
len(snp_data)
genes_list.reset_index(inplace=True)

# Load Gene list
gene_data = genes_list[['contig', 'type','start', 'end', 'ID', 'Name', 'description']]

# Merge SNPs with genes based on position within gene start and end
merged_data = pd.merge(snp_data, gene_data, on="contig", how="left")
filtered_data = merged_data[
    (merged_data['Chr_pos'] >= merged_data['start']) &
    (merged_data['Chr_pos'] <= merged_data['end'])
]
len(filtered_data)
filtered_data

# Keep only relevant columns
result = filtered_data[['contig', 'Chr_pos', 'ID','Name', 'description']]
result
result.to_excel("SNP_gene_list_coluzzii.xlsx")
from google.colab import files
files.download("SNP_gene_list_coluzzii.xlsx")


##iHs plot

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

# Load data
gambiae_df = pd.read_csv('/content/ihs_gambiae_data.csv')
coluzzii_df = pd.read_csv('/content/ihs_coluzzii_cohort_data_with10.csv')
bissau_df = pd.read_csv('/content/ihs_bissau_cohort_data.csv')

# Select dataset
ihs_final_df = coluzzii_df.copy()
ihs_final_df = sorted_iHs_data(ihs_final_df)

# Filter for a specific chromosome
taxon = 'coluzzii'
chromosomes = ihs_final_df['chrom'].unique()

for chromosome in chromosomes:
    # Filter for a specific chromosome
  ihs_final_df_i = ihs_final_df.query(f"chrom == '{chromosome}'")

  # Ensure required columns exist
  data = ihs_final_df_i[['chrom', 'Chr_pos', 'ihs', 'Cohort']].copy()

  # Compute absolute iHS values
  data['abs_iHS'] = np.abs(data['ihs'])

  # Get unique cohorts and chromosomes
  cohorts = data['Cohort'].unique()
  chromosomes = data['chrom'].unique()

  # Define color map for chromosomes
  colors = plt.cm.tab20.colors
  color_map = {chrom: colors[i % len(colors)] for i, chrom in enumerate(chromosomes)}

  # Create subplots
  nrows = len(cohorts)
  fig, axes = plt.subplots(nrows=nrows, figsize=(12, 4 * nrows), sharex=False)

  # Handle case where there is only one cohort
  if nrows == 1:
      axes = [axes]

  # Dictionary to collect legend handles for a global legend
  legend_handles = {}

  # Plot each cohort
  for idx, cohort in enumerate(cohorts):
      ax = axes[idx]
      cohort_data = data[data['Cohort'] == cohort]

      for chrom in chromosomes:
          chrom_data = cohort_data[cohort_data['chrom'] == chrom]

          if chrom_data.empty:
              continue

          # Compute threshold
          threshold = cohort_data['abs_iHS'].quantile(0.99)

          # Convert default color to hex
          default_color = mcolors.to_hex(color_map[chrom])

          # Assign colors: Red for values above threshold, otherwise default color
          point_colors = np.where(chrom_data['abs_iHS'] > threshold, 'red', default_color)

          # Scatter plot with color-coded points
          scatter = ax.scatter(
              chrom_data['Chr_pos'], chrom_data['abs_iHS'],
              color=point_colors, s=10, label=f'Chrom {chrom}'
          )

          # Store only one handle per chromosome for a global legend
          if chrom not in legend_handles:
              legend_handles[chrom] = scatter

      # Threshold line (99th percentile)
      ax.axhline(y=threshold, color='black', linestyle='--', linewidth=1, label='99% Threshold')

      # Set labels and title
      ax.set_ylabel('|iHS|')
      ax.set_title(f'Cohort: {cohort}')

  # Global X-axis label
  fig.supxlabel('Chromosome Position', fontsize=14)

  # Global legend
  fig.legend(handles=legend_handles.values(), labels=legend_handles.keys(),
            loc='upper right', fontsize=12, title="Chromosome", title_fontsize=13)

  # Adjust layout
  plt.tight_layout(rect=[0, 0, 1, 0.95])

  # Save the plot
  #plt.savefig(f'{taxon}_ihs_subplots_by_cohort_{chromosome}.png', dpi=300)
  #plt.savefig(f'{taxon}_ihs_{chromosome}.png', dpi=300)

  # Show the plot
  plt.show()


## iHs by year by taxon

In [None]:
def manhanthan_plot(ihs_taxon_cohort_data):
  # sort the ihs data
  print('Sorting data')
  ihs_taxon_cohort_data_sorted = sorted_iHs_data_by_year(ihs_taxon_cohort_data)

  # List of unique chromosomes
  chromosomes = ihs_taxon_cohort_data_sorted['chrom'].unique()

  # Define color map for cohorts
  cohorts = ihs_taxon_cohort_data_sorted['year'].unique()
  colors = plt.cm.tab20.colors  # Color palette
  color_map = [colors[i % len(colors)] for i in range(len(cohorts))]

  # Determine subplot grid size
  ncols = 2
  nrows = (len(chromosomes) + 1) // ncols

  # Create subplots
  fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(16, 7 * nrows), sharex=False, sharey=False)
  axes = axes.flatten()  # Flatten the axes array for easy indexing

  # Loop through each chromosome and create a plot
  for i, chr in enumerate(chromosomes):
      ax = axes[i]
      chr_data = ihs_taxon_cohort_data_sorted.query(f"chrom == '{chr}'")
      data = chr_data[['year', 'Chr_pos', 'ihs']]
      data['abs_iHS'] = np.abs(data['ihs'])

      # Unique cohorts
      cohorts = data['year'].unique()

      x_ticks = []
      x_labels = []
      current_x = 0

      for idx, cohort in enumerate(cohorts):
          # Filter data for the current cohort
          cohort_data = data[data['year'] == cohort]

          # X values: Positions adjusted by current_x offset
          x = cohort_data['Chr_pos'] + current_x
          y = cohort_data['abs_iHS']

          # Scatter plot for current cohort
          ax.scatter(x, y, color=color_map[idx], s=10, label=f'{cohort}')

          # Add ticks and labels for cohorts
          x_ticks.append(x.median())  # Use median position for tick
          x_labels.append(f'{cohort}')

          # Update current_x for the next cohort
          current_x += cohort_data['Chr_pos'].max() + 50  # Add spacing between cohorts

      # Highlight the threshold line
      threshold = data['abs_iHS'].quantile(0.99)  # Top 1% as threshold
      ax.axhline(y=threshold, color='black', linestyle='--', linewidth=1)

      # Set title for the subplot
      ax.set_title(f'Manhattan Plot of |iHS| Values for Chromosome {chr}', fontsize=12)
      ax.set_ylabel('|iHS|')
      #ax.legend()

      # Set x-axis ticks and labels
      ax.set_xticks(ticks=x_ticks)
      ax.set_xticklabels(labels=x_labels, rotation=45, ha='right')

  # Remove empty subplots if the number of chromosomes is not a multiple of 2
  for j in range(len(chromosomes), len(axes)):
      fig.delaxes(axes[j])

  # Set global x-axis label
  fig.supxlabel('year', fontsize=14)

  # Adjust layout
  plt.tight_layout()

  # Save the plot
  plt.savefig('ihs_per_year.png')

  # Display the plot
  plt.show()

In [None]:
contigs = ['2L', '2R', '3L', '3R', 'X']
# Define our cohort for taxon
# Let's use cohorts with at least 30 samples for good result
taxon='bissau'
country="Gambia, The"
size_info = ag3.sample_metadata(sample_query=f"taxon=='{taxon}' and country=='Gambia, The'").groupby('year').size().reset_index()
size_info.columns = ['year', 'size']
cohorts = []
for index, row in size_info.iterrows():
  if row['size'] >= 30:
    cohorts.append(row['year'])
    print(f"{row['year']}: {row['size']}")
cohorts

if len(cohorts) > 1:
  ihs_final_df = pd.DataFrame()
  for  year in cohorts:
    for contig in contigs:
        pos, ihs = run_ihS_by_period(country, taxon, year, contig, window_size=200)
        ihs_ref = []
        ihs_alt = []
        for i in ihs:
          ihs_ref.append(i[0])
          ihs_alt.append(i[1])

        for i in range(ihs.shape[1]):
          ihs_perc = ihs[:, i]
        ihs_perc
        # create dataFrame for iHs result
        ihs_df = pd.DataFrame({
            "chrom": contig,
            "Chr_pos": pos,
            "ihs": ihs_perc,
            "ihs_ref": ihs_ref,
            "ihs_alt": ihs_alt,
            "year": year,
            "taxon" : taxon
        })
        ihs_final_df = pd.concat([ihs_final_df, ihs_df], ignore_index=True)
    ihs_final_df
  # Save the data
  ihs_final_df.to_csv(f'ihs_{taxon}_cohort_data.csv')
else:
  print('Should have at least 2 cohorts!')

In [None]:
#plot iHs
manhanthan_plot(ihs_final_df)