# BM25

Approach to cluster matching built primarily around BM25, using material from fusion.ipynb.


In [1]:
#
# import libraries
#
import pandas as pd
from pathlib import Path
from rank_bm25 import BM25Okapi
import numpy as np

from functools import partial
from typing import Union
from loguru import logger

In [2]:
#
# set constants
#
input_path = Path("../results")
output_path = Path("../results")

training_fraction = 0.8

In [3]:
#
# load the data along with embeddings
#
master_df = pd.read_csv(input_path / Path("ragsc_00_all_large.csv"))
master_n_cells = master_df.shape[0]

train_df = master_df.sample(frac=training_fraction)
test_df = master_df.drop(train_df.index)  # .sample(frac=training_fraction)
print(f"total rows: {master_df.shape[0]}")
print(f"training set has {train_df.shape[0]} rows")
print(f"test set has {test_df.shape[0]} rows")

total rows: 9370
training set has 7496 rows
test set has 1874 rows


In [4]:
def get_gene_bags(df: pd.DataFrame, max_genes: int, sort_by_cluster_names=True) -> dict:
    """
    Produces "bags of words" for each cluster to use as documents in BM25 analysis.

    Returns a dictionary with cluster name as the keys and a list of gene names as the values.
    """
    clusters = df.groupby("cluster", sort=False)
    word_dict = {}
    for cluster in clusters:
        # each cluster is a tuple (cluster name, cluster dataframe)
        words = []
        cluster_df = cluster[1]  # the dataframe
        # convert each signature into a list of string
        word_series = cluster_df.signature.apply(lambda x: x.split(" "))
        # create a bag of words based containing the gene names for this cluster
        for sig in word_series:
            # retain only max_genes gene names to add to the bag of words
            words.extend(sig[:max_genes])
        word_dict[cluster[0]] = words
    if sort_by_cluster_names:
        word_dict = {k: word_dict[k] for k in sorted(word_dict)}
    return word_dict

In [5]:
#
# chunking
#
def chunk(s: Union[str, list], size: int, step=1) -> list[str]:
    """
    Takes a string or list of strings and creates a list of overlapping chunks of a given size.

    Args
        size: The number of words (gene names) in each chunk.
        step: The number of words to advance before the next chunk (defaults to 1).
    Returns
        A list of strings representing the chunks.
    """
    if isinstance(s, str):
        a = s.split()
    else:
        a = s
    results = []
    max = len(a)
    for i in range(max):
        if i + size < max:
            results.append(" ".join(a[i : i + size]))
        else:
            results.append(" ".join(a[i:]))
        i += step
    return results


chunk2 = partial(chunk, size=2)
chunk3 = partial(chunk, size=2, step=2)

In [90]:
def chunk_sig(sig: str, chunk_size: int, overlap=3) -> list[str]:
    if isinstance(sig, str):
        items = sig.split()
    else:
        items = sig
    chunks = []
    items = sig.split(" ")
    for i in range(0, len(items), overlap):
        if i + chunk_size > len(items):
            chunks.append(" ".join(items[i:]))
            break
        else:
            chunks.append(" ".join(items[i : i + chunk_size]))
    return chunks

In [117]:
word_dict = get_gene_bags(train_df, max_genes=120)
chunk_size = 3
docs = [chunk_sig(" ".join(x), chunk_size, chunk_size) for x in word_dict.values()]
bm25_index=BM25Okapi(docs)

In [118]:
docs[0][:10]

['RPL11 CAMK4 RIPOR2',
 'AC079793.1 SF3B1 RPSA',
 'RPL29 CBLB RPS12',
 'JAK2 SYNE2 PPP2R5C',
 'SLTM ANKRD12 RPRD1A',
 'CCNL2 RPL22 TNFRSF1B',
 'EPB41 ZMYM4 NDUFS5',
 'TM2D1 JAK1 USP33',
 'SELENOF CD53 CD2',
 'TXNIP RPS27 KIFAP3']

