In [None]:
# Base imports
import os
import pickle
import re

# Compute imports
import numpy as np
import pandas as pd
import scipy
from tqdm.notebook import tqdm, trange

# Plotting imports
import matplotlib
from matplotlib import pyplot as plt
import seaborn as sns
from plotly import express as px

# ML import
from sklearn.decomposition import NMF
from sklearn.metrics import mean_squared_error, median_absolute_error
from sklearn.metrics.pairwise import cosine_similarity


import multiprocessing
from multiprocessing import Pool


matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
matplotlib.rcParams['svg.fonttype'] = 'none'
matplotlib.rcParams['font.sans-serif'] = 'Arial'
matplotlib.rcParams['font.family'] = 'sans-serif'
sns.set_style('ticks')
matplotlib.rcParams['text.color'] = '#000000'
matplotlib.rcParams['axes.labelcolor'] = '#000000'
matplotlib.rcParams['xtick.color'] = '#000000'
matplotlib.rcParams['ytick.color'] = '#000000'

In [None]:
from Bio import AlignIO
from Bio.Align import MultipleSeqAlignment
from Bio.Align.Applications import MafftCommandline
from io import StringIO
from Bio import SeqIO
import tempfile
from Bio import AlignIO
from Bio.Phylo.TreeConstruction import DistanceCalculator, DistanceTreeConstructor
from Bio import Phylo

In [None]:
DF_GENES = '../../data/processed/cd-hit-results/sim80/Ebacter_strain_by_gene.pickle.gz'
ENRICHED_METADATA = '../../data/metadata/enriched_metadata.csv'
DF_EGGNOG = '../../data/processed/df_eggnog.csv'

DF_ACC_COMPLETE = '../../data/processed/CAR_genomes/df_acc_complete.pickle'
DF_RARE_COMPLETE = '../../data/processed/CAR_genomes/df_rare_complete.pickle'
DF_CORE_COMPLETE = '../../data/processed/CAR_genomes/df_core_complete.pickle'
BAKTA_ANNOTATIONS = '../../data/processed/bakta_gene_annotations.csv'

In [None]:
bakta_annotations = pd.read_csv(BAKTA_ANNOTATIONS, index_col=0)

In [None]:
df_rare = pd.read_pickle(DF_RARE_COMPLETE)
df_acc = pd.read_pickle(DF_ACC_COMPLETE)
df_core = pd.read_pickle(DF_CORE_COMPLETE)

In [None]:
metadata = pd.read_csv(ENRICHED_METADATA, index_col=0, dtype='object')

display( metadata.shape, metadata.head())

In [None]:
# Load in (full) P matrix
df_genes = pd.read_pickle(DF_GENES)

# Filter metadata for Complete sequences only
metadata_complete = metadata[metadata.genome_status == 'Complete'] # filter for only Complete sequences

# Filter P matrix for Complete sequences only
df_genes_complete = df_genes[metadata_complete.genome_id].copy()
df_genes_complete.fillna(0, inplace=True) # replace N/A with 0
df_genes_complete = df_genes_complete.sparse.to_dense().astype('int8') # densify & typecast to int8 for space and compute reasons
inCompleteseqs = df_genes_complete.sum(axis=1) > 0 # filter for genes found in complete sequences
df_genes_complete = df_genes_complete[inCompleteseqs]

df_genes_complete.shape

In [None]:
# Load in eggNOG annotations
df_eggnog = pd.read_csv(DF_EGGNOG, index_col=0)
df_eggnog.fillna('-', inplace=True)

display(
    df_eggnog.shape,
    df_eggnog.head()
)

In [None]:
core_genes = df_core.index

In [None]:
acc_genes = df_acc.index
rare_genes = df_rare.index

# Categorization of core-genes

## Size of core-genome

In [None]:
len(df_core)

## Proportion of genes in each genome

In [None]:
results = []
for genome in df_core.columns:
    core_count = df_core.loc[:, genome].sum()
    acc_count = df_acc.loc[:, genome].sum()
    rare_count = df_rare.loc[:, genome].sum()
    results.append({"Genome":genome, "Core":core_count, "Acc":acc_count, "Rare":rare_count})
df_results = pd.DataFrame(results)

In [None]:
import matplotlib.pyplot as plt

df = df_results.sort_values('Core', ascending=False)

