# Getting started with astir

## 0. Load necessary libraries

In [28]:
from astir.data_readers import from_csv_yaml
import pandas as pd

## 1. Load data

We start by reading expression data in the form of a csv file and marker gene information in the form of a yaml file:

In [29]:
expression_mat_path = "../../../astir/tests/test-data/sce.csv"
yaml_marker_path = "../../../astir/tests/test-data/jackson-2020-markers.yml"

.. note:: 
    Expression data should already be cleaned and normalized, through e.g. a log transformation and winsorization.

We can view both the expression data and marker data:

In [30]:
!head -n 20 ../../../astir/tests/test-data/jackson-2020-markers.yml


cell_states:
  RTK_signalling:
    - Her2
    - EGFR
  proliferation:
    - Ki-67
    - phospho Histone
  mTOR_signalling:
    - phospho mTOR
    - phospho S6
  apoptosis:
    - cleaved PARP
    - Cleaved Caspase3

cell_types:
  Stromal:
    - Vimentin
    - Fibronectin
  B cells:


In [31]:
pd.read_csv(expression_mat_path, index_col=0)[['EGFR','E-Cadherin', 'CD45', 'Cytokeratin 5']].head()

Unnamed: 0,EGFR,E-Cadherin,CD45,Cytokeratin 5
BaselTMA_SP41_186_X5Y4_3679,0.346787,0.938354,0.22773,0.095283
BaselTMA_SP41_153_X7Y5_246,0.833752,1.364884,0.068526,0.124031
BaselTMA_SP41_20_X12Y5_197,0.110006,0.177361,0.301222,0.05275
BaselTMA_SP41_14_X1Y8_84,0.282666,1.122174,0.606941,0.093352
BaselTMA_SP41_166_X15Y4_266,0.209066,0.402554,0.588273,0.064545


Then we can create an astir object using the `from_csv_yaml` function. For more data loading options, see the data loading tutorial.

In [32]:
ast = from_csv_yaml(expression_mat_path, marker_yaml=yaml_marker_path)
print(ast)

Astir object with 6 columns of cell types, 4 columns of cell states and 100 rows.


## 2. Fitting cell types

To fit cell types, simply call

In [33]:
ast.fit_type()

training astir:   0%|          | 0/10 [00:00<?, ?epochs/s]

---------- 1/5 Cell Type Classification ----------


training astir: 100%|██████████| 10/10 [00:00<00:00, 49.98epochs/s]
training astir: 100%|██████████| 10/10 [00:00<00:00, 61.60epochs/s]
training astir:   0%|          | 0/10 [00:00<?, ?epochs/s]

Done!
---------- 2/5 Cell Type Classification ----------
Done!
---------- 3/5 Cell Type Classification ----------


training astir: 100%|██████████| 10/10 [00:00<00:00, 50.30epochs/s]
training astir: 100%|██████████| 10/10 [00:00<00:00, 58.09epochs/s]
training astir:   0%|          | 0/10 [00:00<?, ?epochs/s]

Done!
---------- 4/5 Cell Type Classification ----------
Done!
---------- 5/5 Cell Type Classification ----------


training astir: 100%|██████████| 10/10 [00:00<00:00, 57.31epochs/s]

Done!





.. note:: 
    **Controlling inference**
    There are many different options for controlling inference in the `fit_type` function, including
    `max_epochs` (maximum number of epochs to train),
    `learning_rate` (ADAM optimizer learning rate),
    `batch_size` (minibatch size),
    `delta_loss` (stops iteration once the change in loss falls below this value),
    `n_inits` (number of restarts using random initializations).
    For full details, see the function documentation.

We can then get cell type assignment probabilities by calling

In [34]:
assignments = ast.get_celltype_probabilities()
assignments

Unnamed: 0,Stromal,B cells,T cells,Macrophage,Epithelial (basal),Epithelial (luminal),Other
BaselTMA_SP41_186_X5Y4_3679,0.218146,0.260286,0.078100,0.300784,0.020030,0.009929,0.112726
BaselTMA_SP41_153_X7Y5_246,0.007433,0.037888,0.013495,0.019616,0.171420,0.721267,0.028881
BaselTMA_SP41_20_X12Y5_197,0.323534,0.161405,0.194188,0.204150,0.014537,0.003601,0.098586
BaselTMA_SP41_14_X1Y8_84,0.114065,0.185562,0.170916,0.125238,0.089930,0.239260,0.075029
BaselTMA_SP41_166_X15Y4_266,0.307228,0.243561,0.163690,0.217983,0.010894,0.004304,0.052340
...,...,...,...,...,...,...,...
BaselTMA_SP41_114_X13Y4_1057,0.019108,0.071844,0.034305,0.040122,0.172589,0.621588,0.040443
BaselTMA_SP41_141_X11Y2_2596,0.016712,0.061177,0.028747,0.035444,0.171387,0.639128,0.047405
BaselTMA_SP41_100_X15Y5_170,0.169132,0.199137,0.129711,0.207781,0.061975,0.070236,0.162027
BaselTMA_SP41_14_X1Y8_2604,0.251748,0.244889,0.077674,0.311692,0.012417,0.003847,0.097733


where each row corresponds to a cell, and each column to a cell type, with the entry being the probability of that cell belonging to a particular cell type.

