In [6]:
import os
from glob import glob
import pickle

import numpy as np
import pandas as pd
import tensorflow as tf
from matplotlib import pyplot as plt, cycler
from scipy.spatial import cKDTree
from matplotlib.colors import LinearSegmentedColormap

from config import (SEP, PREDS_PATH, MODEL_NAME, ADDED_TOKENS_PER_SEQ, SEQ_CUTOFF, MODEL_PATH, MARKER_SCALE_ALL,
                    MARKER_SCALE_POS_ONLY)
from utils.dim_reduction import (calculate_umap, calculate_tsne)
from utils.tokenizer import tokenize_seqs
from typing import List, Dict, Tuple


VIS_COMBS = [
    ["NLReff_test", "bass_ntm_domain_test", "fass_ntm_domain_test", "fass_ctm_domain_test"],
    ["bass_01_ntm_domain_test", "bass_02_ntm_domain_test", "bass_03_ntm_domain_test", "bass_06_ntm_domain_test", "het-s_ntm_domain_test", "sigma_ntm_domain_test", "pp_ntm_domain_test"]
]

In [7]:
def create_and_save_plot(x, y, probabilities, title, filename, set_names, sets_sizes, base_colors, figsize=(12.8, 9.6), dpi=100):
    """
    Create a scatter plot with separate color gradients for each set and return the figure and axes objects.

    Args:
        x (list): X-coordinates of the data points
        y (list): Y-coordinates of the data points
        probabilities (list): Probability values for each point
        title (str): Title of the plot (used for the directory name)
        filename (str): Name of the file to save the plot
        set_names (list): Names of the sets to be plotted
        sets_sizes (dict): Dictionary mapping set names to their sizes
        base_colors (list): List of base colors for each set
        figsize (tuple): Figure size in inches (default: (12.8, 9.6))
        dpi (int): Dots per inch for the saved figure (default: 100)
    Returns:
        fig: The matplotlib Figure object
        ax: The matplotlib Axes object
        tree: KDTree object for efficient nearest neighbor search
    """
    if len(x) != len(y) or len(x) != len(probabilities):
        raise ValueError("x, y, and probabilities must have the same length")
    
    if sum(sets_sizes.values()) != len(x):
        raise ValueError("Sum of set sizes must equal the number of data points")

    fig, ax = plt.subplots(figsize=figsize)
    ax.axis("on")  # Turn the axis on to show coordinates

    start_index = 0
    for set_name, base_color in zip(set_names, base_colors):
        set_size = sets_sizes[set_name]
        end_index = start_index + set_size

        light_color = plt.cm.colors.to_rgba(base_color, 0.3)
        dark_color = plt.cm.colors.to_rgba(base_color, 1.0)
        cmap = LinearSegmentedColormap.from_list(f"custom_{set_name}", [light_color, dark_color])

        set_probs = probabilities[start_index:end_index]
            
        sc = ax.scatter(x[start_index:end_index], y[start_index:end_index], 
                        c=set_probs, cmap=cmap, 
                        s=20, label=set_name)
            
        start_index = end_index

    ax.legend()
    ax.set_title(title)

    # Create a KD-tree for efficient nearest neighbor search
    tree = cKDTree(list(zip(x, y)))
    
    return fig, ax, tree

def find_nearest_points(tree, query_point, k=1):
    """
    Find the k nearest points to the query point.
    
    Args:
        tree (cKDTree): KD-tree of the plotted points
        query_point (tuple): (x, y) coordinates of the query point
        k (int): Number of nearest neighbors to find
    
    Returns:
        list of tuples: (distance, index) for the k nearest points
    """
    distances, indices = tree.query(query_point, k=k)
    return list(zip(distances, indices))


In [8]:
def process_combinations(VIS_COMBS: List[Tuple[str, ...]], 
                         PREDS_PATH: str, 
                         MODEL_NAME: str, 
                         SEP: str, 
                         SEQ_CUTOFF: int, 
                         ADDED_TOKENS_PER_SEQ: int, 
                         MODEL_PATH: str):
    for comb in VIS_COMBS:
        frags, sets_sizes, marker_sizes, all_ids = collect_fragments(comb, PREDS_PATH, MODEL_NAME, SEP)
        base_colors = determine_colors(comb)
        
        frags = tokenize_seqs(frags, SEQ_CUTOFF + ADDED_TOKENS_PER_SEQ)
        mdim_rep = extract_embeddings(frags, MODEL_PATH)
        
        x_umap, y_umap = calculate_umap(mdim_rep)
        
        save_plot(x_umap, y_umap, marker_sizes, comb, sets_sizes, base_colors)