# Calculate proportions
df['Total'] = df['Core'] + df['Acc'] + df['Rare']
df['Core_Proportion'] = df['Core'] / df['Total']
df['Accessory_Proportion'] = df['Acc'] / df['Total']
df['Rare_Proportion'] = df['Rare'] / df['Total']

# Plotting
fig, ax = plt.subplots(figsize=(12, 5))

# Stacked area plot
df.plot(
    x='Genome',
    y=['Core', 'Acc', 'Rare'],
    kind='area',
    stacked=True,
    ax=ax,
    color= ['#B7B2D8', '#FF8A1A', '#66A9D0']
)

# Set limits for x-axis to fill the area
ax.set_xlim(-0.5, len(df) - 0.5)  # Adjust based on the number of genomes

# Adding labels and title
ax.set_ylabel('Gene Count')
ax.set_title('Number of Core, Accessory, and Rare Genes in Each Genome')

plt.xticks(rotation=45)
plt.tight_layout()
plt.savefig('../../images/core_genome_figs/gene_proportions.svg')
plt.show()


## Cog-categories of core

In [None]:
cog_colors = {
    'J': '#ff0000',
    'A': '#c2af58',
    'K': '#ff9900',
    'L': '#ffff00',
    'B': '#ffc600',
    'D': '#99ff00',
    'Y': '#493126',
    'V': '#ff008a',
    'T': '#0000ff',
    'M': '#9ec928',
    'N': '#006633',
    'Z': '#660099',
    'W': '#336699',
    'U': '#33cc99',
    'O': '#00ffff',
    'C': '#9900ff',
    'G': '#805642',
    'E': '#ff00ff',
    'F': '#99334d',
    'H': '#727dcc',
    'I': '#5c5a1b',
    'P': '#0099ff',
    'Q': '#ffcc99',
    'R': '#ff9999',
    'S': '#d6aadf'
}

In [None]:
cogs = df_eggnog.loc[core_genes].COG_category.values

unique_cogs = set([x for x in cogs if len(x) == 1 and x != '-'])
cog_counts = {x: 0 for x in unique_cogs}
for cog in cogs:
    for char in cog:
        if char == '-': char = 'S'
        cog_counts[char] += 1

In [None]:
import plotly.graph_objects as go
import plotly.express as px
import pandas as pd
from plotly.subplots import make_subplots

# Create the subplots layout (1 row, 2 columns, with 2 rows in total)
fig = make_subplots(
    rows=2, cols=2,
    subplot_titles=['Core Genome', 'Acc Genome', 'Rare Genome'],
    specs=[[{'type': 'sunburst'}, {'type': 'sunburst'}],
           [{'type': 'sunburst'}, None]],  # None for the bottom-right empty slot
    horizontal_spacing=0.05,  # Decrease space between columns
    vertical_spacing=0,  # Decrease space between rows
)

# Color map and categories setup
colors = {'METABOLISM': 'green', 'CELLULAR PROCESSES AND SIGNALING': 'gold', 
          'INFORMATION STORAGE AND PROCESSING': 'red', 'POORLY CHARACTERIZED': 'grey'}

categories = {
    'INFORMATION STORAGE AND PROCESSING': ['J', 'A', 'K', 'L', 'B'],
    'CELLULAR PROCESSES AND SIGNALING': 'D Y V T M N Z W U O'.split(),
    'METABOLISM': 'C G E F H I P Q'.split(),
    'POORLY CHARACTERIZED': 'R S'.split()
}
categories = {x: key for key, value in categories.items() for x in value}

