In [1]:
import warnings
warnings.filterwarnings("ignore")
import functools
import os, sys

import graph_tool.all as gt
import numpy as np
import pandas as pd
import cloudpickle as pickle
import scanpy as sc
import anndata as ad
import muon as mu

from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from scipy import sparse
from numba import njit

In [15]:
"""
Inherit hSBM code from https://github.com/martingerlach/hSBM_Topicmodel
"""

class bionsbm():
    """
    Class to run bionsbm
    """
    def __init__(self, obj, label=None, max_depth=6):
        super().__init__()
        self.keywords = []
        self.nbranches = 1
        self.modalities = []
        self.max_depth = max_depth
        self.obj = obj

        if isinstance(obj, mu.MuData):
            self.modalities=list(obj.mod.keys())   
            dfs=[obj[key].to_df().T for key in self.modalities]
            self.make_graph_multiple_df(dfs[0], dfs[1:])

        elif isinstance(obj, ad.AnnData):
            self.modalities=["Mod1"]
            self.make_graph_multiple_df(obj.to_df().T, [])

        if label:
            g_raw=self.g.copy()
            print("Label found")
            metadata=obj[self.modalities[0]].obs
            mymap = dict([(y,str(x)) for x,y in enumerate(sorted(set(obj[self.modalities[0]].obs[label])))])
            inv_map = {v: k for k, v in mymap.items()}

            docs_type=[int(mymap[metadata.loc[doc][label]]) for doc in self.documents]
            types={}
            types["Docs"]=docs_type
            for i, key in enumerate(self.modalities):
                types[key]=[int(i+np.max(docs_type)+1) for a in range(0, obj[key].shape[0])]
            node_type = g_raw.new_vertex_property('int', functools.reduce(lambda a, b : a+b, list(types.values())))
            self.g = g_raw.copy()
        else:
            node_type=None
        self.node_type=node_type 

    def save_graph(self, filename="graph.xml.gz")->None:
        """
        Save the graph

        :param filename: name of the graph stored
        """
        self.g.save(filename)

    def dump_model(self, filename="bionsbm.pkl"):
        """
        Dump model using pickle

        """
        with open(filename, 'wb') as f:
            pickle.dump(self, f)
        
    def make_graph_multiple_df(self, df: pd.DataFrame, df_keyword_list: list)->None:
        """
        Create a graph from two dataframes one with words, others with keywords or other layers of information

        :param df: DataFrame with words on index and texts on columns
        :param df_keyword_list: list of DataFrames with keywords on index and texts on columns
        """
        df_all = df.copy(deep =True)
        for ikey,df_keyword in enumerate(df_keyword_list):
            df_keyword = df_keyword.reindex(columns=df.columns)
            df_keyword.index = ["".join(["#" for _ in range(ikey+1)])+str(keyword) for keyword in df_keyword.index]
            df_keyword["kind"] = ikey+2
            df_all = pd.concat((df_all,df_keyword), axis=0)

        def get_kind(word):
            return 1 if word in df.index else df_all.at[word,"kind"]

        self.nbranches = len(df_keyword_list)
       
        self.make_graph(df_all.drop("kind", axis=1, errors='ignore'), get_kind)


    def make_graph(self, df: pd.DataFrame, get_kind):
        self.g = gt.Graph(directed=False)

        n_docs, n_words = df.shape[1], df.shape[0]

        # Add all vertices first
        self.g.add_vertex(n_docs + n_words)

        # Create vertex properties
        name = self.g.new_vp("string")
        kind = self.g.new_vp("int")
        self.g.vp["name"] = name
        self.g.vp["kind"] = kind

        # Assign doc vertices (loop for names, array for kind)
        for i, doc in enumerate(df.columns):
            name[self.g.vertex(i)] = doc
        kind.get_array()[:n_docs] = 0

        # Assign word vertices (loop for names, array for kind)
        for j, word in enumerate(df.index):
            name[self.g.vertex(n_docs + j)] = word
        kind.get_array()[n_docs:] = np.array([get_kind(w) for w in df.index], dtype=int)

        # Edge weights
        weight = self.g.new_ep("int")
        self.g.ep["count"] = weight

        # Build sparse edges
        rows, cols = df.values.nonzero()
        vals = df.values[rows, cols].astype(int)
        edges = [(c, n_docs + r, v) for r, c, v in zip(rows, cols, vals)]
        if len(edges)==0: raise ValueError("Empty graph")

        self.g.add_edge_list(edges, eprops=[weight])

        # Remove edges with 0 weight
        filter_edges = self.g.new_edge_property("bool")
        for e in self.g.edges():
            filter_edges[e] = weight[e] > 0
        self.g.set_edge_filter(filter_edges)
        self.g.purge_edges()
        self.g.clear_filters()

        self.documents = df.columns
        self.words = df.index[self.g.vp['kind'].a[n_docs:] == 1]
        for ik in range(2, 2 + self.nbranches):
            self.keywords.append(df.index[self.g.vp['kind'].a[n_docs:] == ik])



    def fit(self, n_init=1, verbose=True, deg_corr=True, overlap=False, parallel=False, B_min=0, B_max=None, clabel=None, *args, **kwargs) -> None:
        """
        Fit using minimize_nested_blockmodel_dl
        
        :param n_init: number of initialisation. The best will be kept
        :param verbose: Print output
        :param deg_corr: use deg corrected model
        :param overlap: use overlapping model
        :param parallel: perform parallel moves
        :param  \*args: positional arguments to pass to gt.minimize_nested_blockmodel_dl
        :param  \*\*kwargs: keywords arguments to pass to gt.minimize_nested_blockmodel_dl
        """
        if clabel == None:
            clabel = self.g.vp['kind']
            state_args = {'clabel': clabel, 'pclabel': clabel}
        else:
            print(f"Clabel is {clabel}, assigning partitions to vertices", flush=True)
            state_args = {'clabel': clabel, 'pclabel': clabel}
    
        state_args["eweight"] = self.g.ep.count
        min_entropy = np.inf
        best_state = None
        state_args["deg_corr"] = deg_corr
        state_args["overlap"] = overlap

        if B_max is None:
            B_max = self.g.num_vertices()
            
        multilevel_mcmc_args={"B_min": B_min, "B_max": B_max, "verbose": verbose,"parallel" : parallel}

        print("multilevel_mcmc_args is \n", multilevel_mcmc_args, flush=True)
        print("state_args is \n", state_args, flush=True)

        for _ in range(n_init):
            print("Fit number:", _, flush=True)
            state = gt.minimize_nested_blockmodel_dl(self.g, state_args=state_args, multilevel_mcmc_args=multilevel_mcmc_args, *args, **kwargs)
            
            entropy = state.entropy()
            if entropy < min_entropy:
                min_entropy = entropy
                self.state = state
                
        self.mdl = min_entropy

        L = len(self.state.levels)
        self.L = L
        self.groups = {}


    def get_mdl(self):
        """
        Get minimum description length

        Proxy to self.state.entropy()
        """
        return self.mdl
            
    def _get_shape(self):
        """
        :return: list of tuples (number of documents, number of words, (number of keywords,...))
        """
        D = int(np.sum(self.g.vp['kind'].a == 0)) #documents
        W = int(np.sum(self.g.vp['kind'].a == 1)) #words
        K = [int(np.sum(self.g.vp['kind'].a == (k+2))) for k in range(self.nbranches)] #keywords
        return D, W, K

    # Helper functions      

    def get_groups(self, l=0):

    # --- Numba function for edge processing with list of arrays ---
        @njit
        def process_edges_numba_list(sources, targets, z1, z2, kinds, weights,
                                     D, W, K_arr, nbranches,
                                     n_db, n_wb, n_dbw, n_w_key_b_list, n_dbw_key_list):

            for i in range(len(sources)):
                v1 = sources[i]
                v2 = targets[i]
                w = weights[i]
                t1 = z1[i]
                t2 = z2[i]
                kind = kinds[i]

                n_db[v1, t1] += w

                if kind == 1:
                    n_wb[v2 - D, t2] += w
                    n_dbw[v1, t2] += w
                else:
                    ik = kind - 2
                    offset = D + W
                    for j in range(ik):
                        offset += K_arr[j]
                    n_w_key_b_list[ik][v2 - offset, t2] += w
                    n_dbw_key_list[ik][v1, t2] += w


        if l in self.groups:
            return self.groups[l]

        state_l = self.state.project_level(l).copy(overlap=True)
        state_l_edges = state_l.get_edge_blocks()
        B = state_l.get_B()
        D, W, K = self._get_shape()
        nbranches = self.nbranches

        # Preallocate arrays
        n_wb = np.zeros((W, B))
        n_db = np.zeros((D, B))
        n_dbw = np.zeros((D, B))

        # For branches, use list of arrays (one per branch) to avoid broadcasting issues
        n_w_key_b = [np.zeros((K[ik], B)) for ik in range(nbranches)]
        n_dbw_key = [np.zeros((D, B)) for _ in range(nbranches)]

        # Convert graph edges to arrays
        edges = list(self.g.edges())
        sources = np.array([e.source() for e in edges], dtype=np.int64)
        targets = np.array([e.target() for e in edges], dtype=np.int64)
        weights = np.array([self.g.ep["count"][e] for e in edges], dtype=np.float64)
        z1_arr = np.array([state_l_edges[e][0] for e in edges], dtype=np.int64)
        z2_arr = np.array([state_l_edges[e][1] for e in edges], dtype=np.int64)
        kinds = np.array([self.g.vp['kind'][v] for v in targets], dtype=np.int64)

        # --- Edge processing (Numba-accelerated) ---
        process_edges_numba_list(sources, targets, z1_arr, z2_arr, kinds, weights, D, W, K, nbranches, n_db, n_wb, n_dbw, n_w_key_b, n_dbw_key)

        # --- Keep only nonzero columns safely ---
        ind_d = np.where(np.sum(n_db, axis=0) > 0)[0]
        n_db = n_db[:, ind_d]
        Bd = len(ind_d)

        ind_w = np.where(np.sum(n_wb, axis=0) > 0)[0]
        n_wb = n_wb[:, ind_w]
        Bw = len(ind_w)

        ind_w2 = np.where(np.sum(n_dbw, axis=0) > 0)[0]
        n_dbw = n_dbw[:, ind_w2]

        Bk = []
        for ik in range(nbranches):
            ind_wk = np.where(np.sum(n_w_key_b[ik], axis=0) > 0)[0]
            n_w_key_b[ik] = n_w_key_b[ik][:, ind_wk].copy()
            Bk.append(len(ind_wk))

            ind_w2k = np.where(np.sum(n_dbw_key[ik], axis=0) > 0)[0]
            n_dbw_key[ik] = n_dbw_key[ik][:, ind_w2k].copy()

        # --- Compute probabilities ---
        p_tw_w = (n_wb / np.nansum(n_wb, axis=1)[:, None]).T
        p_tk_w_key = [(n_w_key_b[ik] / np.nansum(n_w_key_b[ik], axis=1)[:, None]).T
                      for ik in range(nbranches)]
        p_w_tw = n_wb / np.nansum(n_wb, axis=0)[None, :]
        p_w_key_tk = [n_w_key_b[ik] / np.nansum(n_w_key_b[ik], axis=0)[None, :]
                      for ik in range(nbranches)]
        p_tw_d = (n_dbw / np.nansum(n_dbw, axis=1)[:, None]).T
        p_tk_d = [(n_dbw_key[ik] / np.nansum(n_dbw_key[ik], axis=1)[:, None]).T
                  for ik in range(nbranches)]
        p_td_d = (n_db / np.nansum(n_db, axis=1)[:, None]).T

        result = {    'Bd': Bd, 'Bw': Bw, 'Bk': Bk,
                    'p_tw_w': p_tw_w,
                    'p_tk_w_key': p_tk_w_key,
                    'p_td_d': p_td_d,
                    'p_w_tw': p_w_tw,
                    'p_w_key_tk': p_w_key_tk,
                    'p_tw_d': p_tw_d,
                    'p_tk_d': p_tk_d}

        self.groups[l] = result
        return result

    def get_groups_fast_numba_bipartite_safe(self, l=0):
        """
        Numba-accelerated get_groups that is robust for bipartite graphs (nbranches == 0)
        and for arbitrary number of partitions.
        """
    
        @njit
        def process_edges_numba_stack(sources, targets, z1, z2, kinds, weights,
                                      D, W, K_arr, nbranches,
                                      n_db, n_wb, n_dbw, n_w_key_b3, n_dbw_key3):
            """
            Numba-compiled loop that increments the stacked accumulator arrays.
            This function is defensive: if a 'kind' references a branch index out of range,
            or an index into keywords is out of range, it's ignored (so bipartite graphs keep working).
            """
            m = len(sources)
            for i in range(m):
                v1 = sources[i]
                v2 = targets[i]
                w = weights[i]
                t1 = z1[i]
                t2 = z2[i]
                kind = kinds[i]
        
                # update doc-group counts (always)
                n_db[v1, t1] += w
        
                if kind == 1:
                    # word node
                    idx_w = v2 - D
                    if idx_w >= 0 and idx_w < n_wb.shape[0]:
                        n_wb[idx_w, t2] += w
                    # update doc->word-group
                    n_dbw[v1, t2] += w
        
                elif kind >= 2:
                    ik = kind - 2
                    # guard: only process if ik is a valid branch index
                    if ik >= 0 and ik < nbranches:
                        # compute offset = D + W + sum(K_arr[:ik])
                        offset = D + W
                        for j in range(ik):
                            offset += K_arr[j]
                        idx_k = v2 - offset
                        # guard keyword index bounds
                        if idx_k >= 0 and idx_k < K_arr[ik]:
                            n_w_key_b3[ik, idx_k, t2] += w
                            n_dbw_key3[ik, v1, t2] += w
                        # else: out-of-range keyword index -> ignore to remain robust
                else:
                    # unexpected kind (<1): ignore for safety (original assumed only kind==1 or >=2)
                    pass
            
            # cache
         #   if l in self.groups:
          #      return self.groups[l]
        
        state_l = self.state.project_level(l).copy(overlap=True)
        state_l_edges = state_l.get_edge_blocks()
        B = state_l.get_B()
        D, W, K = self._get_shape()
        nbranches = self.nbranches
    
        # Preallocate primary arrays (word/doc)
        n_wb = np.zeros((W, B), dtype=np.float64)    # words x word-groups
        n_db = np.zeros((D, B), dtype=np.float64)    # docs  x doc-groups
        n_dbw = np.zeros((D, B), dtype=np.float64)   # docs  x word-groups
    
        # Preallocate stacked branch arrays (shape: nbranches x max_K x B) and (nbranches x D x B)
        if nbranches > 0:
            max_K = int(np.max(K))
            # If some K are zero, max_K will still be >=0; stack is safe
            n_w_key_b3 = np.zeros((nbranches, max_K, B), dtype=np.float64)
            n_dbw_key3 = np.zeros((nbranches, D, B), dtype=np.float64)
        else:
            # empty stacked arrays if no branches
            n_w_key_b3 = np.zeros((0, 0, B), dtype=np.float64)
            n_dbw_key3 = np.zeros((0, D, B), dtype=np.float64)
    
        # Convert graph edges to arrays
        edges = list(self.g.edges())
        m = len(edges)
        sources = np.empty(m, dtype=np.int64)
        targets = np.empty(m, dtype=np.int64)
        z1_arr = np.empty(m, dtype=np.int64)
        z2_arr = np.empty(m, dtype=np.int64)
        weights = np.empty(m, dtype=np.float64)
        kinds = np.empty(m, dtype=np.int64)
    
        for i, e in enumerate(edges):
            sources[i] = int(e.source())
            targets[i] = int(e.target())
            z1_arr[i] = int(state_l_edges[e][0])
            z2_arr[i] = int(state_l_edges[e][1])
            weights[i] = float(self.g.ep["count"][e])
            kinds[i] = int(self.g.vp['kind'][int(e.target())])
    
        K_arr = np.array(K, dtype=np.int64)  # can be empty if nbranches==0
    
        # --- Numba edge processing (single compiled function for all cases) ---
        process_edges_numba_stack(
            sources, targets, z1_arr, z2_arr, kinds, weights,
            D, W, K_arr, nbranches,
            n_db, n_wb, n_dbw, n_w_key_b3, n_dbw_key3
        )
    
        # --- Trim empty columns for doc/word arrays (same logic as original) ---
        ind_d = np.where(np.sum(n_db, axis=0) > 0)[0]
        n_db = n_db[:, ind_d]
        Bd = len(ind_d)
    
        ind_w = np.where(np.sum(n_wb, axis=0) > 0)[0]
        n_wb = n_wb[:, ind_w]
        Bw = len(ind_w)
    
        ind_w2 = np.where(np.sum(n_dbw, axis=0) > 0)[0]
        n_dbw = n_dbw[:, ind_w2]
    
        # --- Convert stacked branch arrays into per-branch lists (safe slicing) ---
        n_w_key_b_list = []
        n_dbw_key_list = []
        Bk = []
    
        for ik in range(nbranches):
            Kk = int(K_arr[ik])
            if Kk > 0:
                # compute which columns (groups) are non-zero
                col_sums = np.sum(n_w_key_b3[ik, :Kk, :], axis=0)
                ind_wk = np.where(col_sums > 0)[0]
                # slice and copy into a per-branch array (Kk x Bk)
                if ind_wk.size > 0:
                    n_w_key_b_list.append(n_w_key_b3[ik, :Kk, :][:, ind_wk].copy())
                else:
                    # keep shape (Kk, 0) if there are no columns
                    n_w_key_b_list.append(np.zeros((Kk, 0), dtype=np.float64))
                Bk.append(len(ind_wk))
            else:
                # branch with 0 keywords
                n_w_key_b_list.append(np.zeros((0, 0), dtype=np.float64))
                Bk.append(0)
    
            # doc x keyword-groups for this branch
            col_sums_dbw = np.sum(n_dbw_key3[ik], axis=0)
            ind_w2k = np.where(col_sums_dbw > 0)[0]
            if ind_w2k.size > 0:
                n_dbw_key_list.append(n_dbw_key3[ik][:, ind_w2k].copy())
            else:
                n_dbw_key_list.append(np.zeros((D, 0), dtype=np.float64))
    
        # --- Compute probabilities exactly like the original (division -> NaN if denominator==0) ---
        # P(t_w | w)
        denom = np.sum(n_wb, axis=1, keepdims=True)  # (W,1)
        p_tw_w = (n_wb / denom).T
    
        # P(t_k | keyword) per branch
        p_tk_w_key = []
        for ik in range(nbranches):
            arr = n_w_key_b_list[ik]
            denom = np.sum(arr, axis=1, keepdims=True)
            p_tk_w_key.append((arr / denom).T)
    
        # P(w | t_w)
        denom = np.sum(n_wb, axis=0, keepdims=True)  # (1,Bw)
        p_w_tw = n_wb / denom
    
        # P(keyword | t_w_key) per branch
        p_w_key_tk = []
        for ik in range(nbranches):
            arr = n_w_key_b_list[ik]
            denom = np.sum(arr, axis=0, keepdims=True)
            p_w_key_tk.append(arr / denom)
    
        # P(t_w | d)
        denom = np.sum(n_dbw, axis=1, keepdims=True)
        p_tw_d = (n_dbw / denom).T
    
        # P(t_k | d) per branch
        p_tk_d = []
        for ik in range(nbranches):
            arr = n_dbw_key_list[ik]
            denom = np.sum(arr, axis=1, keepdims=True)
            p_tk_d.append((arr / denom).T)
    
        # P(t_d | d)
        denom = np.sum(n_db, axis=1, keepdims=True)
        p_td_d = (n_db / denom).T
    
        result = {
            'Bd': Bd,
            'Bw': Bw,
            'Bk': Bk,
            'p_tw_w': p_tw_w,
            'p_tk_w_key': p_tk_w_key,
            'p_td_d': p_td_d,
            'p_w_tw': p_w_tw,
            'p_w_key_tk': p_w_key_tk,
            'p_tw_d': p_tw_d,
            'p_tk_d': p_tk_d,
        }
    
        self.groups[l] = result
        return result


  
    def save_single_level(self, l: int, name: str) -> None:
        """
        Save per-level probability matrices (topics, clusters, documents) for the given level.

        Parameters
        ----------
        l : int
            The level index to save. Must be within the range of available model levels.
        name : str
            Base path (folder + prefix) where files will be written.
            Example: "results/mymodel" → files like:
                - results/mymodel_level_0_mainfeature_topics.tsv.gz
                - results/mymodel_level_0_clusters.tsv.gz
                - results/mymodel_level_0_mainfeature_topics_documents.tsv.gz
                - results/mymodel_level_0_metafeature_topics.tsv.gz
                - results/mymodel_level_0_metafeature_topics_documents.tsv.gz

        Notes
        -----
        - Files are written as tab-separated values (`.tsv.gz`) with gzip compression.
        - Handles both the main feature (`self.modalities[0]`) and any meta-features (`self.modalities[1:]`).
        - Raises RuntimeError if any file cannot be written.
        """

        # --- Validate inputs ---
        if not isinstance(l, int) or l < 0 or l >= len(self.state.levels):
            raise ValueError(f"Invalid level index {l}. Must be between 0 and {len(self.state.levels) - 1}.")
        if not isinstance(name, str) or not name.strip():
            raise ValueError("`name` must be a non-empty string path prefix.")

        main_feature = self.modalities[0]

        try:
            data = self.get_groups_fast_numba_bipartite_safe(l)
        except Exception as e:
            raise RuntimeError(f"Failed to get group data for level {l}: {e}") from e

        # Helper to safely save a DataFrame
        def _safe_save(df, filepath):
            try:
                Path(filepath).parent.mkdir(parents=True, exist_ok=True)
                df.to_csv(filepath, compression="gzip", sep="\t")
            except Exception as e:
                raise RuntimeError(f"Failed to save {filepath}: {e}") from e

        # Helper to safely save a DataFrame
        def _safe_save(df, filepath):
            try:
                Path(filepath).parent.mkdir(parents=True, exist_ok=True)
                df.to_csv(filepath, compression="gzip", sep="\t")
            except Exception as e:
                raise RuntimeError(f"Failed to save {filepath}: {e}") from e

        # --- P(document | cluster) ---
        clusters = pd.DataFrame(data=data["p_td_d"], columns=self.documents)
        _safe_save(clusters, f"{name}_level_{l}_clusters.tsv.gz")


        # --- P(main_feature | main_topic) ---
        p_w_tw = pd.DataFrame(data=data["p_w_tw"], index=self.words,
            columns=[f"{main_feature}_topic_{i}" for i in range(data["p_w_tw"].shape[1])])
        _safe_save(p_w_tw, f"{name}_level_{l}_{main_feature}_topics.tsv.gz")

        # --- P(main_topic | documents) ---
        p_tw_d = pd.DataFrame(data=data["p_tw_d"].T,index=self.documents,
            columns=[f"{main_feature}_topic_{i}" for i in range(data["p_w_tw"].shape[1])])
        _safe_save(p_tw_d, f"{name}_level_{l}_{main_feature}_topics_documents.tsv.gz")

        # --- P(meta_feature | meta_topic_feature), if any ---
        if len(self.modalities) > 1:
            for k, meta_features in enumerate(self.modalities[1:]):
                feat_topic = pd.DataFrame(data=data["p_w_key_tk"][k], index=self.keywords[k],
                    columns=[f"{meta_features}_topic_{i}" for i in range(data["p_w_key_tk"][k].shape[1])])
                _safe_save(feat_topic, f"{name}_level_{l}_{meta_features}_topics.tsv.gz")


            # --- P(meta_topic | document) ---
            for k, meta_features in enumerate(self.modalities[1:]):
                p_tk_d = pd.DataFrame(data=data["p_tk_d"][k].T, index=self.documents,
                    columns=[f"{meta_features}_topics_{i}" for i in range(data["p_w_key_tk"][k].shape[1])])
                _safe_save(p_tk_d, f"{name}_level_{l}_{meta_features}_topics_documents.tsv.gz")



    def save_data(self, name: str = "results/mymodel") -> None:
        """
        Save the global graph, model, state, and level-specific data for the current nSBM self.

        Parameters
        ----------
        name : str, optional
            Base path (folder + prefix) where all outputs will be saved.
            Example: "results/mymodel" will produce:
                - results/mymodel_graph.xml.gz
                - results/mymodel_model.pkl    
                - results/mymodel_entropy.txt
                - results/mymodel_state.pkl
                - results/mymodel_level_X_*.tsv.gz  (per level, up to 6 levels)

        Notes
        -----
        - The parent folder is created automatically if it does not exist.
        - Level saving is parallelized with threads for efficiency in I/O.
        - By default, at most 6 levels are saved, or fewer if the model has <6 levels.
        - Exceptions in parallel tasks are caught and reported without stopping other tasks.
        """

        # --- Validate name ---
        if not isinstance(name, str) or not name.strip():
            raise ValueError("`name` must be a non-empty string representing the save path.")

        # --- Ensure folder exists ---
        folder = os.path.dirname(name)
        if folder:
            Path(folder).mkdir(parents=True, exist_ok=True)

        # --- Save global files ---
        try:
            self.save_graph(filename=f"{name}_graph.xml.gz")
            self.dump_model(filename=f"{name}_model.pkl")

            with open(f"{name}_entropy.txt", "w") as f:
                f.write(str(self.state.entropy()))

            with open(f"{name}_state.pkl", "wb") as f:
                pickle.dump(self.state, f)

        except Exception as e:
            raise RuntimeError(f"Failed to save global files for model '{name}': {e}") from e


        # --- Save levels in parallel (threaded to avoid data duplication) ---
        L = min(len(self.state.levels), self.max_depth)
        if L == 0:
            print("Nothing to save")
            return  # nothing to save

        errors = []
        with ThreadPoolExecutor() as executor:
            futures = {executor.submit(self.save_single_level, l, name): l for l in range(L)}
            for future in as_completed(futures):
                l = futures[future]
                try:
                    future.result()
                except Exception as e:
                    errors.append((l, str(e)))

        if errors:
            msg = "; ".join([f"Level {l}: {err}" for l, err in errors])
            raise RuntimeError(f"Errors occurred while saving levels: {msg}")


    def get_V(self):
        '''
        return number of word-nodes == types
        '''
        return int(np.sum(self.g.vp['kind'].a == 1))  # no. of types

    def get_D(self):
        '''
        return number of doc-nodes == number of documents
        '''
        return int(np.sum(self.g.vp['kind'].a == 0))  # no. of types

    def get_N(self):
        '''
        return number of edges == tokens
        '''
        return int(self.g.num_edges())  # no. of types

