In [4]:
import slim_conservation_scoring.seqtools.general_utils as tools
import pandas as pd
import numpy as np
import slim_conservation_scoring.pipeline.group_conservation_objects as group_tools
import pairk
from pathlib import Path
from Bio import AlignIO, Seq, SeqIO, Align
import seaborn as sns
from sklearn import metrics
from slim_conservation_scoring.seqtools import pssms
import logomaker as lm
from ast import literal_eval
import slim_conservation_scoring.conservation_scores.tools.pairwise_tools as pairwise_tools
import slim_conservation_scoring.conservation_scores.tools.score_plots as score_plots
import slim_conservation_scoring.conservation_scores.tools.basic_plotting as basic_plots
from dataclasses import dataclass

import matplotlib.pyplot as plt
plt.style.use('custom_standard')
import seaborn as sns

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
def convert_jsonfile_to_relative(json_file):
    return json_file.replace("/home/jch/Documents/08-benchmark/", "../../")

In [5]:
position_weights = {
    "DOC_WW_Pin1_4": np.array([0, 0, 0, 1, 1, 0]),  # ...([ST])P.
    "LIG_AP2alpha_2": np.array([1, 1, 1]),  # DP[FW]
    "LIG_EH_1": np.array([0, 1, 1, 1, 0]),  # .NPF.
    "LIG_SH2_GRB2like": np.array([1, 1, 1, 0]),  # (Y)([EDST]|[MLIVAFYHQW])N.
    "LIG_SH3_CIN85_PxpxPR_1": np.array([1, 0, 1, 0, 1, 1]),  # P.[AP].PR
    "enah_LPPPP_FPPPP": np.array([2, 1, 0, 1, 1]),  # [FWYL]P.[AFILTVYWP]P
    "TRAF6": np.array([0, 0, 0, 1, 0, 1, 0, 0, 1]),  # ...P.E..[FYWDE]
}

table_file = (
    "../../benchmark/benchmark_v4/p3_conservation/benchmark_table_ANNOTATED.csv"
)
df = pd.read_csv(table_file)
df = df[
    df["ELM_motif_class"] != "LIG_14-3-3_CanoR_1"
]  # this motif has a variable length regex and so it's more difficult to apply any position weighting
df = df[
    [
        "reference_index",
        "ELM_motif_class",
        "Organism",
        "UniprotID",
        "regex",
        "hit_sequence",
        "gene_id",
        "hit start position",
        "hit end position",
        "verified interaction",
        "name",
        "json_file",
        "critical_error",
    ]
]
df = df[df["critical_error"].isna()]
df["json_file"] = df["json_file"].apply(convert_jsonfile_to_relative)
df["weight_array"] = df["ELM_motif_class"].map(position_weights)

In [6]:
def json_2_z_score_list(json_file, level, scorekey):
    og = group_tools.ConserGene(
        json_file, filepath_converter=convert_jsonfile_to_relative
    )
    if level not in og.levels_passing_filters:
        return
    lvlo = og.get_level_obj(level, filepath_converter=convert_jsonfile_to_relative)
    if scorekey not in lvlo.conservation_scores:
        return
    if "hit_z_scores" not in lvlo.conservation_scores[scorekey]:
        return
    return lvlo.conservation_scores[scorekey]["hit_z_scores"]

def add_scorelist_2_df(df, level, scorekey):
    colname = f"{level}_{scorekey}_z_scores"
    df[colname] = df["json_file"].apply(
        lambda x: json_2_z_score_list(x, level, scorekey)
    )
    return df

In [7]:
from attrs import asdict, define, field, validators


@define
class PairwiseScoreResults:
    flanked_hit: str
    flanked_hit_start_position_in_idr: int
    original_hit_st_in_flanked_hit: int
    original_hit_end_in_flanked_hit: int
    function_name: str
    function_params: dict
    lflank: int
    rflank: int
    kmer_aln_file: str | Path
    flanked_hit_sequence: str
    flanked_hit_scores: list
    flanked_hit_z_scores: list
    hit_sequence: str
    hit_scores: list
    hit_z_scores: list
    pairk_conservation_params: dict
    bg_std: float

    def __attrs_post_init__(self):
        self.kmer_aln_file = convert_jsonfile_to_relative(self.kmer_aln_file)


@dataclass
class AlnScoreResults:
    file: str
    function_name: str
    function_params: dict
    hit_scores: list
    hit_z_scores: list


def slice_aln_scores(lvlo: group_tools.LevelAlnScore, aln_start, aln_end):
    hit_slice = slice(aln_start, aln_end + 1)
    hit_scores = lvlo.scores[hit_slice]
    hit_z_scores = lvlo.z_scores[hit_slice]
    hit_aln_seq = lvlo.query_aln_sequence[hit_slice]
    return hit_scores, hit_z_scores, hit_aln_seq

