In [None]:
# Base imports
import os
import pickle
import re

# Compute imports
from collections import Counter
import numpy as np
import pandas as pd
import scipy
from tqdm.notebook import tqdm, trange

# Plotting imports
import matplotlib
from matplotlib.colors import LinearSegmentedColormap
from matplotlib import pyplot as plt
import seaborn as sns
from plotly import express as px

# ML import
from sklearn.decomposition import NMF
from sklearn.metrics import mean_squared_error, median_absolute_error
from sklearn.metrics.pairwise import cosine_similarity
from scipy import sparse
from sklearn.cluster import KMeans


matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
matplotlib.rcParams['svg.fonttype'] = 'none'
matplotlib.rcParams['font.sans-serif'] = 'Arial'
matplotlib.rcParams['font.family'] = 'sans-serif'
sns.set_style('ticks')
matplotlib.rcParams['text.color'] = '#000000'
matplotlib.rcParams['axes.labelcolor'] = '#000000'
matplotlib.rcParams['xtick.color'] = '#000000'
matplotlib.rcParams['ytick.color'] = '#000000'

In [None]:
PATH_TO_DATA = '../../data/'

In [None]:
DF_GENES = '../../data/processed/cd-hit-results/sim80/Ebacter_strain_by_gene.pickle.gz'
ENRICHED_METADATA = '../../data/metadata/enriched_metadata.csv'
DF_EGGNOG = '../../data/processed/df_eggnog.csv'

DF_CORE_COMPLETE = '../../data/processed/CAR_genomes/df_core_complete.pickle'
DF_ACC_COMPLETE = '../../data/processed/CAR_genomes/df_acc_complete.pickle'
DF_RARE_COMPLETE = '../../data/processed/CAR_genomes/df_rare_complete.pickle'

L_BINARIZED = '../../data/processed/nmf-outputs/L_binarized.csv'
A_BINARIZED = '../../data/processed/nmf-outputs/A_binarized.csv'
L_MATRIX = '../../data/processed/nmf-outputs/L.csv'
A_MATRIX = '../../data/processed/nmf-outputs/A.csv'

In [None]:
# Load in A_binarized matrix
A_binarized = pd.read_csv(A_BINARIZED, index_col=0)
L_binarized = pd.read_csv(L_BINARIZED, index_col=0)

In [None]:
df_rare = pd.read_pickle(DF_RARE_COMPLETE)
df_acc = pd.read_pickle(DF_ACC_COMPLETE)
df_core = pd.read_pickle(DF_CORE_COMPLETE)

In [None]:
metadata = pd.read_csv(ENRICHED_METADATA, index_col=0, dtype='object')

display( metadata.shape, metadata.head())

In [None]:
# Load in (full) P matrix
df_genes = pd.read_pickle(DF_GENES)

# Filter metadata for Complete sequences only
metadata_complete = metadata[metadata.genome_status == 'Complete'] # filter for only Complete sequences

# Filter P matrix for Complete sequences only
df_genes_complete = df_genes[metadata_complete.genome_id].copy()
df_genes_complete.fillna(0, inplace=True) # replace N/A with 0
df_genes_complete = df_genes_complete.sparse.to_dense().astype('int8') # densify & typecast to int8 for space and compute reasons
inCompleteseqs = df_genes_complete.sum(axis=1) > 0 # filter for genes found in complete sequences
df_genes_complete = df_genes_complete[inCompleteseqs]

df_genes_complete.shape

In [None]:
phylon_order = ['hormaechei-xiangfangensis',
 'hormaechei-oharae',
 'hormaechei-steigerwaltii-2',
 'hormaechei-steigerwaltii-1',
 'hormaechei-steigerwaltii-3',
 'hormaechei-hormaechei',
 'hormaechei-hoffmannii-1',
 'hormaechei-hoffmannii-2',
 'unchar-1',
 'unchar-2',
 'unchar-3',
 'unchar-4',
 'roggenkampii',
 'asburiae',
 'kobei',
 'bugandensis',
 'cancerogenous',
 'ludwigii',
 'cloacae']

characterized_order = ['hormaechei-xiangfangensis',
 'hormaechei-oharae',
 'hormaechei-steigerwaltii-2',
 'hormaechei-steigerwaltii-1',
 'hormaechei-steigerwaltii-3',
 'hormaechei-hormaechei',
 'hormaechei-hoffmannii-1',
 'hormaechei-hoffmannii-2',
 'roggenkampii',
 'asburiae',
 'kobei',
 'bugandensis',
 'cancerogenous',
 'ludwigii',
 'cloacae']

