In [None]:
import os
import sys
import copy

import numpy as np
import pandas as pd
import matplotlib
from matplotlib import pyplot as plt
import logomaker
import colorsys
import matplotlib.colors as mcolors

sys.path.append("/home/akubaney/projects/na_mpnn/evaluation")
from na_eval_utils import read_json_file

In [None]:
font_path = "./ARIAL.TTF"
matplotlib.font_manager.fontManager.addfont(font_path)

# 2) Tell Matplotlib to load it
plt.rcParams['font.family'] = "Arial"

In [None]:
# Default DNA color palette for sequence logos
DEFAULT_DNA_COLORS = {
    'A': '#00FF00',
    'C': '#0000FF',
    'G': '#FFA500',
    'T': '#FF0000'
}

# Default style dictionary for sequence logos
DEFAULT_SEQLOGO_STYLE = {
    'figsize': (40 / 25.4, 40 / 25.4),
    'dpi': 300,
    'title_fontsize': None,
    'axis_title_fontsize': None,
    'tick_labelsize': 6,
    'letter_fontsize': 12,
    'letter_color_dict': DEFAULT_DNA_COLORS,
    'show_title': False,
    'title': None,
    'x_label': None, #'Position',
    'y_label': None, #'Information Content (bits)',
    'show_axis_labels': True,
    'show_ticks': True,
    'background_color': 'white',
    'logo_type': 'information',  # 'information' or 'probability'
    'stack_order': 'big_on_top', # 'big_on_top' or 'small_on_top'
    'vpad': 0,  # Vertical padding
    'width': 1,  # Width of each position
    'y_max': 2,  # Maximum y-value for the logo
    'one_index_positions': True,  # Whether to use 1-indexing for x-axis
    'pad_left':  0,   # number of uniform columns to add on the left
    'pad_right': 0,   # number of uniform columns to add on the right
    'x_tick_labels_with_true_seq': True,   # Use A/C/G/T labels when true_sequence is provided
    'save_prefix': None,  # Prefix for saving figures
    'save_name': None,
    # New controls for consistent styling with other plots
    'spine_linewidth': 0.5,        # Axis spine thickness
    'tick_width': 0.5,             # Tick width
    'hide_top_right_spines': True, # Toggle for top/right spines
    'desaturate': 0.7              # 0=no change, 1=fully desaturated
}

In [None]:
# The token index, as defined by NA-MPNN.
DNA_RESTYPE_TO_INT = {
    "DA": 21,
    "DC": 22,
    "DG": 23,
    "DT": 24
}
DNA_INT_TO_RESTYPE = {v: k for k, v in DNA_RESTYPE_TO_INT.items()}
DNA_INT_LIST = [
    DNA_RESTYPE_TO_INT["DA"], 
    DNA_RESTYPE_TO_INT["DC"], 
    DNA_RESTYPE_TO_INT["DG"], 
    DNA_RESTYPE_TO_INT["DT"]
]
DNA_INT_TO_REMAPPED_INT = dict(zip(DNA_INT_LIST, range(len(DNA_INT_LIST))))

