In [1]:
import os
import numpy as np
import pandas as pd
from pyfaidx import Fasta
import logomaker

import matplotlib
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
import matplotlib.patches as mpatches
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
matplotlib.rcParams['font.sans-serif'] = "Arial" # missing fonts:: https://alexanderlabwhoi.github.io/post/2021-03-missingfont/
# Then, "ALWAYS use sans-serif fonts"
matplotlib.rcParams['font.family'] = "sans-serif"
matplotlib.rcParams.update({'font.size': 6, 'axes.linewidth': 1, 'xtick.major.width': 1, 'xtick.major.size': 5, 'ytick.major.width': 1, 'ytick.major.size': 5})
from matplotlib.backends.backend_pdf import PdfPages



In [6]:
class ExampleAnnotator:
    def __init__(self, pred_file, seq_arr_file, seq_feat_file, loc_arr_file, genome_fasta, seq_attr_min=0.1, neighbor_attr_min=0.1, selected_motif_idx=[]):
        self.pred_df = pd.read_csv(pred_file, header=None, names=["ypred", "ytarget", "chrm", "start", "end"])
        self.seq_arr = np.load(seq_arr_file)
        self.feat_arr = np.load(seq_feat_file)
        self.loc_arr = np.load(loc_arr_file)
        self.genome = Fasta(genome_fasta, as_raw=True)
        self.single_motif_idx, self.multi_motif_idx = self.get_top_preds(seq_attr_min=seq_attr_min, neighbor_attr_min=neighbor_attr_min)
        self.selected_motif_idx = selected_motif_idx
        pass

    def create_seq_attr_df(self, seq, seq_attr):
        data_arr = np.zeros((len(seq), 4))
        nts = ["A", "T", "G", "C"]
        seq_idx = dict(zip(nts, range(len(nts))))
        for i, (s, sq) in enumerate(zip(seq, seq_attr)):
            data_arr[i, seq_idx[s]] = sq
        data_df = pd.DataFrame(data_arr, columns=nts)
        return data_df

    def get_complement(self, seq):
        complement_dict = {
            "A": "T",
            "G": "C",
            "T": "A",
            "C": "G",
            "N": "N"
            }    
        rev_comp_seq = "".join([complement_dict[s] for s in seq])
        return rev_comp_seq

    def create_seqattr_figure(self, seq, seq_attr, title="", nrows=5, figsize=(15,10)):
        seq_df = self.create_seq_attr_df(seq, seq_attr)
        bp_per_row=len(seq)//nrows
        fig, ax = plt.subplots(nrows, 1, figsize=figsize, sharey=True)
        start = 0
        for row in range(nrows):
            # create Logo object
            nn_logo = logomaker.Logo(seq_df.iloc[start:start+bp_per_row], ax=ax[row],)
            nn_logo.ax.set_xticks([])
            start += bp_per_row
            ax[row].spines[['right', 'top', "bottom"]].set_visible(False)
            ax[row].xaxis.set_visible(False)
            # ax[row].set_yticks([])
            ax[row].set_ylim(seq_attr.min(), seq_attr.max())
        fig.suptitle(title)
        return fig, ax

    def get_seq_info(self, idx, reverse=False):
        chrm, start, end = self.pred_df.iloc[idx].loc[["chrm", "start", "end"]].values
        seq = list(self.genome.get_seq(chrm, start, end))
        if reverse:
            seq = self.get_complement(seq)
        seq_attr = np.sum(self.seq_arr[idx], axis=0)
        return seq, seq_attr, (chrm, start, end)
    
    def get_top_preds(self, seq_attr_min=0.1, neighbor_attr_min=0.1):
        top_pred_idxs = self.pred_df.loc[
            (self.pred_df.ytarget==1)&(self.pred_df.ypred>0.9)
            ].sort_values("ypred", ascending=False).index
        single_peaks = []
        multiple_peaks = []
        for tpi in top_pred_idxs:
            seq, seq_attr, (chrm, start, end) = self.get_seq_info(tpi, reverse=False)
            if np.max(seq_attr)>seq_attr_min:
                highest_point = np.argmax(seq_attr)
                lower_threshold = highest_point - 20
                higher_threshold = highest_point + 20
                important_indices = np.argwhere(seq_attr>neighbor_attr_min).flatten()
                if np.all((important_indices>lower_threshold) & (important_indices<higher_threshold)):
                    # there is only one peak
                    single_peaks.append(tpi)
                else:
                    # there might be multiple peaks
                    multiple_peaks.append(tpi)
        return single_peaks, multiple_peaks

    def get_seq_figures(self, save_file, mode="multiple"):
        os.makedirs(os.path.dirname(save_file), exist_ok=True)
        with PdfPages(save_file) as pdf:
            if mode=="multiple":
                for idx in self.multi_motif_idx:
                    s, sa, (chrm, start, end) = self.get_seq_info(idx, reverse=False)
                    fig, ax = self.create_seqattr_figure(s, sa, title=f"{chrm}: {start}-{end}", nrows=5, figsize=(10,3))
                    pdf.savefig(fig, bbox_inches='tight',dpi=100)
                    plt.close(fig)  # Close to save memory
            elif mode=="single":
                for idx in self.single_motif_idx:
                    s, sa, (chrm, start, end) = self.get_seq_info(idx, reverse=False)
                    fig, ax = self.create_seqattr_figure(s, sa, title=f"{chrm}: {start}-{end}", nrows=5, figsize=(10,3))
                    pdf.savefig(fig, bbox_inches='tight',dpi=100)
                    plt.close(fig)  # Close to save memory
            elif mode=="selected":
                for idx in self.selected_motif_idx:
                    s, sa, (chrm, start, end) = self.get_seq_info(idx, reverse=False)
                    fig, ax = self.create_seqattr_figure(s, sa, title=f"{chrm}: {start}-{end}", nrows=5, figsize=(10,3))
                    pdf.savefig(fig, bbox_inches='tight',dpi=100)
                    plt.close(fig)  # Close to save memory                
            else:
                raise ValueError(f"Mode: {mode} is incorrect, must be one of single or multiple or selected")
        return