# Loop through the categories
for i, (category, genes) in enumerate(zip(['core', 'acc', 'rare'], [core_genes, acc_genes, rare_genes])):
    cogs = df_eggnog.loc[genes].COG_category.values
    unique_cogs = "J A K L B D Y V T M N Z W U O C G E F H I P Q R S".split()
    
    cog_counts = {x: 0 for x in unique_cogs}
    for cog in cogs:
        for char in cog:
            if char == '-': char = 'S'  # Replace '-' with 'S'
            cog_counts[char] += 1
    
    COG_order = "J A K L B D Y V T M N Z W U O C G E F H I P Q R S".split()
    df = pd.DataFrame(list(cog_counts.items()), columns=['COG', 'Count']).set_index('COG').loc[[x for x in COG_order if x in cog_counts.keys()]]
    df = df.reset_index()
    df = df.drop(df[df.Count == 0].index)
    
    # Assign the meta categories to the DataFrame
    df['meta_category'] = df['COG'].map(categories)
    df['color'] = df['meta_category'].apply(lambda x: colors[x])

    # Create the sunburst plot for the current category
    sunburst = px.sunburst(df, path=['meta_category', 'COG'], 
                           values='Count', color='meta_category', 
                           color_discrete_map=colors)
    
    # Update the traces to make the text black and set font size
    sunburst.update_traces(
        textinfo='label+percent entry+value', 
        insidetextorientation='horizontal',
        textfont=dict(size=8, color='black'),  # Set font size to 8 and color to black
        hoverinfo='label+percent entry+value'  # Ensures text appears on hover as well
    )

    # Adjust layout for each sunburst plot
    sunburst.update_layout(
        title=category.upper() + ' Genome COG Distribution',
        title_x=0.5,
        width=800,  # Increased width of each plot for better visibility
        height=700,  # Increased height of each plot for better visibility
        showlegend=False  # Hide legend for individual plots
    )
    
    # Add each sunburst plot to its respective subplot
    if category == 'core':
        fig.add_trace(sunburst.data[0], row=1, col=1)
    elif category == 'acc':
        fig.add_trace(sunburst.data[0], row=1, col=2)
    elif category == 'rare':
        fig.add_trace(sunburst.data[0], row=2, col=1)

# Update the layout to reduce the margins, adjust spacing, and make plots closer
fig.update_layout(
    title='COG Category Distribution by Genome',  # Overall title
    showlegend=True,  # Show legend for color categories
    legend=dict(
        x=1.05,  # Position legend further to the right
        y=0,  # Align the bottom of the legend with the plot
        traceorder='normal',
        bgcolor='rgba(255, 255, 255, 0)',  # Transparent background for the legend
        borderwidth=0  # No border for the legend
    ),
    margin=dict(t=50, b=50, l=50, r=100),  # Reduce margins to make plots closer together
    height=1200,  # Overall height increased to fit larger plots
    width=1000,  # Overall width increased to fit larger plots
    uniformtext=dict(minsize=8, mode='show')
)


# Save the figure as an SVG file
fig.write_image('../../images/supplemental/cogs_grid.svg', format='svg')

# Show the final plot
fig.show()


## Looking at AMRFinder Results

In [None]:
amr = pd.read_csv('../../data/processed/amrfinder/output', sep = '\t')
amr['Protein identifier'] = amr['Protein identifier'].apply(lambda x: x.split('A')[0])
amr = amr.set_index('Protein identifier')

In [None]:
hist_df = df_genes.loc[amr.index,df_core.columns].sum().sort_values()

In [None]:
sns.violinplot(hist_df);
plt.title('Violin Plot of AMR Genes Per Strain');
plt.ylabel('Number of AMR Genes');

In [None]:
core_amr = df_core.loc[[x for x in amr.index if x in df_core.index]]
df_eggnog.loc[core_amr.index]

In [None]:
acc_amr = df_acc.loc[[x for x in amr.index if x in df_acc.index]]
df_eggnog.loc[acc_amr.index]

In [None]:
amr_rare = df_rare.loc[[x for x in amr.index if x in df_rare.index]]
amr_rare.sum(axis=1).sort_values().hist()

## Metabolic Core Genes

In [None]:
# 1042 genes are metabolic
isMetabolic = df_eggnog.loc[df_core.index].COG_category.str.contains('C|E|F|G|H|I|P')
core_metabolic = df_eggnog.loc[df_core.index][isMetabolic]

display(core_metabolic)

## Core motility genes

In [None]:
# 42 motility genes

isMotility1 = df_eggnog.loc[df_core.index].COG_category.str.contains('N')
isMotility2 = df_eggnog.loc[df_core.index].Description.str.contains('pilus')
isMotility3 = df_eggnog.loc[df_core.index].Description.str.contains('pili')

core_motility = df_eggnog.loc[df_core.index][isMotility1 | isMotility2 | isMotility3]
core_motility

## Get number of alleles per core gene