In [None]:
def load_predicted_pwm_and_true_sequence(json_path, num_chains_to_plot):
    """
    Load predicted PWM and true sequence data from JSON files.
    Args:
        json_path (str): Path to the JSON file containing evaluation results
        num_chains_to_plot (int): Number of DNA chains to include in the plot
    Returns:
        tuple: (reference_aligned_ppm, subject_predicted_ppm, 
            subject_true_sequence)
    """
    # Load the score and prediction JSON files.
    score_json_dict = read_json_file(json_path)
    prediction_json_dict = read_json_file(score_json_dict["subject_path"])

    # Extract the reference aligned PPM and mask.
    reference_aligned_ppm = np.array(
        score_json_dict["aligned_ppm"], dtype = np.float64
    )
    reference_ppm_mask = np.array(
        score_json_dict["ppm_mask"], dtype = np.int32
    )

    # Extract the subject predicted PPM, mask, and DNA mask.
    subject_predicted_ppm = np.array(
        prediction_json_dict["predicted_ppm_na_mpnn_format"], 
        dtype = np.float64
    )
    subject_mask = np.array(
        prediction_json_dict["mask"], dtype = np.int32
    )
    subject_dna_mask = np.array(
        prediction_json_dict["dna_mask"], 
        dtype = np.int32
    )
    subject_true_sequence = np.array(
        prediction_json_dict["true_sequence_na_mpnn_format"],
        dtype = np.int32
    )

    # Compute the position mask, which indicates which positions have both a
    # reference (ground truth) PPM and a subject (predicted) PPM, and are also
    # part of the DNA sequence.
    position_mask = (reference_ppm_mask == 1) & (subject_mask == 1) & \
        (subject_dna_mask == 1)
    
    # Get the chain labels of the DNA, and plot the first `num_chains_to_plot` 
    # unique chains.
    subject_chain_labels = np.array(
        prediction_json_dict["chain_labels"], 
        dtype = np.int32
    )
    dna_chain_labels = subject_chain_labels[position_mask]

    if num_chains_to_plot is None:
        chains_to_plot = np.unique(dna_chain_labels)
    else:
        chains_to_plot = np.unique(dna_chain_labels)[:num_chains_to_plot]

    # Compute a mask for the positions that are in the chains we want to plot.
    chain_mask = np.isin(subject_chain_labels, chains_to_plot)

    # Combine the position mask with the chain mask.
    position_mask = np.logical_and(position_mask, chain_mask)
    
    # Subset the positions.
    reference_aligned_ppm = reference_aligned_ppm[position_mask]
    subject_predicted_ppm = subject_predicted_ppm[position_mask]
    subject_true_sequence = subject_true_sequence[position_mask]

    # Subset the residue types.
    reference_aligned_ppm = reference_aligned_ppm[:, DNA_INT_LIST]
    subject_predicted_ppm = subject_predicted_ppm[:, DNA_INT_LIST]

    return reference_aligned_ppm, subject_predicted_ppm, subject_true_sequence