In [None]:
# Load in eggNOG annotations
df_eggnog = pd.read_csv(DF_EGGNOG, index_col=0)
df_eggnog.fillna('-', inplace=True)
df_eggnog.COG_category = df_eggnog.COG_category.apply(lambda x: x[0])
display(
    df_eggnog.shape,
    df_eggnog.head()
)

## Analysis of Rare Genome


In [None]:
rare_genes = df_rare.index
rare_geneome_eggnog = df_eggnog.loc[rare_genes]

In [None]:
BAKTA_ANNOTATIONS = '../../data/processed/bakta_gene_annotations.csv'
bakta_annotations = pd.read_csv(BAKTA_ANNOTATIONS, index_col=0)

In [None]:
# Optional: drop all genes which are predicted as hypothetical 
df_rare_annot = df_rare
# to_drop = bakta_annotations[(bakta_annotations.Product.str.contains('ypothetical'))].index
# to_drop = [x for x in to_drop if x in df_rare.index]
# df_rare_annot = df_rare.drop(to_drop).astype(int)

# Extract gene locations for determination of gene proximity

In [None]:
bakta_files = '../../data/processed/bakta/'
header_to_allele = '../../data/processed/cd-hit-results/header_to_allele_80.pickle.gz'
cd_hit_headers = '../../data/processed/cd-hit-results/rep_headers.txt'

In [None]:
df_h2a = pd.read_pickle(header_to_allele)

In [None]:
headers = open(cd_hit_headers).readlines()
headers = [x[1:13] for x in headers]

In [None]:
# gene to alleles within
cluster_to_alleles = {}

# Iterate over the original dictionary
for allele, cluster in tqdm(df_h2a.items()):
    cluster = cluster.split('A')[0]
    # If the cluster is not yet in the new dictionary, add it with an empty list
    if cluster not in cluster_to_alleles:
        cluster_to_alleles[cluster] = []
    # Append the allele to the list of alleles for this cluster
    cluster_to_alleles[cluster].append(allele)


In [None]:
genome_to_tag = {}
for genome in tqdm(os.listdir(bakta_files)):
    file = open(bakta_files + genome + '/' + genome + '.gff3')
    file.seek(0)
    text = file.read(10000)
    loc = text.find('locus_tag=')
    tag = text[loc+10:loc+16]
    genome_to_tag[genome] = tag

In [None]:
def _get_attr(attributes, attr_id, ignore=False):
    """
    Helper function for parsing GFF annotations

    Parameters
    ----------
    attributes : str
        Attribute string
    attr_id : str
        Attribute ID
    ignore : bool
        If true, ignore errors if ID is not in attributes (default: False)

    Returns
    -------
    str, optional
        Value of attribute
    """

    try:
        return re.search(attr_id + "=(.*?)(;|$)", attributes).group(1)
    except AttributeError:
        if ignore:
            return None
        else:
            raise ValueError("{} not in attributes: {}".format(attr_id, attributes))

def gff2pandas(gff_file, feature=["CDS"], index=None):
    """
    Converts GFF file(s) to a Pandas DataFrame
    Parameters
    ----------
    gff_file : str or list
        Path(s) to GFF file
    feature: str or list
        Name(s) of features to keep (default = "CDS")
    index : str, optional
        Column or attribute to use as index

    Returns
    -------
    df_gff: ~pandas.DataFrame
        GFF formatted as a DataFrame
    """

    # Argument checking
    if isinstance(gff_file, str):
        gff_file = [gff_file]

    if isinstance(feature, str):
        feature = [feature]

    result = []

    for gff in gff_file:
        with open(gff, "r") as f:
            lines = f.readlines()

        # Get lines to skip
        skiprow = [i for i, line in enumerate(lines) if line.startswith("#")]
       
        # Read GFF
        names = [
            "accession",
            "source",
            "feature",
            "start",
            "end",
            "score",
            "strand",
            "phase",
            "attributes",
        ]
        DF_gff = pd.read_csv(gff, sep="\t", skiprows=skiprow, names=names, header=None, low_memory=False)
        
        region = DF_gff[DF_gff.feature == 'region']
        region_len = int(region.iloc[0].end)

        # Filter for CDSs
        DF_cds = DF_gff[DF_gff.feature.isin(feature)]

        # Sort by start position
        # DF_cds = DF_cds.sort_values("start")

        DF_cds = DF_cds.copy() # get rid of copy warning
        
        # Extract attribute information
        DF_cds["locus_tag"] = DF_cds.attributes.apply(_get_attr, attr_id="locus_tag")

        result.append(DF_cds)
        
    DF_gff = pd.concat(result)
    
    if index:
        if DF_gff[index].duplicated().any():
            logging.warning("Duplicate {} detected. Dropping duplicates.".format(index))
            DF_gff = DF_gff.drop_duplicates(index)
        DF_gff.set_index("locus_tag", drop=True, inplace=True)

    return DF_gff[['accession', 'start', 'end', 'locus_tag', 'strand']], region_len