In [119]:
def get_score(bm25, gene_list, max_genes=25, normalized=True) -> list[float]:
    """
    Returns a list containing the scores for a particular list of genes
    """
    query = chunk_sig(gene_list, chunk_size)[:max_genes]
    scores = bm25.get_scores(query)
    if (np.max(scores) <= np.min(scores)):
        print("max and min are equal!")
        return scores
    if normalized:
        scores = (scores - np.min(scores)) / (np.max(scores) - np.min(scores))
    return scores

In [120]:
def create_score_column(df: pd.DataFrame, bm25, max_genes) -> pd.DataFrame:
    """
    Add a column to the provided dataframe containing the BM25 scores.

    Args:
        df - the dataframe whose signatures will be used to generate the scores
        bm25 - the index to use fo comparison
        max_genes - the maximum number of genes to include from each signature

    Returns a reference to the original dataframe
    """
    df["scores"] = df.signature.apply(lambda x: get_score(bm25, x, max_genes))
    return df

In [124]:
df_test = create_score_column(test_df, bm25_index, 120)

max and min are equal!
max and min are equal!
max and min are equal!
max and min are equal!
max and min are equal!


In [125]:
BM = 0
VECTOR = 1
BOTH = 2


def calculate_summary_stats(
    clusters_df: pd.DataFrame, method: int
) -> dict[int, np.ndarray]:
    clusters = clusters_df.groupby("cluster")
    table: dict[int, np.ndarray] = {k: [] for k in range(df_test.cluster.max())}  # type: ignore
    for cluster in clusters:
        cluster_no = cluster[0]
        cluster_df = cluster[1]
        row_count = cluster_df.shape[0]
        values = np.zeros(row_count, dtype=float)
        for row in range(row_count):
            # n_score = cluster_df.n_score.iloc[row][cluster_no]
            m_score = cluster_df.scores.iloc[row][cluster_no]
            if method == BM:
                values[row] = m_score
            else:
                logger.error("Only supports BM25 stats")
            # elif method == VECTOR:
            #     values[row] = n_score
            # else:
            #     if n_score > m_score:
            #         values[row] = n_score
            #     else:
            #         values[row] = m_score
        table[cluster_no] = values  # type: ignore
    return table

In [126]:
from scipy import stats

table_match = calculate_summary_stats(df_test, BM)
for cluster in table_match:
    print(
        f"{cluster:02} {table_match[cluster].mean():8.3f} {table_match[cluster].std():8.3f} {stats.sem(table_match[cluster]):8.3f} ({table_match[cluster].size})"
    )

00    0.553    0.377    0.023 (259)
01    0.361    0.245    0.017 (203)
02    0.562    0.354    0.029 (155)
03    0.574    0.321    0.026 (155)
04    0.433    0.319    0.027 (144)
05    0.610    0.384    0.035 (122)
06    0.495    0.342    0.036 (89)
07    0.408    0.346    0.034 (106)
08    0.504    0.401    0.038 (112)
09    0.613    0.384    0.042 (84)
10    0.558    0.375    0.045 (70)
11    0.533    0.420    0.050 (72)
12    0.261    0.336    0.043 (63)
13    0.386    0.388    0.047 (69)
14    0.224    0.325    0.051 (42)
15    0.315    0.394    0.064 (39)
16    0.160    0.268    0.042 (41)
17    0.226    0.348    0.067 (28)
18    0.117    0.290    0.065 (21)


In [54]:
def create_ranking_dict(a: np.ndarray) -> dict[int, float]:
    d = {}
    for i in range(len(a)):
        d[i] = a[i]
    return d


def sort_categories_by_values(categories: dict[int, float]) -> dict[int, float]:
    return dict(sorted(categories.items(), key=lambda item: item[1], reverse=True))