In [None]:
def plot_seq_logo(ppm, style = None, true_sequence=None):
    """
    Create a sequence logo from a position probability matrix (PPM).
    Args:
        ppm (numpy.ndarray): Position probability matrix with shape 
            (positions, 4)
        style (dict, optional): Style dictionary to control appearance. 
            See DEFAULT_SEQLOGO_STYLE.
        true_sequence (array-like, optional): Sequence in NA-MPNN int encoding
            aligned to ppm rows; used for x-axis labels as A/C/G/T when enabled.
    Returns:
        matplotlib.figure.Figure: The generated sequence logo figure
    """
    merged_style = copy.deepcopy(DEFAULT_SEQLOGO_STYLE)
    if style is not None:
        merged_style.update(style)

    # Optional uniform padding on either side
    pad_left  = int(merged_style.get('pad_left', 0))
    pad_right = int(merged_style.get('pad_right', 0))
    if pad_left > 0 or pad_right > 0:
        num_bases = ppm.shape[1]
        # left pad ppm
        if pad_left > 0:
            left_block = np.full((pad_left, num_bases), 1.0 / num_bases)
            ppm = np.vstack((left_block, ppm))
        # right pad ppm
        if pad_right > 0:
            right_block = np.full((pad_right, num_bases), 1.0 / num_bases)
            ppm = np.vstack((ppm, right_block))
        
        # pad true sequence; only necessary if the na-mpnn true sequence needs
        # to be padded.
        if true_sequence is not None and (len(true_sequence) != len(ppm)):
            true_sequence = np.concatenate((
                np.full(pad_left, np.nan, dtype=true_sequence.dtype),
                true_sequence,
                np.full(pad_right, np.nan, dtype=true_sequence.dtype)
            ))

    # Normalize the PPM; necessary due to floating point errors.
    ppm = np.array(ppm)
    ppm = ppm / np.sum(ppm, axis = -1, keepdims = True)

    # Create the DataFrame for logomaker.
    dna_names = ['A', 'C', 'G', 'T']
    df = pd.DataFrame(ppm, columns = dna_names)
    df = logomaker.transform_matrix(
        df,
        from_type = 'probability',
        to_type = 'information',
        normalize_values = False
    )

    # Desaturation helper: reduces saturation in HLS space by factor (0..1)
    def _desaturate_hex(hex_color: str, factor: float) -> str:
        try:
            rgb = mcolors.to_rgb(hex_color)
        except Exception:
            return hex_color
        h, l, s = colorsys.rgb_to_hls(*rgb)
        s2 = max(0.0, s * (1.0 - float(factor)))
        r2, g2, b2 = colorsys.hls_to_rgb(h, l, s2)
        return mcolors.to_hex((r2, g2, b2))

    desat_factor = float(merged_style.get('desaturate', 0.0) or 0.0)
    base_colors = merged_style.get('letter_color_dict', DEFAULT_DNA_COLORS) or DEFAULT_DNA_COLORS
    if desat_factor and desat_factor > 1e-12:
        letter_color_dict = {k: _desaturate_hex(v, desat_factor) for k, v in base_colors.items()}
    else:
        letter_color_dict = base_colors

    # Create the figure and axes for the logo.
    fig = plt.figure(
        figsize = merged_style['figsize'], dpi = merged_style['dpi'], constrained_layout = True
    )
    ax = fig.add_axes([0.15, 0.15, 0.8, 0.825])

    # Plot the sequence logo.
    logo = logomaker.Logo(
        df,
        ax = ax,
        color_scheme = letter_color_dict,
        stack_order = merged_style['stack_order'],
        vpad = merged_style['vpad'],
        width = merged_style['width']
    )
    
    # Style the letters in the logo.
    logo.style_glyphs(
        color_scheme = letter_color_dict,
        fontsize = merged_style['letter_fontsize']
    )

    # Set the y-axis max.
    if merged_style['y_max'] is not None:
        ax.set_ylim(0, merged_style['y_max'])

    # Show title.
    if merged_style['show_title'] and merged_style['title']:
        ax.set_title(
            merged_style['title'], fontsize = merged_style['title_fontsize']
        )
    
    # Axes labels and ticks.
    if merged_style['show_axis_labels']:
        if merged_style['x_label']:
            ax.set_xlabel(
                merged_style['x_label'], 
                fontsize = merged_style['axis_title_fontsize']
            )
        if merged_style['y_label']:
            ax.set_ylabel(
                merged_style['y_label'], 
                fontsize = merged_style['axis_title_fontsize']
            )
    else:
        ax.set_xlabel("")
        ax.set_ylabel("")

    if not merged_style['show_ticks']:
        ax.set_xticks([])
        ax.set_yticks([])
    else:
        # X tick labels: numeric or true sequence letters
        num_pos = df.shape[0]
        ax.set_xticks(np.arange(num_pos))
        use_true_seq = bool(merged_style.get('x_tick_labels_with_true_seq', False)) and (true_sequence is not None)
        if use_true_seq:
            labels = [DNA_INT_TO_RESTYPE.get(res_int, " ")[-1] for res_int in true_sequence]
            ax.set_xticklabels(labels, fontsize = merged_style['tick_labelsize'])
        else:
            # Optional 1-indexing of the x-axis.
            if merged_style.get('one_index_positions', True):
                ax.set_xticklabels(
                    np.arange(1, num_pos + 1),
                    fontsize = merged_style['tick_labelsize']
                )
            else:
                ax.set_xticklabels(
                    np.arange(num_pos),
                    fontsize = merged_style['tick_labelsize']
                )
        
        ax.tick_params(
            axis = "both", labelsize = merged_style['tick_labelsize']
        )
    
    # Set the background color.
    ax.set_facecolor(merged_style['background_color'])
    fig.patch.set_facecolor(merged_style['background_color'])

    # Apply spine/tick width and optionally hide top/right spines
    spine_lw = merged_style.get('spine_linewidth', None)
    tick_w = merged_style.get('tick_width', None)
    hide_tr = merged_style.get('hide_top_right_spines', False)

    if spine_lw is not None:
        for spine in ax.spines.values():
            spine.set_linewidth(spine_lw)
    if tick_w is not None:
        ax.tick_params(width=tick_w)
    if hide_tr:
        if 'top' in ax.spines:
            ax.spines['top'].set_visible(False)
        if 'right' in ax.spines:
            ax.spines['right'].set_visible(False)

    # Save the figure if a save name is provided.
    if merged_style['save_name']:
        plt.savefig(
            merged_style['save_name'], 
            dpi = merged_style['dpi'],
            pad_inches = 0
        )
    
    return fig