In [8]:
def json2logoplot_alnscore(
    jsonfile, score_key, with_gaps=False, axes=None, level="Vertebrata", flank=5
):
    og = group_tools.ConserGene(
        jsonfile, filepath_converter=convert_jsonfile_to_relative
    )
    lvlo = og.get_aln_score_obj(
        level, score_key, filepath_converter=convert_jsonfile_to_relative
    )
    flst, flend, flhit = tools.pad_hit(
        og.query_idr_sequence,
        og.hit_st_in_idr,
        og.hit_end_in_idr,
        l_flank=flank,
        r_flank=flank,
    )
    query_idr, index = tools.reindex_alignment_str(
        lvlo.query_aln_sequence[lvlo.idr_aln_start : lvlo.idr_aln_end + 1]
    )
    flstaln, flendaln = index[flst], index[flend]
    flanked_hit_scores, flanked_hit_z_scores, flhit_aln_seq = slice_aln_scores(
        lvlo, flstaln + lvlo.idr_aln_start, flendaln + lvlo.idr_aln_start
    )
    idr_aln = lvlo.aln[:, lvlo.idr_aln_start : lvlo.idr_aln_end + 1]
    flhit_aln = idr_aln[:, flstaln : flendaln + 1]

    if not with_gaps:
        seqlist, query_slice, nongapinds = score_plots.strip_gaps_from_slice(
            flhit_aln, flhit_aln_seq
        )
        score_list = list(np.array(flanked_hit_z_scores)[nongapinds])
    else:
        seqlist = [str(i.seq) for i in list(flhit_aln)]
        query_slice = flhit_aln_seq
        score_list = flanked_hit_z_scores
    if axes is None:
        fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(10, 4))
    basic_plots.plot_score_bar_plot(
        ax=axes[0],
        score_list=score_list,
        query_seq=query_slice,
    )
    basic_plots.plot_logo(
        ax=axes[1],
        str_list=seqlist,
        tick_label_str=query_slice,
    )
    counts = pssms.alignment_2_counts(seqlist, show_plot=False, heatmap=False)
    return counts

In [9]:
def json2logoplot(
    jsonfile, score_key, axes=None, level="Vertebrata"
):
    og = group_tools.ConserGene(
        jsonfile, filepath_converter=convert_jsonfile_to_relative
    )
    lvlo = og.get_level_obj(level, filepath_converter=convert_jsonfile_to_relative)
    result = PairwiseScoreResults(**lvlo.conservation_scores[score_key])
    mat_res = pairk.PairkAln.from_file(result.kmer_aln_file)
    subseqdf = mat_res.orthokmer_matrix.copy()
    seqlist = subseqdf.loc[result.flanked_hit_start_position_in_idr].to_list()
    if axes is None:
        fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(10, 4))
    basic_plots.plot_score_bar_plot(
        ax=axes[0],
        score_list=result.flanked_hit_z_scores,
        query_seq=result.flanked_hit_sequence,
    )
    basic_plots.plot_logo(
        ax=axes[1], str_list=seqlist, tick_label_str=result.flanked_hit_sequence
    )
    counts = pssms.alignment_2_counts(seqlist, show_plot=False, heatmap=False)
    return counts, result

In [10]:
def composite_plot(s, level, pairkey = "pairk_aln_lf5_rf5_edssmat50", flank=5):
    jsonfile = s["json_file"]
    fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(20, 4))
    _ = json2logoplot_alnscore(
        jsonfile,
        "aln_property_entropy",
        axes=ax[:, 0],
        level=level,
        with_gaps=True,
        flank=flank,
    )
    # automatically uses whatever scoring parameters are in the json files (should be just normal right now)
    _ = json2logoplot_alnscore(
        jsonfile,
        "aln_property_entropy",
        axes=ax[:, 1],
        level=level,
        with_gaps=False,
        flank=flank,
    )
    _ = json2logoplot(
        jsonfile,
        pairkey,
        axes=ax[:, 2],
        level=level,
    )
    for axi in ax[0, :]:
        axi.set_ylim([-4, 4])
    plt.tight_layout()
    return fig, ax

### finding the hits to make logos for

In [11]:
OUTPUT_FOLDER = Path("logos")
OUTPUT_FOLDER.mkdir(exist_ok=True)
instances = [
    ("9606_0:00294e", 553, "Vertebrata", {}),
    ("9606_0:000643", 266, "Metazoa", {}),
    ("9606_0:000fe9", 1377, "Metazoa", {}),
    ("9606_0:003dd5", 1241, "Vertebrata", dict(pairkey="pairk_aln_lf0_rf0_edssmat50", flank=0)),
]
for i in instances:
    temp = df[(df["gene_id"] == i[0]) & (df["hit start position"] == i[1])].copy()
    assert len(temp) == 1
    temp = temp.iloc[0]
    fig, ax = composite_plot(temp, i[2], **i[3])
    filename = f"{temp['reference_index']}-{temp['name']}-{temp['UniprotID']}-{i[2]}-{temp['gene_id'].replace(':','-')}.png"
    fig.savefig(
        OUTPUT_FOLDER / filename,
        bbox_inches="tight",
        dpi=300,
    )
    plt.close(fig)

In [50]:
OUTPUT_FOLDER = Path("../figure4/logos/")
OUTPUT_FOLDER.mkdir(exist_ok=True)
instances = [
    ("9606_0:000b76", 1201, "Vertebrata"),
    ("9606_0:000b76", 1201, "Metazoa"),
]
for i in instances:
    temp = df[(df["gene_id"] == i[0]) & (df["hit start position"] == i[1])].copy()
    assert len(temp) == 1
    temp = temp.iloc[0]
    fig, ax = composite_plot(temp, i[2])
    filename = f"{temp['reference_index']}-{temp['name']}-{temp['UniprotID']}-{i[2]}-{temp['gene_id'].replace(':','-')}.png"
    fig.savefig(
        OUTPUT_FOLDER / filename,
        bbox_inches="tight",
        dpi=300,
    )
    plt.close(fig)