In [None]:
import logging
import re
import urllib
from io import StringIO
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import gzip
import pickle
from tqdm.notebook import tqdm, trange
import multiprocessing
from IPython.display import display, HTML
import itertools
from statistics import mode

%matplotlib inline
%config InlineBackend.figure_format = 'svg'

In [None]:
PATH_TO_DATA = '../../data'

In [None]:
DF_GENES = f'{PATH_TO_DATA}/processed/cd-hit-results/sim80/Ebacter_strain_by_gene.pickle.gz'
ENRICHED_METADATA = f'{PATH_TO_DATA}/metadata/enriched_metadata.csv'
GFF_Paths = f'{PATH_TO_DATA}/processed/bakta/'
DF_EGGNOG = f'{PATH_TO_DATA}/processed/df_eggnog.csv'
DF_CORE_COMPLETE = f'{PATH_TO_DATA}/processed/CAR_genomes/df_core_complete.pickle'
L_BINARIZED = f'{PATH_TO_DATA}/processed/nmf-outputs/L_binarized.csv'

In [None]:
df_eggnog = pd.read_csv(DF_EGGNOG, low_memory=False).set_index('gene')

In [None]:
metadata = pd.read_csv(ENRICHED_METADATA, index_col = 0, dtype='object')
complete_metadata = metadata[metadata.genome_status == 'Complete']

In [None]:
df_core = pd.read_pickle(DF_CORE_COMPLETE)
df_genes = pd.read_pickle(DF_GENES)
df_genes

In [None]:
L_binarized = pd.read_csv(L_BINARIZED, index_col=0)
L_binarized

In [None]:
# create dict where each strain has a gene vector
strain_vectors = {}

In [None]:
with gzip.open(f'{PATH_TO_DATA}/processed/cd-hit-results/header_to_allele_80.pickle.gz', 'rb') as f:
    header_to_allele = pickle.load(f)

## Functions to parse GFF

In [None]:
def extract_contig_sizes(gff_file, index=None):
     if isinstance(gff_file, str):
        gff_file = [gff_file]

    for gff in gff_file:
        with open(gff, "r") as f:
            lines = f.readlines()

        # Get lines to skip
        skiprow = sum([line.startswith("#") for line in lines]) - 2

        # Read GFF
        names = [
            "accession",
            "source",
            "feature",
            "start",
            "end",
            "score",
            "strand",
            "phase",
            "attributes",
        ]
        DF_gff = pd.read_csv(gff, sep="\t", skiprows=skiprow, names=names, header=None, low_memory=False)

def _get_attr(attributes, attr_id, ignore=False):
    """
    Helper function for parsing GFF annotations

    Parameters
    ----------
    attributes : str
        Attribute string
    attr_id : str
        Attribute ID
    ignore : bool
        If true, ignore errors if ID is not in attributes (default: False)

    Returns
    -------
    str, optional
        Value of attribute
    """

    try:
        return re.search(attr_id + "=(.*?)(;|$)", attributes).group(1)
    except AttributeError:
        if ignore:
            return None
        else:
            raise ValueError("{} not in attributes: {}".format(attr_id, attributes))

def gff2pandas(gff_file, feature=["CDS"], index=None):
    """
    Converts GFF file(s) to a Pandas DataFrame
    Parameters
    ----------
    gff_file : str or list
        Path(s) to GFF file
    feature: str or list
        Name(s) of features to keep (default = "CDS")
    index : str, optional
        Column or attribute to use as index

    Returns
    -------
    df_gff: ~pandas.DataFrame
        GFF formatted as a DataFrame
    """

    # Argument checking
    if isinstance(gff_file, str):
        gff_file = [gff_file]

    if isinstance(feature, str):
        feature = [feature]

    result = []

    for gff in gff_file:
        with open(gff, "r") as f:
            lines = f.readlines()

        # Get lines to skip
        skiprow = sum([line.startswith("#") for line in lines]) - 2

        # Read GFF
        names = [
            "accession",
            "source",
            "feature",
            "start",
            "end",
            "score",
            "strand",
            "phase",
            "attributes",
        ]
        DF_gff = pd.read_csv(gff, sep="\t", skiprows=skiprow, names=names, header=None, low_memory=False)
        
        region = DF_gff[DF_gff.feature == 'region']
        region_len = int(region.iloc[0].end)

        oric = 0
        # try:
        #     oric = list(DF_gff[DF_gff.feature == 'oriC'].start)[0]
        # except:
        #     oric = [0]
        
        # Filter for CDSs
        DF_cds = DF_gff[DF_gff.feature.isin(feature)]

        # Sort by start position
        DF_cds = DF_cds.sort_values("start")

        # Extract attribute information
        DF_cds["locus_tag"] = DF_cds.attributes.apply(_get_attr, attr_id="locus_tag")

        result.append(DF_cds)

    DF_gff = pd.concat(result)

    if index:
        if DF_gff[index].duplicated().any():
            logging.warning("Duplicate {} detected. Dropping duplicates.".format(index))
            DF_gff = DF_gff.drop_duplicates(index)
        DF_gff.set_index("locus_tag", drop=True, inplace=True)

    return DF_gff[['accession', 'start', 'end', 'locus_tag']], region_len, oric