In [None]:
# Step 1: Read the data (alleles)
P_allele = pd.read_pickle('../../data/processed/cd-hit-results/sim80/Ebacter_strain_by_allele.pickle.gz')
P_allele = P_allele.loc[:, df_core.columns]
display(
    P_allele.shape,
    P_allele.head(),
    P_allele.dtypes
)

In [None]:
P_allele = P_allele.loc[:, df_core.columns].fillna(0)
mask = np.any(P_allele.values == 1, axis = 1)
P_allele = P_allele[mask]

In [None]:
relevant_alleles = [x for x in P_allele.index if x.split('A')[0] in df_core.index]
len(relevant_alleles)

In [None]:
P_allele = P_allele.loc[relevant_alleles].fillna(0).astype(int)

In [None]:
# Step 1: Create a mapping from genes to the alleles that match
print('Mapping alleles')
allele_mapping = {gene: [x for x in P_allele.index if gene + 'A' in x] for gene in df_core.index}

# Step 2: Create a DataFrame to hold the sums for easier processing
print('Summing allele presence')
sum_df = P_allele[df_core.columns].fillna(0).sum(axis=1)

# Step 3: Process each gene
dominant_alleles = []
super_dominant_alleles = []

for gene in tqdm(df_core.index):
    alleles = allele_mapping.get(gene, [])
    if not alleles:
        continue
    
    # Compute the sum of alleles
    allele_sum = sum_df.loc[alleles]
    
    # Check if the max sum is greater than 0.5 times the number of columns
    if allele_sum.max() > 0.5 * df_core.shape[1]:
        dominant_alleles.append(allele_sum.idxmax())

    if allele_sum.max() > 0.95 * df_core.shape[1]:
        super_dominant_alleles.append(allele_sum.idxmax())

In [None]:
core_allele_counts = {}

for gene in tqdm(df_core.index):
    core_allele_counts[gene] = len([x for x in P_allele.index if gene + 'A' in x])

In [None]:
core_alleles = pd.Series(core_allele_counts).sort_values()

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

plt.rcParams['font.family'] = 'DejaVu Sans'

# Set the figure size (narrower and taller)
plt.figure(figsize=(12, 6))  # Adjust width and height as needed

plt.hist(pd.DataFrame(core_alleles), bins=np.linspace(0, 260, 20), label='All Genes', edgecolor='black', alpha=0.8)
plt.hist(pd.DataFrame(core_alleles).loc[[x.split('A')[0] for x in dominant_alleles]], 
         bins=np.linspace(0, 260, 20), color='red', label='Dominant Alleles', edgecolor='black', alpha=1)
plt.hist(pd.DataFrame(core_alleles).loc[[x.split('A')[0] for x in super_dominant_alleles]], 
         bins=np.linspace(0, 260, 20), color='purple', label='Hyper-Dominant Alleles', edgecolor='black', alpha=1)

plt.legend()
plt.title('Number of Alleles Per Core Gene')
plt.xlabel('Number of Alleles')
plt.ylabel('Number of Genes')

plt.savefig('../images/core_genome_figs/core_allele_dist.svg', format='svg')

plt.show()

In [None]:
core_cogs = df_eggnog.loc[[x.split('A')[0] for x in set(df_core.index) - set(dominant_alleles)]].COG_category.apply(lambda x: x[0]).value_counts()
core_proportion = core_cogs / core_cogs.sum()

In [None]:
dominant_cogs = df_eggnog.loc[[x.split('A')[0] for x in set(dominant_alleles) - set(super_dominant_alleles)]].COG_category.apply(lambda x: x[0]).value_counts()
dominant_proportion = dominant_cogs / dominant_cogs.sum()

In [None]:
super_dominant_cogs = df_eggnog.loc[[x.split('A')[0] for x in super_dominant_alleles]].COG_category.apply(lambda x: x[0]).value_counts()
super_dominant_proportion = super_dominant_cogs / super_dominant_cogs.sum()

In [None]:
# Create a dataframe for the counts
df = pd.DataFrame()
df['core'] = core_cogs
df['dominant'] = dominant_cogs
df['super_dominant'] = super_dominant_cogs
df['COG'] = df.index
df = df.fillna(0)


# Normalize counts to proportions
df['core_proportion'] = df['core'] / df['core'].sum()
df['dominant_proportion'] = df['dominant'] / df['dominant'].sum()
df['super_dominant_proportion'] = df['super_dominant'] / df['super_dominant'].sum()