In [None]:
strain_vectors = {}
def h2a(x):
    try:
        return df_h2a[x].split('A')[0]
    except:
        return None
        
for strain in tqdm(metadata_complete.genome_id):
    DF_gff, size = gff2pandas(f'{PATH_TO_DATA}/processed/bakta/{strain}/{strain}.gff3')
    DF_gff['gene'] = DF_gff.locus_tag.apply(lambda x: h2a(x))
    # DF_gff = DF_gff[DF_gff.accession == DF_gff.accession.value_counts().index[0]]
    # DF_gff = DF_gff[['gene','start', 'end', 'strand']]
    # gene_order = (DF_gff.sort_values('start').gene.to_list())
 
    strain_vectors[strain] = DF_gff.sort_values(by=['accession', 'start'])

In [None]:
rare_genes = set(df_rare.index)

strain_contigs = {}

rare_genome_stats = pd.DataFrame(index = list(rare_genes), columns = ['sizes', 'genomes', 'all_neighbors'])
for gene in rare_genes:
    rare_genome_stats.loc[gene] = [[], [], set()]

strain_contig_sizes = pd.Series({strain:[] for strain in metadata_complete.genome_id})

for strain, df in tqdm(strain_vectors.items()):
    strain_contigs[strain] = []
    curr_contig = []
    curr_accession = ""
    for i,row in df.iterrows():
        if curr_accession == row.accession or curr_accession == "":
            if row.gene in rare_genes:

                curr_contig.append(row.gene)

            elif len(curr_contig) > 0:
                strain_contigs[strain].append(curr_contig)
                strain_contig_sizes.loc[strain].append(len(curr_contig))
                for gene in curr_contig:
                    rare_genome_stats.loc[gene, 'sizes'].append(len(curr_contig))
                    rare_genome_stats.loc[gene, 'genomes'].append(strain + '_' + str(len(strain_contigs[strain])-1))
                    rare_genome_stats.loc[gene, 'all_neighbors'].update(set(curr_contig) - set([gene]))  
        
                curr_contig = []
        elif len(curr_contig) > 0:
            strain_contigs[strain].append(curr_contig)
            strain_contig_sizes.loc[strain].append(len(curr_contig))
            for gene in curr_contig:
                rare_genome_stats.loc[gene, 'sizes'].append(len(curr_contig))
                rare_genome_stats.loc[gene, 'genomes'].append(strain + '_' + str(len(strain_contigs[strain])-1))
                rare_genome_stats.loc[gene, 'all_neighbors'].update(set(curr_contig) - set([gene]))
                                           
            curr_contig = []
        if row.gene in rare_genes and row.gene not in curr_contig:
            curr_contig.append(row.gene)
            
        curr_accession = row.accession

    if len(curr_contig) > 0:
        strain_contigs[strain].append(curr_contig)
        strain_contig_sizes.loc[strain].append(len(curr_contig))
        for gene in curr_contig:
            rare_genome_stats.loc[gene, 'sizes'].append(len(curr_contig))
            rare_genome_stats.loc[gene, 'genomes'].append(strain + '_' + str(len(strain_contigs[strain])-1))
            rare_genome_stats.loc[gene, 'all_neighbors'].update(set(curr_contig) - set([gene]))
                
        curr_contig = []

# Repeated for only main chromosomes

In [None]:
strain_vectors_chrom = {}
def h2a(x):
    try:
        return df_h2a[x].split('A')[0]
    except:
        return None
        
for strain in tqdm(metadata_complete.genome_id):
    DF_gff, size = gff2pandas(f'{PATH_TO_DATA}/processed/bakta/{strain}/{strain}.gff3')
    DF_gff['gene'] = DF_gff.locus_tag.apply(lambda x: h2a(x))
    DF_gff = DF_gff[DF_gff.accession == DF_gff.accession.value_counts().index[0]]
    DF_gff = DF_gff[['accession','gene','start', 'end', 'strand']]
    # gene_order = (DF_gff.sort_values('start').gene.to_list())
 
    strain_vectors_chrom[strain] = DF_gff.sort_values('start')

In [None]:
rare_genes_chrom = set(df_rare.index)

strain_contigs_chrom = {}

rare_genome_stats_chrom = pd.DataFrame(index = list(rare_genes_chrom), columns = ['sizes', 'genomes', 'all_neighbors'])
for gene in rare_genes_chrom:
    rare_genome_stats_chrom.loc[gene] = [[], [], set()]

strain_contig_sizes_chrom = pd.Series({strain:[] for strain in metadata_complete.genome_id})