In [16]:
#mdata=mu.read_h5mu("../bionsbm/Test_data.h5mu")
mdata=mu.read_h5mu("../bionsbm/Test_data.h5mu")
mdata

In [23]:
model_single = bionsbm(mdata)
model_single.fit(verbose=False)

multilevel_mcmc_args is 
 {'B_min': 0, 'B_max': 3200, 'verbose': False, 'parallel': False}
state_args is 
 {'clabel': <VertexPropertyMap object with value type 'int32_t', for Graph 0x751310303e90, at 0x75127c7e9690>, 'pclabel': <VertexPropertyMap object with value type 'int32_t', for Graph 0x751310303e90, at 0x75127c7e9690>, 'eweight': <EdgePropertyMap object with value type 'int32_t', for Graph 0x751310303e90, at 0x7512a81f51d0>, 'deg_corr': True, 'overlap': False}
Fit number: 0


In [24]:
model_single.save_data()

In [29]:
pd.read_csv("results/mymodel_level_1_lncRNA_topics.tsv.gz", sep="\t", index_col=0)

Unnamed: 0,lncRNA_topic_0
##LINC00566,0.0
##LINC02191,0.0
##LINC00858,0.0
##EHHADH-AS1,0.0
##LINC02447,0.0
...,...
##LINC02366,0.0
##NFIA-AS2,0.0
##LINC01311,0.0
##CDIPTOSP,0.0