def plot_seq_logo_comparison(
    id, 
    na_mpnn_num_chains_to_plot = 1, 
    deeppbs_num_chains_to_plot = 1, 
    style = None, 
    show_all_titles = False,
    na_mpnn_pad_left=0,
    na_mpnn_pad_right=0,
    deeppbs_pad_left=0,
    deeppbs_pad_right=0
):
    """
    Create a comparison of sequence logos between reference, NA-MPNN 
    prediction, and DeepPBS prediction.
    
    Args:
        id (str): Identifier for the structure to compare
        na_mpnn_num_chains_to_plot (int): Number of DNA chains to plot for 
            NA-MPNN.
        deeppbs_num_chains_to_plot (int): Number of DNA chains to plot for 
            DeepPBS.
        style (dict, optional): Style dictionary for the sequence logos.
        show_all_titles (bool, optional): Whether to show all titles in the
            sequence logos.
        na_mpnn_pad_left (int, optional): Number of uniform columns to add
            on the left for NA-MPNN.
        na_mpnn_pad_right (int, optional): Number of uniform columns to add
            on the right for NA-MPNN.
        deeppbs_pad_left (int, optional): Number of uniform columns to add 
            on the left for DeepPBS.
        deeppbs_pad_right (int, optional): Number of uniform columns to add
            on the right for DeepPBS.
    Returns:
        list: List of matplotlib figures for the three sequence logos
    """
    # Overall output directory for the evaluation outputs.
    base_folder = os.path.join(
        "/",
        "home",
        "akubaney",
        "projects",
        "na_mpnn",
        "evaluation",
        "evaluation_outputs",
        "specificity_test_scores",
    )
    
    # Load the reference and predicted PPMs for NA-MPNN and DeepPBS.
    na_mpnn_score_path = os.path.join(
        base_folder, "na_mpnn", id, f"{id}.json"
    )
    deeppbs_score_path = os.path.join(
        base_folder, "deeppbs", id, f"{id}.json"
    )
    ref_ppm_na_mpnn, pred_ppm_na_mpnn, true_seq_na_mpnn = \
        load_predicted_pwm_and_true_sequence(
            na_mpnn_score_path, na_mpnn_num_chains_to_plot
    )
    ref_ppm_deeppbs, pred_ppm_deeppbs, true_seq_deeppbs = \
        load_predicted_pwm_and_true_sequence(
            deeppbs_score_path, deeppbs_num_chains_to_plot
    )

    print(f"NA-MPNN true sequence: {list(map(lambda res_int: DNA_INT_TO_RESTYPE[res_int], true_seq_na_mpnn))}")
    print(f"DeepPBS true sequence: {list(map(lambda res_int: DNA_INT_TO_RESTYPE[res_int], true_seq_deeppbs))}")
    
    save_prefix = style.get('save_prefix', None) if style else None

    figures = []
    # Plot the aligned reference NA-MPNN PPM.
    ref_na_mpnn_style = copy.deepcopy(style) if style else {}
    ref_na_mpnn_style.update({
        'show_title': show_all_titles,
        'title': f"NA-MPNN Reference - {id}",
        'pad_left': na_mpnn_pad_left,
        'pad_right': na_mpnn_pad_right,
        'save_name': f"{save_prefix}_{id}_na_mpnn_reference.svg" if save_prefix else None
    })
    print(ref_na_mpnn_style["title"])
    figures.append(plot_seq_logo(ref_ppm_na_mpnn, ref_na_mpnn_style, true_sequence=true_seq_na_mpnn))

    # Plot the aligned reference DeepPBS PPM.
    ref_deeppbs_style = copy.deepcopy(style) if style else {}
    ref_deeppbs_style.update({
        'show_title': show_all_titles,
        'title': f"DeepPBS Reference - {id}",
        'pad_left': deeppbs_pad_left,
        'pad_right': deeppbs_pad_right,
        'save_name': 
            f"{save_prefix}_{id}_deep_pbs_reference.svg" if save_prefix else None
    })
    print(ref_deeppbs_style["title"])
    figures.append(plot_seq_logo(ref_ppm_deeppbs, ref_deeppbs_style, true_sequence=true_seq_na_mpnn))

    # Plot the NA-MPNN predicted PPM.
    na_mpnn_style = copy.deepcopy(style) if style else {}
    na_mpnn_style.update({
        'show_title': show_all_titles,
        'title': f"NA-MPNN Prediction - {id}",
        'pad_left': na_mpnn_pad_left,
        'pad_right': na_mpnn_pad_right,
        'save_name': f"{save_prefix}_{id}_na_mpnn_prediction.svg" if save_prefix else None
    })
    print(na_mpnn_style["title"])
    figures.append(plot_seq_logo(pred_ppm_na_mpnn, na_mpnn_style, true_sequence=true_seq_na_mpnn))
    
    # Plot the DeepPBS predicted PPM.
    deeppbs_style = copy.deepcopy(style) if style else {}
    deeppbs_style.update({
        'show_title': show_all_titles,
        'title': f"DeepPBS Prediction - {id}",
        'pad_left': deeppbs_pad_left,
        'pad_right': deeppbs_pad_right,
        'save_name': f"{save_prefix}_{id}_deep_pbs_prediction.svg" if save_prefix else None
    })
    print(deeppbs_style["title"])
    figures.append(plot_seq_logo(pred_ppm_deeppbs, deeppbs_style, true_sequence=true_seq_na_mpnn))

    return figures