for strain, df in tqdm(strain_vectors_chrom.items()):
    if strain == '550.624':
        continue
    
    strain_contigs_chrom[strain] = []
    curr_contig = []
    curr_accession = ""
    for i,row in df.iterrows():
        if curr_accession == row.accession or curr_accession == "":
            if row.gene in rare_genes_chrom:

                curr_contig.append(row.gene)

            elif len(curr_contig) > 0:
                strain_contigs_chrom[strain].append(curr_contig)
                strain_contig_sizes_chrom.loc[strain].append(len(curr_contig))
                for gene in curr_contig:
                    rare_genome_stats_chrom.loc[gene, 'sizes'].append(len(curr_contig))
                    rare_genome_stats_chrom.loc[gene, 'genomes'].append(strain + '_' + str(len(strain_contigs_chrom[strain])-1))
                    rare_genome_stats_chrom.loc[gene, 'all_neighbors'].update(set(curr_contig) - set([gene]))  
        
                curr_contig = []
        elif len(curr_contig) > 0:
            strain_contigs_chrom[strain].append(curr_contig)
            strain_contig_sizes_chrom.loc[strain].append(len(curr_contig))
            for gene in curr_contig:
                rare_genome_stats_chrom.loc[gene, 'sizes'].append(len(curr_contig))
                rare_genome_stats_chrom.loc[gene, 'genomes'].append(strain + '_' + str(len(strain_contigs_chrom[strain])-1))
                rare_genome_stats_chrom.loc[gene, 'all_neighbors'].update(set(curr_contig) - set([gene]))
                                           
            curr_contig = []
        if row.gene in rare_genes_chrom and row.gene not in curr_contig:
            curr_contig.append(row.gene)
            
        curr_accession = row.accession

    if len(curr_contig) > 0:
        strain_contigs_chrom[strain].append(curr_contig)
        strain_contig_sizes_chrom.loc[strain].append(len(curr_contig))
        for gene in curr_contig:
            rare_genome_stats_chrom.loc[gene, 'sizes'].append(len(curr_contig))
            rare_genome_stats_chrom.loc[gene, 'genomes'].append(strain + '_' + str(len(strain_contigs_chrom[strain])-1))
            rare_genome_stats_chrom.loc[gene, 'all_neighbors'].update(set(curr_contig) - set([gene]))
                
        curr_contig = []

In [None]:
# analyze if singleton rare genes are usually hypothetical proteins or have other enrichments
rare_genome_stats_chrom = rare_genome_stats_chrom.drop(rare_genome_stats_chrom[rare_genome_stats_chrom.sizes.apply(lambda x: len(x) == 0)].index)
strain_contig_sizes_chrom = strain_contig_sizes_chrom.drop('550.624')

In [None]:
rare_genome_stats_chrom.sort_values('sizes').sizes.apply(lambda x: len(x)).hist()
plt.yscale('log')

# Plots

In [None]:
# Step 1: Read the data (alleles)
P_allele = pd.read_pickle('../../data/processed/cd-hit-results/sim80/Ebacter_strain_by_allele.pickle.gz')
P_allele = P_allele.loc[:, df_core.columns]

In [None]:
P_allele = P_allele.loc[:, df_core.columns].fillna(0)
mask = np.any(P_allele.values == 1, axis = 1)
P_allele = P_allele[mask]

In [None]:
relevant_alleles = [x for x in P_allele.index if x.split('A')[0] in df_rare.index]
len(relevant_alleles)

In [None]:
P_allele = P_allele.loc[relevant_alleles].fillna(0).astype(int)

In [None]:
# Path to the FASTA file
from Bio import SeqIO

fasta_file = "../../data/processed/cd-hit-results/sim80/Ebacter_nr.faa"


gene_lengths = pd.DataFrame(index = df_rare.index, columns = ['lens', 'median_len'])
gene_lengths['lens'] = [[] for _ in range(len(df_rare))]

# Iterate through the FASTA file and store sequences associated with headers in gene_list
for record in tqdm(SeqIO.parse(fasta_file, "fasta")):
    # print(record.id, len(record.seq))
    if record.id in P_allele.index:
        gene_lengths.loc[record.id.split('A')[0], 'lens'].append(len(record.seq))
gene_lengths['median_len'] = (gene_lengths.lens).apply(lambda x: np.median(x))
gene_lengths['cog'] = gene_lengths.apply(lambda x: df_eggnog.loc[x.name, 'COG_category'][0], axis=1)
gene_lengths = gene_lengths.replace('-','S')
gene_lengths['alleles'] = gene_lengths.lens.apply(lambda x: len(x))

def get_cog_super(cog):
    if cog in ['J', 'A', 'K', 'L']:
        return 'INFORMATION STORAGE AND PROCESSING'
    elif cog in 'D Y V T M N Z W U O'.split():
        return 'CELLULAR PROCESSES AND SIGNALING'
    elif cog in 'C G E F H I P Q'.split():
        return 'METABOLISM'
    else:
        return 'POORLY CHARACTERIZED'