In [45]:
res_single = model_single.get_groups_fast_numba_bipartite_safe(l=0)
res_single

{'Bd': 5,
 'Bw': 13,
 'Bk': [],
 'p_tw_w': array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 1., ..., 0., 0., 0.],
        [0., 0., 0., ..., 1., 0., 1.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]]),
 'p_tk_w_key': [],
 'p_td_d': array([[0., 0., 0., ..., 0., 1., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [1., 0., 0., ..., 1., 0., 0.],
        [0., 1., 1., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 0.]]),
 'p_w_tw': array([[0.        , 0.        , 0.        , ..., 0.        , 0.        ,
         0.        ],
        [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
         0.        ],
        [0.        , 0.0263789 , 0.        , ..., 0.        , 0.        ,
         0.        ],
        ...,
        [0.        , 0.        , 0.01289802, ..., 0.        , 0.        ,
         0.        ],
        [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
  

In [46]:
model_multi = bionsbm(mdata)
model_multi.fit(verbose=False)

multilevel_mcmc_args is 
 {'B_min': 0, 'B_max': 3200, 'verbose': False, 'parallel': False}
state_args is 
 {'clabel': <VertexPropertyMap object with value type 'int32_t', for Graph 0x7a05207e1050, at 0x7a0531fb2010>, 'pclabel': <VertexPropertyMap object with value type 'int32_t', for Graph 0x7a05207e1050, at 0x7a0531fb2010>, 'eweight': <EdgePropertyMap object with value type 'int32_t', for Graph 0x7a05207e1050, at 0x7a0536ecaad0>, 'deg_corr': True, 'overlap': False}
Fit number: 0


In [47]:
res_multi = model_multi.get_groups_fast_numba_bipartite_safe(l=0)
res_multi

{'Bd': 6,
 'Bw': 14,
 'Bk': [6, 1],
 'p_tw_w': array([[0., 0., 0., ..., 0., 0., 0.],
        [1., 0., 0., ..., 1., 0., 1.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]]),
 'p_tk_w_key': [array([[ 1.,  0., nan, ...,  1.,  1.,  0.],
         [ 0.,  0., nan, ...,  0.,  0.,  0.],
         [ 0.,  0., nan, ...,  0.,  0.,  0.],
         [ 0.,  0., nan, ...,  0.,  0.,  0.],
         [ 0.,  0., nan, ...,  0.,  0.,  0.],
         [ 0.,  1., nan, ...,  0.,  0.,  1.]]),
  array([[nan, nan, nan, nan, nan,  1.,  1., nan,  1., nan, nan, nan, nan,
          nan,  1.,  1., nan, nan, nan, nan, nan, nan, nan,  1., nan, nan,
          nan, nan, nan, nan,  1.,  1., nan, nan,  1., nan, nan, nan, nan,
          nan, nan,  1., nan, nan, nan, nan, nan, nan,  1.,  1., nan, nan,
           1., nan, nan, nan, nan, nan,  1., nan, nan, nan, nan,  1.,  1.,
          nan, nan,  1.,  1., nan, nan

In [37]:
def compare_groups_results(r1, r2, atol=1e-12, name="root"):
    """
    Compare two get_groups-like results.

    Raises AssertionError if any mismatch.
    Prints detailed info about differences.
    """
    for key in r1:
        val1 = r1[key]
        val2 = r2[key]
        full_name = f"{name}.{key}"

        # Check type mismatch
        assert type(val1) == type(val2), f"Type mismatch for {full_name}: {type(val1)} vs {type(val2)}"

        if isinstance(val1, list):
            assert len(val1) == len(val2), f"List length mismatch for {full_name}"
            for i, (a, b) in enumerate(zip(val1, val2)):
                elem_name = f"{full_name}[{i}]"
                if isinstance(a, np.ndarray):
                    assert a.shape == b.shape, f"Shape mismatch for {elem_name}: {a.shape} vs {b.shape}"
                    mask = ~np.isclose(a, b, atol=atol, equal_nan=True)
                    if np.any(mask):
                        idx = np.argwhere(mask)
                        n_diff = len(idx)
                        print(f"{elem_name}: {n_diff} differences found")
                        for j, (r, c) in enumerate(idx[:20]):
                            print(f"{elem_name}[{r},{c}]: {a[r,c]} vs {b[r,c]}")
                        if n_diff > 20:
                            print(f"... and {n_diff - 20} more differences")
                        raise AssertionError(f"{elem_name} arrays differ")
                else:  # scalar in list
                    assert a == b, f"Scalar mismatch in {elem_name}: {a} vs {b}"

        elif isinstance(val1, np.ndarray):
            assert val1.shape == val2.shape, f"Shape mismatch for {full_name}: {val1.shape} vs {val2.shape}"
            mask = ~np.isclose(val1, val2, atol=atol, equal_nan=True)
            if np.any(mask):
                idx = np.argwhere(mask)
                n_diff = len(idx)
                print(f"{full_name}: {n_diff} differences found")
                for j, (r, c) in enumerate(idx[:20]):
                    print(f"{full_name}[{r},{c}]: {val1[r,c]} vs {val2[r,c]}")
                if n_diff > 20:
                    print(f"... and {n_diff - 20} more differences")
                raise AssertionError(f"{full_name} arrays differ")

        else:  # scalar
            assert val1 == val2, f"Scalar mismatch for {full_name}: {val1} vs {val2}"

    print("All keys match exactly (within tolerance and NaN equality).")
    
compare_groups_results(res_orig, res_numba_fixed)

All keys match exactly (within tolerance and NaN equality).


In [16]:
model.save_data()

RuntimeError: Errors occurred while saving levels: Level 1: Failed to get group data for level 1: cannot compute fingerprint of empty list; Level 4: Failed to get group data for level 4: cannot compute fingerprint of empty list; Level 5: Failed to get group data for level 5: cannot compute fingerprint of empty list; Level 3: Failed to get group data for level 3: cannot compute fingerprint of empty list; Level 2: Failed to get group data for level 2: cannot compute fingerprint of empty list; Level 0: Failed to get group data for level 0: cannot compute fingerprint of empty list

In [18]:
main_feature = model.modalities[0]
main_feature

'Mod1'

In [20]:
data = model.get_groups(0)

ValueError: cannot compute fingerprint of empty list

In [35]:
model_bionsbm=bionsbm(mdata["Peak"])
model_bionsbm.g

<Graph object, undirected, with 2000 vertices and 19777 edges, 2 internal vertex properties, 1 internal edge property, at 0x729ca92dde50>

In [38]:
from sbmtm import sbmtm
model_sbmtm=sbmtm()
model_sbmtm.make_graph_from_BoW_df(mdata["Peak"].to_df().T)
model_sbmtm.g

<Graph object, undirected, with 2000 vertices and 19777 edges, 2 internal vertex properties, 1 internal edge property, at 0x729ca972ba50>

In [42]:
def compare_models(m1, m2):
    g1, g2 = m1.g, m2.g

    assert g1.num_vertices() == g2.num_vertices(), "Different number of vertices"
    assert g1.num_edges() == g2.num_edges(), "Different number of edges"

    for prop in ["name", "kind"]:
        p1, p2 = g1.vp[prop], g2.vp[prop]
        vals1 = [p1[v] for v in g1.vertices()]
        vals2 = [p2[v] for v in g2.vertices()]
        assert vals1 == vals2, f"Vertex property {prop} differs"

    w1, w2 = g1.ep["count"], g2.ep["count"]
    edges1 = sorted([(int(e.source()), int(e.target()), int(w1[e])) for e in g1.edges()])
    edges2 = sorted([(int(e.source()), int(e.target()), int(w2[e])) for e in g2.edges()])
    assert edges1 == edges2, "Edges or weights differ"

    assert list(m1.documents) == list(m2.documents), "Documents differ"
    assert np.array_equal(m1.words, m2.words), "Words differ"
#    assert len(m1.keywords) == len(m2.keywords), "Different number of keyword groups"
 #   for i, (kw1, kw2) in enumerate(zip(m1.keywords, m2.keywords)):
  #      assert np.array_equal(kw1, kw2), f"Keywords differ at branch {i}"

    return True

In [43]:
compare_models(model_bionsbm, model_sbmtm)

True