# Fig 4a (Data augmentation demonstration)

In [None]:
id = "DDB_G0278225_AAATGCCA"

base_folder = os.path.join(
    "/",
    "home",
    "akubaney",
    "projects",
    "na_mpnn",
    "evaluation",
    "evaluation_outputs",
    "specificity_test_scores",
)

# Load the reference and predicted PPMs for NA-MPNN and DeepPBS.
na_mpnn_score_path = os.path.join(
    base_folder, "na_mpnn", id, f"{id}.json"
)
ref_ppm_na_mpnn, pred_ppm_na_mpnn, true_seq_na_mpnn = \
    load_predicted_pwm_and_true_sequence(
        na_mpnn_score_path, 1
)

In [None]:
list(map(lambda res_int: DNA_INT_TO_RESTYPE[res_int], true_seq_na_mpnn))

In [None]:
ref_ppm_na_mpnn_df = pd.DataFrame(
    ref_ppm_na_mpnn, 
    columns = ["DA", "DC", "DG", "DT"]
) 
ref_ppm_na_mpnn_df

In [None]:
# Data Augmentation Demonstration
fig = plot_seq_logo(ref_ppm_na_mpnn, true_sequence=true_seq_na_mpnn)

# Fig 4b (Best Distillation)

In [None]:
plot_seq_logo_comparison(
    "SCHCODRAFT_80572_AAAGCCAC",
    style = {
        "save_prefix": "/home/akubaney/projects/na_mpnn/figures/matplotlib/best_from_distillation"
    }
)

In [None]:
plot_seq_logo_comparison(
    "NCU09387_AAACAAAG",
    style = {
        "save_prefix": "/home/akubaney/projects/na_mpnn/figures/matplotlib/best_from_distillation"
    },
    deeppbs_pad_right = 2
)

In [None]:
plot_seq_logo_comparison(
    "DDB_G0278225_AAATGCCA",
    style = {
        "save_prefix": "/home/akubaney/projects/na_mpnn/figures/matplotlib/best_from_distillation"
    }
)

# Fig 6 (Same Protein)

In [None]:
plot_seq_logo_comparison(
    "SCHCODRAFT_80572_AAAGCCAC",
    na_mpnn_pad_left = 1,
    deeppbs_pad_left = 1,
    style = {
        "save_prefix": "/home/akubaney/projects/na_mpnn/figures/matplotlib/same_protein"
    }
)

In [None]:
plot_seq_logo_comparison(
    "SCHCODRAFT_80572_AACGCCAT",
    na_mpnn_pad_right = 1,
    deeppbs_pad_right = 1,
    style = {
        "save_prefix": "/home/akubaney/projects/na_mpnn/figures/matplotlib/same_protein"
    }
)

In [None]:
plot_seq_logo_comparison(
    "SCHCODRAFT_80572_AACGCCAC",
    style = {
        "save_prefix": "/home/akubaney/projects/na_mpnn/figures/matplotlib/same_protein"
    }
)

# Fig 7a (Best Crystal)

In [None]:
plot_seq_logo_comparison(
    "1am9", 
    na_mpnn_num_chains_to_plot=2, 
    deeppbs_num_chains_to_plot=1,
    style = {
        "save_prefix": "/home/akubaney/projects/na_mpnn/figures/matplotlib/best_from_crystal",
        "figsize": (80 / 25.4, 40 / 25.4)
    }
)

In [None]:
plot_seq_logo_comparison(
    "6u81",
    style = {
        "save_prefix": "/home/akubaney/projects/na_mpnn/figures/matplotlib/best_from_crystal"
    },
    deeppbs_pad_left = 4
)