def collect_fragments(comb: Tuple[str, ...], 
                      PREDS_PATH: str, 
                      MODEL_NAME: str, 
                      SEP: str) -> Tuple[List[str], Dict[str, int], List[float], List[str]]:
    frags = []
    sets_sizes = {}
    marker_sizes = []
    all_ids = []
    
    for set_name in comb:
        pred = pd.read_csv(os.path.join(PREDS_PATH, f'{set_name}.{MODEL_NAME}comb123456.csv'), sep=SEP)
        d = abs(pred["class"] - pred["prob"])
        valid_indices = d <= 1.0
        
        sets_sizes[set_name] = valid_indices.sum()
        frags.extend(pred.loc[valid_indices, "frag"])
        marker_sizes.extend(pred.loc[valid_indices, "prob"])
        all_ids.extend(pred.loc[valid_indices, "id"])

    with open(f'ids.{".".join(comb)}.pkl', 'wb') as f:
        pickle.dump(all_ids, f)
    
    return frags, sets_sizes, marker_sizes, all_ids

def determine_colors(comb: Tuple[str, ...]) -> List[str]:
    if "PB40" in comb[0] and "NLReff" in comb[1]:
        return ["gray", "darkgray", "blue", "green", "red"]
    elif "NLReff" in comb[0]:
        return ["darkgray", "blue", "green", "red", "deeppink", "orange", "cyan", "black"]
    else:
        return ["blue", "green", "red", "deeppink", "orange", "cyan", "black"]

def extract_embeddings(frags: List[str], MODEL_PATH: str) -> np.ndarray:
    mdim_rep = []
    print("Extracting embeddings from specified layer: dropout")
    
    for model_filepath in glob(os.path.join(MODEL_PATH, "*")):
        model = tf.keras.models.load_model(model_filepath)
        layer_out = model.get_layer("dropout").output
        fun = tf.keras.backend.function(model.input, layer_out)
        mdim_rep.append(fun([frags, np.zeros((len(frags), 8943), dtype=np.int8)]))
    
    return np.concatenate(mdim_rep, axis=1)

def save_plot(x_umap: np.ndarray, 
              y_umap: np.ndarray, 
              marker_sizes: List[float], 
              comb: Tuple[str, ...], 
              sets_sizes: Dict[str, int], 
              base_colors: List[str]):
    print("Saving plots...")
    probabilities = np.array(marker_sizes)
    fig, ax, tree = create_and_save_plot(x_umap, y_umap, probabilities, "UMAP Visualization", 
                                         f'gradient.{".".join(comb)}.pdf', 
                                         comb, sets_sizes, base_colors)
    
    with open(f'tree.{".".join(comb)}.pkl', 'wb') as f:
        pickle.dump(tree, f)
    
    save_path = os.path.join("plots", "UMAP Visualization", f'gradient.{".".join(comb)}.pdf')
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    fig.savefig(save_path, dpi=100)
    plt.close(fig)


In [9]:
process_combinations(VIS_COMBS, PREDS_PATH, MODEL_NAME, SEP, SEQ_CUTOFF, ADDED_TOKENS_PER_SEQ, MODEL_PATH)

Extracting embeddings from specified layer: dropout
Calculating 2dim UMAP of 2686 elements...
Saving plots...
Extracting embeddings from specified layer: dropout
Calculating 2dim UMAP of 157 elements...
Saving plots...


In [12]:
NAME = "NLReff_test.bass_ntm_domain_test.fass_ntm_domain_test.fass_ctm_domain_test"
NEIGHBORS = 15
POINT = (-2, 6)
tree = pickle.load(open(
    'tree.'+name+'.pkl'
    , 'rb'))
all_ids = pickle.load(open(
    'ids.'+name+'.pkl'
    , 'rb'))
nearest_points = find_nearest_points(tree, POINT, k=NEIGHBORS)
for distance, index in nearest_points:
    print(f"{all_ids[index]}\t{distance}")

CVK84517.1_96_233	0.3305907492791523
CCA76343.1_24_211	0.41619419845340305
XP_016226918.1_48_153	0.4244747083051056
RDW58735.1_191_274	0.46037088742927185
PHH91864.1_172_254	0.47269391373893777
KFY62020.1_209_282	0.5017582704566611
KFY22657.1_187_257	0.5104078537862463
OIW25313.1_86_214	0.5658803037160223
RFU23947.1_60_189	0.5995960817434965
KNG45163.1_59_188	0.6008283486438596
DAA73946.1_100_265	0.6051093400496077
XP_006695265.1_86_238	0.6074174618183669
OQE82057.1_122_255	0.6146237916105046
XP_014169034.1_54_185	0.6164515188287562
KLU86620.1_104_239	0.6234190383295741