## Get vectors of genes on chromosomes

In [None]:
def h2a(x):
    try:
        return header_to_allele[x].split('A')[0]
    except:
        return None
        
    
for strain in tqdm(complete_metadata.genome_id):
    DF_gff, size, oric = gff2pandas(f'{PATH_TO_DATA}/processed/bakta/{strain}/{strain}.gff3')
    DF_gff['gene'] = DF_gff.locus_tag.apply(lambda x: h2a(x))
    DF_gff = DF_gff[DF_gff.accession == DF_gff.accession.value_counts().index[0]]
    DF_gff = DF_gff[['gene','start']]
    gene_order = (DF_gff.sort_values('start').gene.to_list())
 
    strain_vectors[strain] = gene_order

In [None]:
gene_chrom_counts = {}
gene_totals = df_genes.loc[L_binarized.index,df_core.columns].sum(axis=1)
for gene in tqdm(L_binarized.index):
    count = 0
    total_count = gene_totals.loc[gene]
    for strain in strain_vectors.keys():
        if gene in strain_vectors[strain]:
            count +=1
    gene_chrom_counts[gene] = count/total_count

In [None]:
gene_locs = pd.DataFrame.from_dict(gene_chrom_counts, orient='index', columns=['chrom_presence'])
gene_locs['location'] = gene_locs.chrom_presence.apply(lambda x: 'chrom' if x > .5 else 'plasmid')
gene_locs

In [None]:
gene_locs.to_csv('acc_gene_location.csv')

# Location of all genes

In [None]:
gene_chrom_counts = {}
gene_totals = df_genes.loc[:,df_core.columns].sum(axis=1)
for gene in tqdm(df_genes.index):
    count = 0
    total_count = gene_totals.loc[gene]
    for strain in strain_vectors.keys():
        if gene in strain_vectors[strain]:
            count +=1
    gene_chrom_counts[gene] = count/total_count

In [None]:
gene_locs = pd.DataFrame.from_dict(gene_chrom_counts, orient='index', columns=['chrom_presence']).dropna()
gene_locs['location'] = gene_locs.chrom_presence.apply(lambda x: 'chrom' if x > .5 else 'plasmid')
gene_locs

In [None]:
gene_locs.to_csv('complete_gene_location.csv')

In [None]:
from scipy.cluster.hierarchy import linkage, leaves_list
from matplotlib.patches import Patch


# Main sorted clustermap
custom_cmap = matplotlib.colors.LinearSegmentedColormap.from_list('custom_cmap', ['white', 'white'])

g = sns.clustermap(
    L_binarized.loc[gene_order, phylon_order],
    method='ward',
    metric='euclidean',
    row_cluster=False,
    yticklabels=False,
    cmap=custom_cmap,
    col_colors=list(clr.values()),
    cbar_pos=None,
    figsize=(10,12)
);

# Perform hierarchical clustering
linkage_matrix = linkage(L_binarized.loc[gene_order], method='ward', metric='euclidean')

# Get the order of rows based on the clustering
row_order = leaves_list(linkage_matrix)

# Reorder the data based on clustering
data_ordered = L_binarized.loc[gene_order, phylon_order]

# Create masks for the heatmaps
mask1 = ~highlight_mask.loc[gene_order, phylon_order]
mask2 = highlight_mask.loc[gene_order, phylon_order]

# Create the figure and the axes
# fig, ax = plt.subplots(figsize=(10, 10))
ax = g.ax_heatmap

custom_cmap = matplotlib.colors.LinearSegmentedColormap.from_list('custom_cmap', ['white', 'black'])
# Plot the first heatmap
sns.heatmap(data_ordered, ax=ax, cmap=custom_cmap, mask=mask1, cbar=False, yticklabels=False)

custom_cmap = matplotlib.colors.LinearSegmentedColormap.from_list('custom_cmap', ['white', 'red'])
# Overlay the second heatmap
sns.heatmap(data_ordered, ax=ax, cmap=custom_cmap, mask=mask2, cbar=False, yticklabels=False)

# # Add a colorbar for each heatmap
# norm1 = plt.Normalize(vmin=data_ordered.min().min(), vmax=data_ordered.max().max())
# sm1 = plt.cm.ScalarMappable(cmap='Greys', norm=norm1)
# sm1.set_array([])
# cbar1 = fig.colorbar(sm1, ax=ax, orientation='vertical', fraction=0.05, pad=0.02)
# cbar1.set_label('Non-Highlighted Data')

# sm2 = plt.cm.ScalarMappable(cmap='Blues', norm=norm1)  # Use same norm for consistency
# sm2.set_array([])
# cbar2 = fig.colorbar(sm2, ax=ax, orientation='vertical', fraction=0.05, pad=0.02)
# cbar2.set_label('Highlighted Data')

legend_elements = [
    Patch(facecolor='black', edgecolor='black', label='Chromosomal'),
    Patch(facecolor='red', edgecolor='red', label='Plasmid')
]
ax.legend(handles=legend_elements, loc='upper right')


plt.show()