# Rank Determination

In [None]:
import os
import sys
import scanpy as sc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import argparse
import rich
import gseapy as gp
import ipywidgets as widgets


from IPython.display import display
from rich.panel import Panel
from gseapy import gseaplot
from statannotations.Annotator import Annotator
from pathlib import Path
from scBTF import SingleCellTensor, SingleCellBTF, FactorizationSet, Factorization

%config InlineBackend.figure_formats = ['png2x']
sc.logging.print_header()
sc.settings.njobs = 32

## 1. Load Config and Data

In [4]:
CONFIG_FILE = '.config_ipynb'

if os.path.isfile(CONFIG_FILE):
    with open(CONFIG_FILE) as f:
        sys.argv = f.read().split()
else:
    sys.argv = ['stub.py']
parser = argparse.ArgumentParser()
parser.add_argument("--adata_path", help="path to adata object, expects log normalized X")
parser.add_argument("--consensus_factorization_path",help="path to scBTF consensus factorization result, .pkl format")
parser.add_argument("--full_factorization_path",help="path to scBTF full factorization result, .pkl format")

args = parser.parse_args()

rich.print(args)
globals().update(vars(args))

In [5]:
consensus_factorization = FactorizationSet.load(consensus_factorization_path)
full_factorization = FactorizationSet.load(full_factorization_path)
full_factorization



## 2. Explained Variance across restarts for each rank

Given by the formula,
$$
1 −  \frac{|| T −  T’ ||^2}{|| T ||^2}  
$$ 
where (T) is the target tensor and (T’) is the tensor reconstructed from the inferred factor matrices


In [37]:
children = []
for rank in sorted(list(full_factorization.get_ranks())):
    out = widgets.Output()
    with out:
        fig, ax = plt.subplots(figsize=(5, 2))
        n_restarts = len(full_factorization.factorizations[rank])
        var_explained = [full_factorization.variance_explained(rank=rank, restart_index=i) for i in range(n_restarts)]
        ax.plot(var_explained, 'g')
        plt.xlabel('Restart')
        plt.ylabel("Variance Explained")
        plt.ylim(min(var_explained) - 0.05, 1)
        plt.show(fig)
    children.append(out)
tab = widgets.Tab(children = children)
for index, rank in enumerate(sorted(list(reconstructed_all.get_ranks()))):
    tab.set_title(index, f'Rank {rank}')
display(tab)

Tab(children=(Output(), Output(), Output()), _titles={'0': 'Rank 16', '1': 'Rank 20', '2': 'Rank 24'})

## 3. Fit and stability metrics across ranks

A critical parameter in tensor factorization algorithms is the factorization rank R which determines the number of factors used to approximate the target tensor. Choosing an optimal rank involves using both objective metrics about the quality of the factorization at a given rank and a more subjective assessment of the quality and resolution of the factors retrieved.

We utilize 3 objective metrics of factorization quality:

	1. Explained Variance
	2. Consensus matrix cophenetic correlation
	3. Silhouette Score


In [None]:
fig = full_factorization.rank_metrics_plot()

## 4. Summary plot of factors in each rank

Factor loadings for samples in the first panel and celltypes in the second panel. 
Third panel gives the gene program extracted from the gene loadings of the factor


In [40]:
children = []
for rank in sorted(list(consensus_factorization.get_ranks())):
    out = widgets.Output()
    with out:
        fig = reconstructed_all.plot_components(
            rank=rank, restart_index=0, threshold = 0.7, entropy = 1, eps = 0, sort_by = 'other',
            plot_erichment_terms=True, normalize_gene_factors=True, title=False
        )
        plt.show(fig)
    children.append(out)
tab = widgets.Tab(children = children)
for index, rank in enumerate(sorted(list(consensus_factorization.get_ranks()))):
    tab.set_title(index, f'Rank {rank}')
display(tab)

Tab(children=(Output(), Output(), Output()), _titles={'0': 'Rank 16', '1': 'Rank 20', '2': 'Rank 24'})

## 5. Detailed Factor Analysis for each rank 

In [None]:
ARGS = 'stub.py --adata_path {} --consensus_factorization_path {} --rank {} --factor {}'
CONFIG_FILENAME = '.config_ipynb'
children = []
for rank in sorted(list(consensus_factorization.get_ranks())):
    children_ac = []
    for factor in range(2):
        out = widgets.Output()
        with out:
            with open(CONFIG_FILENAME,'w') as f:
                f.write(ARGS.format(adata_path, consensus_factorization_path, rank, factor))
            %run factor_analysis_template_small.ipynb
        children_ac.append(out)
    accordion = widgets.Accordion(children=children_ac)
    for factor in range(rank):
        accordion.set_title(factor, f'Factor {factor}')
    children.append(accordion)
tab = widgets.Tab(children = children)
for index, rank in enumerate(sorted(list(consensus_factorization.get_ranks()))):
    tab.set_title(index, f'Rank {rank}')
display(tab)

In [None]:
for factor in range(24):
    CONFIG_FILENAME = '.config_ipynb'

    with open(CONFIG_FILENAME,'w') as f:
        f.write(st.format(adata_path, consensus_factorization_path, rank, factor))
    %run factor_analysis_template.ipynb