gene_lengths['cog_super'] = gene_lengths.cog.apply(lambda x: get_cog_super(x))

In [None]:
custom_colors = [
    # Shades of red/orange/yellow
    "Red",
    "IndianRed",
    "DarkRed",
    "FireBrick",
    "Tomato",
    "Gold",
    "DarkGoldenrod",
    "Goldenrod",
    # Other species
    "Green",
    "Blue",
    "Purple",
    "Cyan",
    "Magenta",
    "Lime",
    "Pink",
]

In [None]:
def get_strains(phylon, A_binarized = A_binarized):
    phylon_membership = A_binarized.loc[phylon]
    return (phylon_membership[phylon_membership == 1]).index

In [None]:
import warnings
warnings.simplefilter(action='ignore')

plot_data = pd.DataFrame(columns = ['Phylon', 'Genome', 'Region_Size'])
df = strain_contig_sizes_chrom.explode().reset_index()
df.columns = ['Genome', 'Region_Size']
for phylon in characterized_order:
    strains = get_strains(phylon)  # Your function to get strains for the phylon
    contiguous_sizes = df[df['Genome'].isin(strains)]  # Summing rare gene counts for each strain
    contiguous_sizes.Region_Size = contiguous_sizes.Region_Size.apply(lambda x: np.log2(x))
    contiguous_sizes = contiguous_sizes.assign(Phylon=[phylon]*len(contiguous_sizes))
    plot_data = pd.concat([plot_data,contiguous_sizes], ignore_index=True)

# Create the boxenplot
plt.figure(figsize=(6, 8))
sns.boxenplot(data=plot_data, x='Phylon', y='Region_Size', palette=custom_colors)
plt.xticks([],rotation=45, ha='right')
plt.xlabel('Phylon')
plt.ylabel('Log2 Rare Genes per Strain')
plt.title('Distribution of Chromosomal Rare Gene \n Region Sizes Across Strains in Each Phylon')
plt.tight_layout()
plt.show()

# TN Central

In [None]:
blast_results_tncentral = pd.read_csv('../../data/blastdbs/tn_central_enrichment.txt', sep='\t', header=None)
blast_results_tncentral.columns = ['query', 'target', 'identity', 'len', 'mismatch', 'gapopen', 'qstart', 'qend', 'tstart', 'tend',
                                     'eval', 'bitscore']
blast_results_tncentral = blast_results_tncentral[blast_results_tncentral.identity > 80]
blast_results_tncentral['query'] = blast_results_tncentral['query'].apply(lambda x: x.split('A')[0])
blast_results_tncentral['target'] = blast_results_tncentral['target'].apply(lambda x: x.split('(')[0])
blast_results_tncentral = blast_results_tncentral.sort_values(by=['query', 'eval'], ascending=[True, True])
blast_results_tncentral  = blast_results_tncentral.drop_duplicates(subset='query', keep='first')
blast_results_tncentral = blast_results_tncentral.set_index('query')
blast_results_tncentral_rare = blast_results_tncentral.loc[[x for x in blast_results_tncentral.index if x in rare_genes]]

In [None]:
rare_genome_stats_chrom.loc[[x for x in blast_results_tncentral_rare.index if x in rare_genome_stats_chrom.index]].sort_values('sizes')

In [None]:
bakta_annotations

In [None]:
df_rare.loc[[x for x in blast_results_tncentral.index if x in df_rare.index]].sum().sort_values()

In [None]:
df_acc.loc[[x for x in blast_results_tncentral.index if x in df_acc.index]].sum().sort_values()

In [None]:
df_genes_complete.loc[[x for x in blast_results_tncentral.index if x in df_genes_complete.index]].sum().sort_values()

In [None]:
hypothetical_proteins = bakta_annotations.loc[rare_genome_stats_chrom.index][bakta_annotations.loc[rare_genome_stats_chrom.index].Product == 'hypothetical protein'].index
others = [x for x in rare_genome_stats_chrom.index if x not in hypothetical_proteins]

plt.violinplot(rare_genome_stats_chrom.loc[hypothetical_proteins].sizes.apply(lambda x: len(x)).sort_values())
plt.violinplot(rare_genome_stats_chrom.loc[others].sizes.apply(lambda x: len(x)).sort_values())

In [None]:
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

# Create a new DataFrame combining both groups
df_combined = pd.DataFrame({
    'Size': list(rare_genome_stats_chrom.loc[hypothetical_proteins].sizes.apply(lambda x: np.mean(x)).sort_values()) +
            list(rare_genome_stats_chrom.loc[others].sizes.apply(lambda x: np.mean(x)).sort_values()),
    'Group': ['Hypothetical Protein'] * len(hypothetical_proteins) +
             ['Other'] * len(others)
})

