# Evaluate features

This notebook demonstrate how to evaluate the features stored in the anndata.obsm.
The task we are interested in is to predict the gene expression based on the cell_type label and the covariates. 

We evaluate different metric to show that the semantic features obtained from the self supervised learning (ssl) frameworks are biological relevant.

In [1]:
# TO REMOVE when notebook is stable

%load_ext autoreload
%autoreload 2

### Common Imports

In [2]:
import numpy
import torch
import seaborn
import tarfile
import os
import matplotlib
import matplotlib.pyplot as plt
from anndata import read_h5ad

# tissue_purifier import
import tissue_purifier as tp

### Download an  anndata object with extra features stored in .obsm.

Altenatively you can use the anndata file generated by running notebook2.

In [3]:
import tissue_purifier.io

bucket_name = "ld-data-bucket"
annotated_anndata_source_path = "tissue-purifier/annotated_slideseq_testis_anndata_h5ad.tar.gz"
annotated_anndata_dest_path = "./annotated_slideseq_testis_anndata_h5ad.tar.gz"
annotated_anndata_dest_folder = "./annotated_anndata"

#tp.io.download_from_bucket(bucket_name, annotated_anndata_source_path, annotated_anndata_dest_path)   
#with tarfile.open(annotated_anndata_dest_path, "r:gz") as fp:
#    fp.extractall(path=annotated_anndata_dest_folder)
    
# Make a list of all the h5ad files in the annotated_anndata_dest_folder
fname_list = []
for f in os.listdir(annotated_anndata_dest_folder):
    if f.endswith('.h5ad'):
        fname_list.append(f)
print(fname_list)

FileNotFoundError: [Errno 2] No such file or directory: './annotated_anndata'

### Decide how to filter the anndata object

In [None]:
# filter cells parameters
fc_bc_min_umi = 200
fc_bc_max_umi = 3000
fc_bc_min_n_genes_by_counts = 10
fc_bc_max_n_genes_by_counts = 2500
fc_bc_max_pct_counts_mt = 5

# filter genes parameters
fg_bc_min_cells_by_counts = 100

# filter rare cell types parameters
fctype_bc_min_cells_absolute = 100
fctype_bc_min_cells_frequency = 0.1

### Open the first anndata and compute some metrics

In [None]:
h5ad_file = fname_list[0]
adata = anndata.read_h5ad(filename=h5ad_file)

# mitocondria metrics
adata.var['mt'] = adata.var_names.str.startswith('mt-')  # annotate the group of mitochondrial genes as 'mt'
sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True)

# counts cells frequency
tmp = adata.obs[cell_type_key].values.describe()
mask1 = (tmp["counts"] > fctype_dt_min_cells_absolute)
mask2 = (tmp["freqs"] > fctype_dt_min_cells_frequency)
mask = mask1 * mask2
cell_type_keep = set(tmp[mask].index.values)
adata.obs["keep_ctype"] = adata.obs["cell_type"].apply(lambda x: x in cell_type_keep)

### Filter anndata

In [None]:
adata = adata[adata.obs["total_counts"] > fc_dt_min_umi, :] 
adata = adata[adata.obs["total_counts"] < fc_dt_max_umi, :] 
adata = adata[adata.obs["n_genes_by_counts"] > fc_dt_min_n_genes_by_counts, :] 
adata = adata[adata.obs["n_genes_by_counts"] < fc_dt_max_n_genes_by_counts, :] 
adata = adata[adata.obs["pct_counts_mt"] < fc_dt_max_pct_counts_mt, :]
adata = adata[adata.obs["keep_ctype"] == True, :]
adata = adata[:, adata.var["n_cells_by_counts"] > fg_dt_min_cells_by_counts]

### Make a gene dataset from the anndata and split it into train/test/val

In [None]:
covariate_keys = "ncv_10"

gene_dataset = make_gene_dataset_from_anndata(
        anndata=adata,
        cell_type_key='cell_type',
        covariate_key=covariate_key,
        preprocess_strategy='raw',
        apply_pca=False)
    
train_dataset, test_dataset, val_dataset = next(iter(train_test_val_split(gene_dataset)))

### train the gene regression model and test it 

In [None]:
gr = GeneRegression()
gr.configure_optimizer(optimizer_type='adam', lr=5E-3)

result = gr.train_and_test(
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    test_num_samples=10,
    train_steps=50,
    train_print_frequency=5,
    use_covariates=True,
    l1_regularization_strength=0.1,
    l2_regularization_strength=None,
    eps_range=(1.0E-5, 1.0E-2),
    subsample_size_cells=None,
    subsample_size_genes=None,
    from_scratch=True)

### visualize the results