scores_df = test_df[["cluster", "scores"]].copy()
scores_df["rank_dict"] = scores_df.scores.apply(lambda x: create_ranking_dict(x))
scores_df["ranked_dict"] = scores_df.rank_dict.apply(
    lambda x: sort_categories_by_values(x)
)
scores_df

Unnamed: 0,cluster,scores,rank_dict,ranked_dict
12,0,"[0.8317565093984453, 0.8898878992585331, 0.600...","{0: 0.8317565093984453, 1: 0.8898878992585331,...","{3: 1.0, 1: 0.8898878992585331, 0: 0.831756509..."
16,0,"[1.0, 0.7101401168618983, 0.7080276093736951, ...","{0: 1.0, 1: 0.7101401168618983, 2: 0.708027609...","{0: 1.0, 1: 0.7101401168618983, 2: 0.708027609..."
17,0,"[1.0, 0.6582289843957945, 0.7078091830170901, ...","{0: 1.0, 1: 0.6582289843957945, 2: 0.707809183...","{0: 1.0, 2: 0.7078091830170901, 1: 0.658228984..."
18,0,"[0.9293668223785776, 1.0, 0.8694028139139268, ...","{0: 0.9293668223785776, 1: 1.0, 2: 0.869402813...","{1: 1.0, 0: 0.9293668223785776, 2: 0.869402813..."
25,0,"[1.0, 0.8597027800415953, 0.6926143798962348, ...","{0: 1.0, 1: 0.8597027800415953, 2: 0.692614379...","{0: 1.0, 1: 0.8597027800415953, 3: 0.693894467..."
...,...,...,...,...
9347,18,"[0.530859651465067, 0.5230712082127268, 0.4165...","{0: 0.530859651465067, 1: 0.5230712082127268, ...","{9: 1.0, 0: 0.530859651465067, 1: 0.5230712082..."
9349,18,"[0.7309076969354257, 1.0, 0.7209908954735146, ...","{0: 0.7309076969354257, 1: 1.0, 2: 0.720990895...","{1: 1.0, 3: 0.9982068555527103, 4: 0.902158389..."
9361,18,"[0.8926251895220338, 0.6465804935882411, 0.721...","{0: 0.8926251895220338, 1: 0.6465804935882411,...","{9: 1.0, 0: 0.8926251895220338, 3: 0.776211615..."
9365,18,"[0.758851489432163, 0.41861073375775426, 0.572...","{0: 0.758851489432163, 1: 0.41861073375775426,...","{9: 1.0, 0: 0.758851489432163, 5: 0.7457135315..."


In [49]:
clusters = scores_df.groupby("cluster")
ranks = {}
for cluster in clusters:
    cluster_no = cluster[0]
    cluster_df = cluster[1]
    cluster_rank_total = 0
    # if cluster_no > 0:
    # break
    cluster_df["cluster_rank"] = cluster_df.ranked_dict.apply(lambda x: x[cluster_no])
    ranks[cluster_no] = cluster_df.cluster_rank.mean()

ranks

{0: 0.9962871030604834,
 1: 0.9746683182640282,
 2: 0.8886388823094608,
 3: 0.8839024385738254,
 4: 0.9238078474246274,
 5: 0.9707897654283968,
 6: 0.994977867743239,
 7: 0.8134663432133056,
 8: 0.9379237116862214,
 9: 0.9908841366830422,
 10: 0.9758680944604559,
 11: 0.9849142655251698,
 12: 0.6571824004441252,
 13: 0.8753256053082086,
 14: 0.9075333920329774,
 15: 0.8192226911974294,
 16: 0.7413395902884095,
 17: 0.43202495874237967,
 18: 0.48441904200836944}

In [33]:
#
# need to calculate the rank agreement between the actual cluster and the top ranked predicted cluster
#


def correct_ranking(cluster_no, ranked_dict) -> int:
    pred_clusters = list(ranked_dict.keys())
    if pred_clusters[0] == cluster_no:
        return 1
    else:
        return 0


