# Adata

> adata functions.

In [None]:
#| default_exp utils.adata

In [None]:
#| hide
from nbdev.showdoc import *

In [1]:
#| export
from dataclasses import dataclass, field
import numpy as np, pandas as pd

from typing import List, Any, Optional

In [None]:
#| export
from iza.types import AnnData, ndarray, DataFrame
from iza.static import X_MAGIC, PHATE, X_PHATE

### Adata

In [None]:
#| export
@dataclass
class AdataExtractor:
    adata: AnnData
    layer: Optional[str] = X_MAGIC
    x_emb: Optional[str] = X_PHATE

    dim_str: Optional[str] = None
    use_hvg: Optional[bool] = True

    @property
    def has_hvg(self):
        return hasattr(self.adata, 'var') and hasattr(self.adata.var, 'highly_variable')
    
    @property
    def has_emb(self):
        return hasattr(self.adata, 'obsm') and self.x_emb in self.adata.obsm.keys()

    def get_layer(self) -> ndarray:
        layer = self.sdata().layers.get(self.layer, None)

        if layer is None:
            layer = self.sdata().X

        if hasattr(layer, 'toarray'):
            layer = layer.toarray()

        if hasattr(layer, 'todense'):
            layer = layer.todense()

        return layer
    
    def get_emb(self) -> ndarray:
        emb = self.sdata().obsm.get(self.x_emb, None)
        if emb is None:
            raise ValueError(f'No embedding found in adata.obsm {self.sdata().obsm.keys()}')

        # NOTE: defined in _02_utils/_05_guards.ipynb
        emb = to_ndarray(emb)
        return emb

    @property
    def axis_str(self):
        if self.dim_str:
            return self.dim_str
        return self.x_emb.replace('X_', '').upper()
    
    @property
    def emb_cols(self):
        ndim = self.get_emb().shape[1]
        cols = [f'{self.axis_str}_{i+1}' for i in range(ndim)]
        return cols
        
    def sdata(self):
        if self.use_hvg and self.has_hvg:
            return self.adata[:, self.adata.var.highly_variable]
        return self.adata
    
    def get_df_cnt(self) -> DataFrame:
        layer = self.get_layer()

        cols = self.sdata().var.index
        idxs = self.sdata().obs.index
        df = pd.DataFrame(layer, index=idxs, columns=cols)
        return df
    
    def get_df_emb(self) -> DataFrame:
        emb = self.get_emb()
        
        cols = self.emb_cols
        idxs = self.sdata().obs.index
        df = pd.DataFrame(emb, index=idxs, columns=cols)
        return df
    
    @property
    def df_cnt(self):
        return self.get_df_cnt()
    
    @property
    def df_emb(self):
        return self.get_df_emb()

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()