In [None]:
count

In [None]:
from statsmodels.stats.proportion import proportions_ztest
from statsmodels.stats.multitest import multipletests
import pandas as pd

# Assuming df is your original DataFrame with proportions already calculated
# Ensure you have columns like 'core_proportion', 'dominant_proportion', 'super_dominant_proportion'

# Store results for each comparison
results = []

# Define the groups and their pairwise comparisons
groups = ['core', 'dominant', 'super_dominant']
comparisons = [('core', 'dominant'), ('core', 'super_dominant'), ('dominant', 'super_dominant')]

# Perform pairwise comparisons for each COG category
for cog in df['COG']:
    for group1, group2 in comparisons:
        # Get the proportion for each group
        prop1 = df.loc[df['COG'] == cog, f'{group1}_proportion'].values[0]
        prop2 = df.loc[df['COG'] == cog, f'{group2}_proportion'].values[0]
        
        # Get the total counts for each group
        n1 = df[group1].sum()  # Total count for group1
        n2 = df[group2].sum()  # Total count for group2
        
        # Perform the proportion z-test
        count = [prop1 * n1, prop2 * n2]  # Number of successes (proportion * total)
        nobs = [n1, n2]  # Total number of observations in each group
        
        # Perform z-test
        z_stat, p_value = proportions_ztest(count, nobs)
        
        # Store the results
        results.append({
            'COG': cog,
            'Comparison': f'{group1} vs {group2}',
            'Z-statistic': z_stat,
            'P-value': p_value
        })

# Convert results to a DataFrame
results_df = pd.DataFrame(results)

# Adjust p-values for multiple comparisons using FDR (False Discovery Rate)
_, corrected_p_values, _, _ = multipletests(results_df['P-value'], method='fdr_bh')

# Add corrected p-values to the results
results_df['Corrected P-value'] = corrected_p_values

# Mark significant results (e.g., Corrected P-value < 0.05)
results_df['Significance'] = results_df['Corrected P-value'].apply(
    lambda x: 'Significant' if x < 0.05 else 'Not Significant'
)

# Display results
display(results_df)


In [None]:
# Melt the dataframe to long format for easier plotting
plot_df = pd.melt(
    df,
    id_vars=['COG'],
    value_vars=['core_proportion', 'dominant_proportion', 'super_dominant_proportion'],
    var_name='Group',
    value_name='Proportion'
)

# Map proportions to group names
plot_df['Group'] = plot_df['Group'].str.replace('_proportion', '').str.capitalize()

# Extract significant results
significant_results = results_df[results_df["Corrected P-value"] < 0.05]

# Create a significance flag
significant_results["Significance"] = significant_results.apply(
    lambda row: f"* {row['Comparison']}" if row['Corrected P-value'] < 0.05 else "",
    axis=1
)

# Merge significance back into the plot data
plot_df = plot_df.merge(significant_results[['COG', 'Comparison', 'Significance']],
                        on='COG', how='left')

# Define the desired order of the COG categories
COG_order = "J A K L B D Y V T M N Z W U O C G E F H I P Q R S -".split()

# Filter the COG_order to include only those present in the dataset
present_cogs = plot_df['COG'].unique()
filtered_cog_order = [cog for cog in COG_order if cog in present_cogs]

# Set up the plot
plt.figure(figsize=(12, 6))
sns.barplot(
    data=plot_df,
    x='COG',
    y='Proportion',
    hue='Group',
    palette='viridis',
    order=filtered_cog_order  # Set the filtered x-tick order
)

# Annotate significant differences
for index, row in significant_results.iterrows():
    cog = row['COG']
    comparison = row['Comparison']
    
    # Get the x-position of the COG in the filtered order
    x_pos = filtered_cog_order.index(cog)  # Get the index in the filtered order list
    # Find the maximum y value for that COG
    max_y_pos = plot_df[plot_df['COG'] == cog]['Proportion'].max()
    
    # Add the significance mark slightly above the bar
    plt.text(
        x=x_pos, 
        y=max_y_pos + 0.01,  # Adjust this offset to position the star
        s="*",  # Add a star to indicate significance
        ha='center',
        color='red'
    )