# Create the violin plot
plt.figure(figsize=(8, 6))
sns.violinplot(x='Group', y='Size', data=df_combined)

# Optional: Set plot labels
plt.title('Violin Plot of Sizes for Hypothetical Proteins vs Others')
plt.xlabel('Group')
plt.ylabel('Size')

# Show the plot
plt.show()


In [None]:
bakta_annotations.loc[rare_genome_stats.index].Product.value_counts().head(50)

In [None]:
bakta_annotations.loc[rare_genome_stats_chrom.index].Product.value_counts().head(50)

# AMR Genes and Virulence Factors seen to be neighbors in some genome 

In [None]:
amr = pd.read_csv('../../data/processed/amrfinder/output', sep = '\t')
amr['Protein identifier'] = amr['Protein identifier'].apply(lambda x: x.split('A')[0])
amr = amr.sort_values('% Coverage of reference sequence')
amr = amr.drop_duplicates(subset='Protein identifier', keep="last")

In [None]:
# import blast results and faa file to translate between IDs
blast_results = pd.read_csv('../../data/blastdbs/vfdb.txt', sep='\t', header=None)
                           
blast_results.columns = ['query', 'target', 'identity', 'len', 'mismatch', 'gapopen', 'qstart', 'qend', 'tstart', 'tend',
                                     'eval', 'bitscore']
blast_results = blast_results[blast_results.identity > 80]
blast_results['query'] = blast_results['query'].apply(lambda x: x.split('A')[0])
blast_results['target'] = blast_results['target'].apply(lambda x: x.split('(')[0])
blast_results = blast_results.sort_values(by=['query', 'eval'], ascending=[True, True])
blast_results  = blast_results.drop_duplicates(subset='query', keep='first')

def parse_fasta(file_path):
    with open(file_path, "r") as f:
        fasta_data = f.readlines()

    # Initialize an empty list to store the results
    results = {}

    # Regular expressions for matching the labels
    vfg_pattern = r"VFG\d{6}"  # Match VFG followed by 6 digits
    vf_pattern = r"VF\d{4,6}"  # Match VF followed by 4 to 6 digits
    vfc_pattern = r"VFC\d{4,6}"  # Match VFC followed by 4 to 6 digits

    # Iterate over the lines
    for line in fasta_data:
        # If the line starts with '>' it's a header
        if line.startswith(">"):
            # Extract the VFG label using regex
            vfg_match = re.search(vfg_pattern, line)
            vf_match = re.search(vf_pattern, line)
            vfc_match = re.search(vfc_pattern, line)

            if vfg_match:
                vfg_label = vfg_match.group(0)
                
                # Extract VF and VFC labels if they exist in the header
                vf_label = vf_match.group(0) if vf_match else "Unknown VF"
                vfc_label = vfc_match.group(0) if vfc_match else "Unknown VFC"
                
                results[vfg_label] =  (vf_label, vfc_label)

    return results

# Step 2: Use the function and print the results
file_path = "../../data/blastdbs/vfdb/VFDB_setA_pro.fas"  # Replace with your FASTA file path
mapped_data = parse_fasta(file_path)

vf_annots = pd.read_csv('../../data/blastdbs/vfdb/VFs.csv', skiprows=1, index_col=0)

blast_results['vfid'] = blast_results.target.apply(lambda x: mapped_data[x][0])
blast_results['category'] = blast_results['vfid'].apply(lambda x: vf_annots.loc[x, 'VFcategory'])
blast_results['vfname'] = blast_results['vfid'].apply(lambda x: vf_annots.loc[x, 'VF_Name'])
blast_results['gene_name'] = [bakta_annotations.loc[x, 'Name'] for x in blast_results['query']]

In [None]:
rare_amr_genes = [x for x in amr['Protein identifier'] if x in df_rare.index]
rare_vf_genes = [x for x in blast_results['query'] if x in df_rare.index]
total_virulence_genes = (rare_amr_genes + rare_vf_genes)
vir_gene_stats = pd.DataFrame(index = total_virulence_genes, columns=['dir_neighbors', 'num_strains', 'neighbors_list'])
for gene in set(rare_amr_genes + rare_vf_genes):
    direct_neighbors = [x for x in rare_genome_stats.loc[gene, 'all_neighbors'] if x in total_virulence_genes] # direct neighbors in any contig

    num_strains = df_rare.loc[gene].sum()
    vir_gene_stats.loc[gene] = [len(direct_neighbors),  num_strains, direct_neighbors]

In [None]:
import matplotlib.pyplot as plt

