# Comparing Modlyn & Scanpy feature selection methods

In [None]:
# pip install lamindb modlyn scanpy seaborn
import lamindb as ln
import modlyn as mn
import scanpy as sc
import pandas as pd
import seaborn as sns
sns.set_theme()
%config InlineBackend.figure_formats = ['svg']

In [None]:
ln.track()

## Prepare dataset

In [None]:
artifact = ln.Artifact.using("laminlabs/arrayloader-benchmarks").get("JNaxQe8zbljesdbK0000")
adata = artifact.load()
sc.pp.log1p(adata)
adata

In [None]:
keep = adata.obs["cell_line"].value_counts().loc[lambda x: x>3].index
adata = adata[adata.obs["cell_line"].isin(keep)].copy()
adata

In [None]:
adata.obs["cell_line"].value_counts().tail()

## Train LogReg with Modlyn

In [None]:
logreg = mn.models.SimpleLogReg(
    adata=adata,
    label_column="cell_line",    
    learning_rate=1e-1,
    weight_decay=1e-3,
)
logreg.fit(
    adata_train=adata,
    adata_val=adata[:20],
    train_dataloader_kwargs={
        "batch_size": 128,
        "drop_last": True,
        "num_workers": 4
    },
    max_epochs=5,
)

In [None]:
logreg.plot_losses()

In [None]:
logreg.plot_classification_report(adata)

## Get features scores of different methods

In [None]:
df_modlyn_logreg = logreg.get_weights()
df_modlyn_logreg.head()

In [None]:
sc.tl.rank_genes_groups(adata, 'cell_line', method='logreg', key_added='sc_logreg')
df_scanpy_logreg = sc.get.rank_genes_groups_df(adata, group=None, key="sc_logreg").pivot(index='group', columns='names', values='scores')
df_scanpy_logreg.attrs["method_name"] = "scanpy_logreg"
df_scanpy_logreg.head()

In [None]:
sc.tl.rank_genes_groups(adata, 'cell_line', method='wilcoxon', key_added='sc_wilcoxon')
df_scanpy_wilcoxon = sc.get.rank_genes_groups_df(adata, group=None, key="sc_wilcoxon").pivot(index='group', columns='names', values='scores')
df_scanpy_wilcoxon.attrs["method_name"] = "scanpy_wilcoxon"
df_scanpy_wilcoxon.head()

## Compare feature selection results

In [None]:
compare = mn.eval.CompareScoresJaccard([df_modlyn_logreg, df_scanpy_logreg, df_scanpy_wilcoxon], n_top_values=[5, 10, 25])

In [None]:
compare.plot_heatmaps()

In [None]:
compare.compute_jaccard_comparison()
compare.plot_jaccard_comparison()

In [None]:
ln.finish()