# Customize plot
plt.title("Proportion of COG Categories Across Groups")
plt.xlabel("COG Category")
plt.ylabel("Proportion")
plt.legend(title="Group")
plt.tight_layout()

plt.savefig('../images/core_genome_figs/Conservation_and_COG_Categories.svg', format='svg')
plt.show() 

## Groups of genes of interest in Core genome

In [None]:
bakta_files = '../../data/processed/bakta/'
header_to_allele = '../../data/processed/cd-hit-results/header_to_allele_80.pickle.gz'
cd_hit_headers = '../../data/processed/cd-hit-results/rep_headers.txt'
PATH_TO_DATA = '../../data/'

In [None]:
df_h2a = pd.read_pickle(header_to_allele)

In [None]:
headers = open(cd_hit_headers).readlines()
headers = [x[1:13] for x in headers]

In [None]:
# gene to alleles within
cluster_to_alleles = {}

# Iterate over the original dictionary
for allele, cluster in tqdm(df_h2a.items()):
    cluster = cluster.split('A')[0]
    # If the cluster is not yet in the new dictionary, add it with an empty list
    if cluster not in cluster_to_alleles:
        cluster_to_alleles[cluster] = []
    # Append the allele to the list of alleles for this cluster
    cluster_to_alleles[cluster].append(allele)


In [None]:
genome_to_tag = {}
for genome in tqdm(os.listdir(bakta_files)):
    file = open(bakta_files + genome + '/' + genome + '.gff3')
    file.seek(0)
    text = file.read(10000)
    loc = text.find('locus_tag=')
    tag = text[loc+10:loc+16]
    genome_to_tag[genome] = tag

In [None]:
def _get_attr(attributes, attr_id, ignore=False):
    """
    Helper function for parsing GFF annotations

    Parameters
    ----------
    attributes : str
        Attribute string
    attr_id : str
        Attribute ID
    ignore : bool
        If true, ignore errors if ID is not in attributes (default: False)

    Returns
    -------
    str, optional
        Value of attribute
    """

    try:
        return re.search(attr_id + "=(.*?)(;|$)", attributes).group(1)
    except AttributeError:
        if ignore:
            return None
        else:
            raise ValueError("{} not in attributes: {}".format(attr_id, attributes))

def gff2pandas(gff_file, feature=["CDS"], index=None):
    """
    Converts GFF file(s) to a Pandas DataFrame
    Parameters
    ----------
    gff_file : str or list
        Path(s) to GFF file
    feature: str or list
        Name(s) of features to keep (default = "CDS")
    index : str, optional
        Column or attribute to use as index

    Returns
    -------
    df_gff: ~pandas.DataFrame
        GFF formatted as a DataFrame
    """

    # Argument checking
    if isinstance(gff_file, str):
        gff_file = [gff_file]

    if isinstance(feature, str):
        feature = [feature]

    result = []

    for gff in gff_file:
        with open(gff, "r") as f:
            lines = f.readlines()

        # Get lines to skip
        skiprow = [i for i, line in enumerate(lines) if line.startswith("#")]
       
        # Read GFF
        names = [
            "accession",
            "source",
            "feature",
            "start",
            "end",
            "score",
            "strand",
            "phase",
            "attributes",
        ]
        DF_gff = pd.read_csv(gff, sep="\t", skiprows=skiprow, names=names, header=None, low_memory=False)
        
        region = DF_gff[DF_gff.feature == 'region']
        region_len = int(region.iloc[0].end)

        # Filter for CDSs
        DF_cds = DF_gff[DF_gff.feature.isin(feature)]

        # Sort by start position
        # DF_cds = DF_cds.sort_values("start")

        DF_cds = DF_cds.copy() # get rid of copy warning
        
        # Extract attribute information
        DF_cds["locus_tag"] = DF_cds.attributes.apply(_get_attr, attr_id="locus_tag")

        result.append(DF_cds)
        
    DF_gff = pd.concat(result)
    
    if index:
        if DF_gff[index].duplicated().any():
            logging.warning("Duplicate {} detected. Dropping duplicates.".format(index))
            DF_gff = DF_gff.drop_duplicates(index)
        DF_gff.set_index("locus_tag", drop=True, inplace=True)

    return DF_gff[['accession', 'start', 'end', 'locus_tag', 'strand']], region_len