To fetch an array corresponding to the most likely cell type assignments, call

In [35]:
# TODO

## 3. Fitting cell state

Similarly as before, to fit cell state, call

In [36]:
ast.fit_state()

training astir: 100%|██████████| 100/100 [00:00<00:00, 733.79epochs/s]
training astir:   0%|          | 0/100 [00:00<?, ?epochs/s]

---------- 1/5 Cell State Classification ----------
---------- 2/5 Cell State Classification ----------


training astir: 100%|██████████| 100/100 [00:00<00:00, 862.17epochs/s]
training astir: 100%|██████████| 100/100 [00:00<00:00, 748.04epochs/s]
training astir:   0%|          | 0/100 [00:00<?, ?epochs/s]

---------- 3/5 Cell State Classification ----------
---------- 4/5 Cell State Classification ----------


training astir: 100%|██████████| 100/100 [00:00<00:00, 812.77epochs/s]
training astir: 100%|██████████| 100/100 [00:00<00:00, 863.41epochs/s]
training astir: 0epochs [00:00, ?epochs/s]

---------- 5/5 Cell State Classification ----------





and cell state assignments can be inferred via

In [37]:
states = ast.get_cellstates()
states

Unnamed: 0,RTK_signalling,proliferation,mTOR_signalling,apoptosis
BaselTMA_SP41_186_X5Y4_3679,-0.055859,-0.248295,0.081537,-0.440844
BaselTMA_SP41_153_X7Y5_246,0.026950,0.334187,-0.029829,-0.233749
BaselTMA_SP41_20_X12Y5_197,-0.107626,0.441981,-0.379208,-0.184434
BaselTMA_SP41_14_X1Y8_84,-0.171197,-0.700986,0.160324,-0.510929
BaselTMA_SP41_166_X15Y4_266,0.399416,-0.046144,-0.352460,-0.801181
...,...,...,...,...
BaselTMA_SP41_114_X13Y4_1057,0.602089,-0.176759,0.170743,0.660372
BaselTMA_SP41_141_X11Y2_2596,-0.267246,1.100570,-0.607341,-0.432088
BaselTMA_SP41_100_X15Y5_170,-0.506588,-0.183379,0.653195,-1.094012
BaselTMA_SP41_14_X1Y8_2604,-1.027609,0.369216,0.472711,-0.676119


## 4. Saving results

Both cell type and cell state information can easily be saved to disk via

In [38]:
ast.type_to_csv("cell-types.csv")
ast.state_to_csv("cell-states.csv")

In [39]:
!head -n 3 cell-types.csv

,Stromal,B cells,T cells,Macrophage,Epithelial (basal),Epithelial (luminal),Other
BaselTMA_SP41_186_X5Y4_3679,0.2181457999075033,0.26028579493395343,0.0780996469571123,0.3007839852656864,0.02002957230964695,0.009929141071520391,0.11272605955457736
BaselTMA_SP41_153_X7Y5_246,0.007432725661534547,0.03788808526809271,0.013495344143837877,0.01961641415481618,0.1714198088912464,0.721266630941644,0.028880990938828413


In [40]:
!head -n 3 cell-states.csv

,RTK_signalling,proliferation,mTOR_signalling,apoptosis
BaselTMA_SP41_186_X5Y4_3679,-0.055859283,-0.24829505,0.08153685,-0.44084388
BaselTMA_SP41_153_X7Y5_246,0.026950097,0.33418685,-0.029828804,-0.23374897


where the first (unnamed) column always corresponds to the cell name/ID.

## 5. Accessing internal functions and data

Data stored in `astir` objects is in the form of an `SCDataSet`. These can be retrieved via

In [41]:
celltype_data = ast.get_type_dataset()
celltype_data

<astir.models.scdataset.SCDataset at 0x14461a7c0>

and similarly for cell state via `ast.get_state_dataset()`.

These have several helper functions to retrieve relevant information to the dataset:

In [42]:
celltype_data.get_cells()[0:4] # cell names

['BaselTMA_SP41_186_X5Y4_3679',
 'BaselTMA_SP41_153_X7Y5_246',
 'BaselTMA_SP41_20_X12Y5_197',
 'BaselTMA_SP41_14_X1Y8_84']

In [43]:
celltype_data.get_classes() # cell type names

['Stromal',
 'B cells',
 'T cells',
 'Macrophage',
 'Epithelial (basal)',
 'Epithelial (luminal)']

In [44]:
print(celltype_data.get_n_classes()) # number of cell types
print(celltype_data.get_n_features()) # number of features / proteins

6
14


In [45]:
celltype_data.get_exprs() # Return a torch tensor corresponding to the expression data used

tensor([[0.2277, 0.0571, 2.2273,  ..., 2.2151, 0.0953, 0.1909],
        [0.0685, 0.4853, 0.2083,  ..., 0.5026, 0.1240, 0.6859],
        [0.3012, 0.0359, 0.5816,  ..., 0.8102, 0.0527, 0.1160],
        ...,
        [0.0869, 0.0000, 0.5382,  ..., 0.7593, 0.0965, 0.3212],
        [0.2395, 0.0823, 1.9671,  ..., 2.2464, 0.0557, 0.1326],
        [0.2476, 0.0000, 0.3162,  ..., 3.1238, 0.0803, 0.1450]],
       dtype=torch.float64)