# Assuming 'df' is your DataFrame containing the 'num_neighbors' column
df = vir_gene_stats

# Create the histogram data
hist_data, bins = np.histogram(df.num_neighbors, bins=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])

# Create a horizontal bar plot
plt.figure(figsize=(6, 8))
plt.barh(range(1, 12), hist_data, edgecolor='black')

# Title and axis labels
plt.title("Histogram of AMR/VF Rare Gene Neighborhoods")
plt.ylabel("Number of Direct Neighbors")
plt.xlabel("Frequencies")

plt.savefig('virulence_neighbor_barplot.svg', format='svg')
# Show the plot
plt.show()


In [None]:
df = vir_gene_stats
bins = [0, 2, 10, float('inf')]  # Adjust these bins as necessary
labels = [f"Low ({bins[0]+1}-{bins[1]} strains)", f"Medium ({bins[1]+1}-{bins[2]} strains)", f"High ({bins[2]+1}+ strains)"]
df['Strain Category'] = pd.cut(df['num_strains'], bins=bins, labels=labels)

# Create the box plot
plt.figure(figsize=(10, 6))
sns.boxplot(x="Strain Category", y="dir_neighbors", data=df, palette="pastel")
plt.title("Distribution of Direct Neighbors Across Strain Categories")
plt.xlabel("Strain Category")
plt.ylabel("Number of Direct Neighbors")
plt.show()

In [None]:
vir_gene_stats['num_neighbors'] = vir_gene_stats.neighbors_list.apply(lambda x: len(x))
vir_gene_stats.sort_values(by=['num_strains', 'num_neighbors'])

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl

# Example data (replace with your own dataframe)
df = vir_gene_stats  # Assuming 'vir_gene_stats' is your dataframe

# Count the occurrences of each (dir_neighbors, num_strains) pair
count_df = df.groupby(['dir_neighbors', 'num_strains']).size().reset_index(name='count')
count_df['count'] = np.log2(count_df['count'] + 1)

# Scatter plot with color based on the count of each pair
plt.figure(figsize=(8, 8))

# Create scatter plot and store the result in the variable
scatter = sns.scatterplot(x='dir_neighbors', y='num_strains', sizes=(50, 200), 
                          hue='count', data=count_df, palette='viridis', legend=False)

# Add a colorbar based on the 'count' hue
norm = mpl.colors.Normalize(vmin=count_df['count'].min(), vmax=count_df['count'].max())  # Normalize the color range
sm = plt.cm.ScalarMappable(cmap='viridis', norm=norm)  # Create the mappable object for colorbar
sm.set_array([])  # Empty array because the data is already embedded in the plot

# Pass the current axis (ax) to colorbar
plt.colorbar(sm, ax=scatter, label='Log2 Frequency')

# Plot titles and labels
plt.title('Scatter Plot of dir_neighbors vs num_strains with Frequency-Based Coloring')
plt.xlabel('dir_neighbors')
plt.ylabel('num_strains')

plt.savefig('virulence_neighbor_scatter.svg', format='svg')
plt.show()


# Find locus of genes of interest for yersiniabactin production

In [None]:
genome_and_locus = set([x for xs in rare_genome_stats.loc[vir_gene_stats.sort_values('dir_neighbors').tail(11).index].genomes.values for x in xs])
for genome, locus in [x.split('_') for x in genome_and_locus]:
    strain_gff = strain_vectors[genome]
    contig_locus = strain_contigs[genome][int(locus)]
    
    start = strain_gff[strain_gff.gene.isin(contig_locus)].iloc[0,1]
    end = strain_gff[strain_gff.gene.isin(contig_locus)].iloc[-1,2] 

In [None]:
def rgba_to_hex(rgba):
    # Extract RGBA components
    r, g, b, a = rgba
    
    # Convert RGB to range [0, 255]
    r = int(r * 255)
    g = int(g * 255)
    b = int(b * 255)
    
    # Convert to hex string and return
    return f"#{r:02x}{g:02x}{b:02x}"

In [None]:
genome_and_locus = [
    '158836.2032_100',
    '158836.1457_54',
    '2850492.3_41',
    '158836.1536_54',
    '158836.1537_54',
    '550.2451_23',
    '158836.434_43',
    '158836.416_44',
    '2077137.3_218'
] # same strains in genome and locus set but re-ordered, not necessary for plot but improves visual

In [None]:
import pandas as pd
from Bio.SeqRecord import SeqRecord
from Bio.SeqFeature import SeqFeature, FeatureLocation
from Bio.Graphics import GenomeDiagram