In [None]:

libs = ["CC", "ATF2", "CTCF", "FOXA1", "LEF1", "SCRT1", "TCF7L2", "16P12_1"]
genome_fasta = "/data5/deepro/genomes/hg38/GRCh38_no_alt_analysis_set_GCA_000001405.15.fasta"

for lib in libs:
    print(lib)
    nlinear_loc_arr_file = f"/data7/deepro/starrseq/4_ml_classification_fragment_category/data/{lib}/locations.npy"
    nlinear_seq_arr_file = f"/data7/deepro/starrseq/4_ml_classification_fragment_category/data/{lib}/seq_attr.npy"
    nlinear_seq_feat_file = f"/data7/deepro/starrseq/4_ml_classification_fragment_category/data/{lib}/seq_feat.npy"
    nlinear_pred_file = f"/data7/deepro/starrseq/4_ml_classification_fragment_category/data/{lib}/resnet_mlp.csv.gz"
    compare_file = f"/data7/deepro/starrseq/4_ml_classification_fragment_category/data/{lib}/cov_conf_compare.csv.gz"
    pred_df = pd.read_csv(nlinear_pred_file, header=None)
    pred_df["pred_idx"] = pred_df.index
    compare_df = pd.read_csv(compare_file)
    hcov_hconf_df = compare_df.loc[(
        (compare_df.ytarget==1)&(compare_df.cov_decile==9)&
        (compare_df.ypred_resnet>0.9)&(compare_df.ypred_linear<0.9)
        )].merge(pred_df, left_on=["chrm", "start", "end"], right_on=[2,3,4])
    selected_index = hcov_hconf_df.pred_idx.values
    ea = ExampleAnnotator(nlinear_pred_file, nlinear_seq_arr_file, nlinear_seq_feat_file, nlinear_loc_arr_file, genome_fasta, selected_motif_idx=selected_index)
    print(len(ea.single_motif_idx), len(ea.multi_motif_idx), len(ea.selected_motif_idx))
    # takes 268 mins to run
    for mode in ["single", "multiple", "selected"]:
        save_file = f"/data7/deepro/starrseq/4_ml_classification_fragment_category/data/{lib}/motif_{mode}.pdf"
        ea.get_seq_figures(save_file, mode=mode)