In [None]:
strain_vectors = {}
def h2a(x):
    try:
        return df_h2a[x].split('A')[0]
    except:
        return None
        
for strain in tqdm(metadata_complete.genome_id):
    DF_gff, size = gff2pandas(f'{PATH_TO_DATA}/processed/bakta/{strain}/{strain}.gff3')
    DF_gff['gene'] = DF_gff.locus_tag.apply(lambda x: h2a(x))
    # DF_gff = DF_gff[DF_gff.accession == DF_gff.accession.value_counts().index[0]]
    # DF_gff = DF_gff[['gene','start', 'end', 'strand']]
    # gene_order = (DF_gff.sort_values('start').gene.to_list())
 
    strain_vectors[strain] = DF_gff.sort_values(by=['accession', 'start'])

In [None]:
# figure out how to determine contiguity of these regions

# Size core gene length vs. cog cateogry

In [None]:
# Path to the FASTA file
fasta_file = "../../data/processed/cd-hit-results/sim80/Ebacter_nr.faa"


gene_lengths = pd.DataFrame(index = df_core.index, columns = ['lens', 'median_len'])
gene_lengths['lens'] = [[] for _ in range(len(df_core))]

# Iterate through the FASTA file and store sequences associated with headers in gene_list
for record in tqdm(SeqIO.parse(fasta_file, "fasta")):
    # print(record.id, len(record.seq))
    if record.id in P_allele.index:
        gene_lengths.loc[record.id.split('A')[0], 'lens'].append(len(record.seq))
gene_lengths['median_len'] = (gene_lengths.lens).apply(lambda x: np.median(x))
gene_lengths['cog'] = gene_lengths.apply(lambda x: df_eggnog.loc[x.name, 'COG_category'][0], axis=1)
gene_lengths = gene_lengths.replace('-','S')
gene_lengths['alleles'] = gene_lengths.lens.apply(lambda x: len(x))

def get_cog_super(cog):
    if cog in ['J', 'A', 'K', 'L']:
        return 'INFORMATION STORAGE AND PROCESSING'
    elif cog in 'D Y V T M N Z W U O'.split():
        return 'CELLULAR PROCESSES AND SIGNALING'
    elif cog in 'C G E F H I P Q'.split():
        return 'METABOLISM'
    else:
        return 'POORLY CHARACTERIZED'

gene_lengths['cog_super'] = gene_lengths.cog.apply(lambda x: get_cog_super(x))

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(7, 5))

sns.histplot(data=gene_lengths, x='median_len', hue='cog_super', multiple='stack', palette = ['green', 'gold', 'red', 'grey'], ax=ax)

# Create handles and labels for the legend
handles = []
labels = []


colors = {'METABOLISM':'green', 'CELLULAR PROCESSES AND SIGNALING':'gold', 'INFORMATION STORAGE AND PROCESSING':'red', 
         'POORLY CHARACTERIZED':'grey'}

categories = {
    'INFORMATION STORAGE AND PROCESSING': ['J', 'A', 'K', 'L'],
    'CELLULAR PROCESSES AND SIGNALING': 'D Y V T M N Z W U O'.split(),
    'METABOLISM': 'C G E F H I P Q'.split(),
    'POORLY CHARACTERIZED': 'R S'.split()
}

for category in gene_lengths.cog_super.unique():
    # Add a handle for the title
    # handles.append(plt.Line2D([0], [0], color='none'))  # Invisible handle for spacing
    labels.append(category + '\n' + ', '.join(categories[category]))  # Title for the category
    category_handles = [plt.Line2D([0], [0], color=colors[category], lw=4)]
    handles.extend(category_handles)

# Create a single legend
plt.legend(handles, labels, title='COG Supercategories', loc='upper right', bbox_to_anchor=(1, 1), ncol=1)
plt.savefig('../images/core_genome_figs/gene_lengths_histogram.svg', format='svg', dpi=600, bbox_inches='tight')
plt.show()

In [None]:
conserved_genes = []
for gene in tqdm(core_alleles.index):
    if (P_allele.loc[[x for x in P_allele.index if  gene + 'A' in x], df_core.columns].sum(axis=1).iloc[0] > df_core.loc[gene].sum() * .99):
        conserved_genes.append(gene)

In [None]:
len(conserved_genes)