name = "test_fig"
gd_diagram = GenomeDiagram.Diagram(name)
# set max len
max_len = 0
gene_counts = Counter()
for genome, locus in [x.split('_') for x in genome_and_locus]:
    strain_gff = strain_vectors[genome]
    contig_locus = strain_contigs[genome][int(locus)]
    strain_gff = strain_gff[strain_gff.gene.isin(contig_locus)]

    
    for gene in strain_gff.gene.unique():
        gene_counts[gene] += 1
    
    max_len = max(max_len, strain_gff.end.max() - strain_gff.start.min())

# Normalize the gene counts to a range of 0-1
max_count = max(gene_counts.values()) if gene_counts else 1  # Prevent division by zero
normalized_counts = {gene: count / max_count for gene, count in gene_counts.items()}


# Create a custom colormap from white to orange
cmap = LinearSegmentedColormap.from_list("gradient", [(.7, .7, .7), (0, 0, 0)])

# Function to get the color based on the normalized count
def get_color_from_count(count):
    norm_count = normalized_counts.get(count, 0)  # Default to 0 if the gene is not in the counts
    return rgba_to_hex(cmap(norm_count))

for genome, locus in [x.split('_') for x in genome_and_locus]:
    
    strain_gff = strain_vectors[genome]
    contig_locus = strain_contigs[genome][int(locus)]
    strain_gff = strain_gff[strain_gff.gene.isin(contig_locus)]
     # Set colors: red for virulence factor, otherwise a gradient based on count
    colors = []
    for gene in strain_gff.gene.values:
        if gene in vir_gene_stats.index:
            colors.append('#FF0000')  # Red for virulence factors
        else:
            color = get_color_from_count(gene)  # Get the gradient color for the gene
            colors.append(color)

    
    genome_region_len = strain_gff.end.max() - strain_gff.start.min()
    offset = int((max_len - genome_region_len) / 2)
    
    if strain_gff[strain_gff.gene == 'Ebacter_C102'].strand.iloc[0] != '+':
        strain_gff = strain_gff.iloc[::-1]
        colors = colors[::-1]

        min_coord = strain_gff["start"].min()
        max_coord = strain_gff["end"].max()

        strain_gff_temp = strain_gff.copy()
        strain_gff_temp["start"] = max_coord - (strain_gff["end"] - min_coord)
        strain_gff_temp["end"] = max_coord - (strain_gff["start"] - min_coord)
        strain_gff_temp['strand'] = ['-' if x == '+' else '+' for x in strain_gff_temp.strand]
        strain_gff = strain_gff_temp

    track_name = str(genome) + '\t' + str(metadata.set_index('genome_id').loc[genome, 'isolation_country']) + '\t' + \
        str(metadata.set_index('genome_id').loc[genome, 'collection_date'])[:4] + '\n' + A_binarized.loc[characterized_order, genome].idxmax()
    gd_track_for_features = gd_diagram.new_track(
            1, name=track_name, greytrack=True, start=0, end=max_len, height = .5, greytrack_labels = 1
    )
    gd_feature_set = gd_track_for_features.new_set()

    initial_locus = strain_gff.start.min()
    # Add features to the track
    for i, (index, row) in enumerate(strain_gff.iterrows()):
        start = int(row['start'] - initial_locus) + offset
        end = int(row['end'] - initial_locus) + offset
        gene_name = row['gene']
        color = colors[i]
        strand = +1
        if row.strand == '-':
            strand = -1
        
        feature = SeqFeature(
            FeatureLocation(start, end),
            type="gene",
            qualifiers={"gene": gene_name}
        )
        feature.strand = strand
        
        gd_feature_set.add_feature(
            feature,
            sigil="BIGARROW",
            color=color,
            label=False,
            name=gene_name,
            label_position="start",
            label_size=6,
            label_angle=0,
        )


from Bio.Graphics.GenomeDiagram import CrossLink
from reportlab.lib import colors


for i in range(len(gd_diagram.tracks)-1):
    tracks1 = gd_diagram.tracks[i+1]
    tracks2 = gd_diagram.tracks[i+2]

    for feature1 in tracks1.get_sets()[0].get_features():
        for feature2 in tracks2.get_sets()[0].get_features():
            if feature1.name == feature2.name:
                color = colors.Color(red=(240/255),green=(240/255),blue=(240/255))
                color2 = colors.Color(red=(240/255),green=(240/255),blue=(240/255))
                if feature1.name in vir_gene_stats.index:
                    color = colors.Color(red=(255/255),green=(200/255),blue=(200/255))
                    color2  = colors.white
                link_xy = CrossLink(
                    (tracks1, feature1.location.start, feature1.location.end),
                    (tracks2, feature2.location.start, feature2.location.end),
                    color,
                    color2,
                )
                gd_diagram.cross_track_links.append(link_xy)
                

gd_diagram.draw(format="linear", pagesize="A4", fragments=1, start=0, end=max_len)
gd_diagram.write(name + ".svg", "SVG")