clusters = scores_df.groupby("cluster")
ranks = {}
for cluster in clusters:
    cluster_no = cluster[0]
    cluster_df = cluster[1]
    cluster_df["ranked_correctly"] = cluster_df.apply(
        lambda row: correct_ranking(cluster_no, row["ranked_dict"]), axis=1
    )  # type:ignore

    print(
        f"{cluster_no:02} {cluster_df.ranked_correctly.sum():8} {cluster_df.shape[0]:8} {cluster_df.ranked_correctly.mean():8.2f}"
    )

00      234      247     0.95
01      158      206     0.77
02       59      166     0.36
03       62      153     0.41
04       86      160     0.54
05       99      127     0.78
06       88       93     0.95
07       17       99     0.17
08       66       95     0.69
09       75       81     0.93
10       76       90     0.84
11       68       73     0.93
12        1       55     0.02
13       27       51     0.53
14       26       52     0.50
15       18       50     0.36
16        9       32     0.28
17        0       30     0.00
18        0       14     0.00


In [61]:
#
# test chunking approach
#
sig1 = test_df.signature.iloc[0]
print(len(sig1))
print(type(sig1))

1212
<class 'str'>


In [70]:
def chunk_sig(sig: str, chunk_size: int) -> list[str]:
    chunks = []
    items = sig1.split(" ")
    for i in range(len(items)):
        if i + chunk_size > len(items):
            chunks.append(" ".join(items[i:]))
            break
        else:
            chunks.append(" ".join(items[i : i + chunk_size]))
    return chunks


chunks = chunk_sig(sig1, 4)
print(chunks)

['SKAP1 HS2ST1 RPL5 EIF2AK3', 'HS2ST1 RPL5 EIF2AK3 GLS', 'RPL5 EIF2AK3 GLS LRRFIP2', 'EIF2AK3 GLS LRRFIP2 HCLS1', 'GLS LRRFIP2 HCLS1 STAG1', 'LRRFIP2 HCLS1 STAG1 ERCC8', 'HCLS1 STAG1 ERCC8 HLA-DRB1', 'STAG1 ERCC8 HLA-DRB1 UBE3C', 'ERCC8 HLA-DRB1 UBE3C RPL30', 'HLA-DRB1 UBE3C RPL30 WAPL', 'UBE3C RPL30 WAPL PICALM', 'RPL30 WAPL PICALM CLEC2D', 'WAPL PICALM CLEC2D SYNE2', 'PICALM CLEC2D SYNE2 CCL5', 'CLEC2D SYNE2 CCL5 SMCHD1', 'SYNE2 CCL5 SMCHD1 RNF138', 'CCL5 SMCHD1 RNF138 STK4', 'SMCHD1 RNF138 STK4 KDM1A', 'RNF138 STK4 KDM1A RPL11', 'STK4 KDM1A RPL11 SRSF4', 'KDM1A RPL11 SRSF4 PTP4A2', 'RPL11 SRSF4 PTP4A2 KHDRBS1', 'SRSF4 PTP4A2 KHDRBS1 RPS8', 'PTP4A2 KHDRBS1 RPS8 NASP', 'KHDRBS1 RPS8 NASP FAF1', 'RPS8 NASP FAF1 SSBP3', 'NASP FAF1 SSBP3 PRKACB', 'FAF1 SSBP3 PRKACB SLC30A7', 'SSBP3 PRKACB SLC30A7 PRPF38B', 'PRKACB SLC30A7 PRPF38B CD53', 'SLC30A7 PRPF38B CD53 CD2', 'PRPF38B CD53 CD2 RCOR3', 'CD53 CD2 RCOR3 ANGEL2', 'CD2 RCOR3 ANGEL2 LPIN1', 'RCOR3 ANGEL2 LPIN1 EML4', 'ANGEL2 LPIN1 EML4 PS