# User guide

In [1]:
from src.database import Database
from src.protocol import Protocol, LABELS, CONDITIONS, RBPS
from src import utils
from src.dataset import CustomDataset

## Main concepts

There are 3 classes in this package:

* Class `Database`: defines a pandas dataframe containing annotations to select specific images according to a protocol

* Class `Protocol`: defines the rules to select images in the database. 

* Class `CustomDataset`: defines a custom dataset in pytorch to train the model on this dataset. The dataset is defined according to a protocol.

With these 3 classes and the functions in `src.utils`, you can train and/or evaluate a model

### Database

A `Database` object has several attributes: 
* `annotations`: annotations for the fluoMNs dataset, with all available images and their specifications. One row corresponds to one field of view of a well, with several channels and several planes. The `exclude` column indicates whether the sample should be excluded from the experiments, due to bad experimental conditions for example. 
* `all_samples`: all *valid* samples (annotations with `exclude == no`)
* `protocols`: all available protocols defined for this database

In [2]:
db=Database()
db.annotations

Unnamed: 0,experiment,plate,neuron_type,condition,stress_label,cell_line,als_label,well_row,well_col,number_of_channels,channel_1,channel_2,channel_3,channel_4,fov,number_of_planes,exclude
0,screenA,P1,D17,untreated,no_stress,CTRL1,control,2,2,3,DAPI,SFPQ,PABP,none,1,4,no
1,screenA,P1,D17,untreated,no_stress,CTRL1,control,2,2,3,DAPI,SFPQ,PABP,none,2,4,no
2,screenA,P1,D17,untreated,no_stress,CTRL1,control,2,2,3,DAPI,SFPQ,PABP,none,3,4,no
3,screenA,P1,D17,untreated,no_stress,CTRL1,control,2,2,3,DAPI,SFPQ,PABP,none,4,4,no
4,screenA,P1,D17,untreated,no_stress,CTRL1,control,2,2,3,DAPI,SFPQ,PABP,none,5,4,no
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
15815,screenI,P30,D17,untreated,no_stress,MUT6,als,7,11,4,DAPI,BIII,FUS,Caspase-3,7,5,no
15816,screenI,P30,D17,untreated,no_stress,MUT6,als,7,11,4,DAPI,BIII,FUS,Caspase-3,8,5,no
15817,screenI,P30,D17,untreated,no_stress,MUT6,als,7,11,4,DAPI,BIII,FUS,Caspase-3,9,5,no
15818,screenI,P30,D17,untreated,no_stress,MUT6,als,7,11,4,DAPI,BIII,FUS,Caspase-3,10,5,no


In [3]:
for protocol in db.protocols:
    print(protocol.name)

als_untreated_SFPQ
als_untreated_TDP-43
als_untreated_FUS
als_untreated_hnRNPA1
als_untreated_hnRNPK
als_untreated_all
als_osmotic_SFPQ
als_osmotic_TDP-43
als_osmotic_FUS
als_osmotic_hnRNPA1
als_osmotic_hnRNPK
als_osmotic_all
als_oxidative_SFPQ
als_oxidative_TDP-43
als_oxidative_FUS
als_oxidative_hnRNPA1
als_oxidative_hnRNPK
als_oxidative_all
als_heat_SFPQ
als_heat_TDP-43
als_heat_FUS
als_heat_hnRNPA1
als_heat_hnRNPK
als_heat_all
als_untreated_osmotic_SFPQ
als_untreated_osmotic_TDP-43
als_untreated_osmotic_FUS
als_untreated_osmotic_hnRNPA1
als_untreated_osmotic_hnRNPK
als_untreated_osmotic_all
als_untreated_oxidative_SFPQ
als_untreated_oxidative_TDP-43
als_untreated_oxidative_FUS
als_untreated_oxidative_hnRNPA1
als_untreated_oxidative_hnRNPK
als_untreated_oxidative_all
als_untreated_heat_SFPQ
als_untreated_heat_TDP-43
als_untreated_heat_FUS
als_untreated_heat_hnRNPA1
als_untreated_heat_hnRNPK
als_untreated_heat_all
als_untreated_heat_2h_SFPQ
als_untreated_heat_2h_TDP-43
als_untreated_h

### Protocol

A `Protocol` defines which data are selected in the database. When initializing a protocol, we specify which ***labels***, which ***conditions*** and which ***RBP*** we want. 

The **label** tells whether the cell line in the cell culture is healthy (`'control'`) or als-mutant (`'als'`). The **condition** corresponds to the stress condition under which the cell culture is put (e.g. `'untreated'` if the cell is not put under stress). The **RBP** (stands for RNA-binding protein) is the fluorescent marker that we choose. `'all'` stands for all RBPs (used when we want the DAPI or BIII markers because all RBP subsets have the DAPI and BIII channels). 

Here are the different possibilities: 

In [4]:
print('labels: ',LABELS)
print('conditions: ', CONDITIONS)
print('rbps: ', RBPS)

labels:  ['control', 'als']
conditions:  ['untreated', 'oxidative', 'heat', 'heat_2h', 'osmotic', 'osmotic_1h', 'osmotic_2h', 'osmotic_6h']
rbps:  ['SFPQ', 'TDP-43', 'FUS', 'hnRNPA1', 'hnRNPK', 'all']


A protocol can have either one or two labels, and one or two conditions. 

To classify between healthy and als-mutant cells, you need both `'als'` and `'control'` in your protocol. To classify between unstressed and stressed cells, you need both `'untreated'` and a stressor such as `'osmotic'` in your protocol. 

Here is an example `protocol`, with images of control and als untreated motor neurons, where the ``'TDP-43'`` channel must be available. **Note**: this protocol corresponds to the images available in the data subset which can be downloaded [**here**](https://zenodo.org/record/4664177). 

In [11]:
protocol = Protocol(['control', 'als'], ['untreated'], 'TDP-43')

Each `protocol` has a `name` method, which corresponds to `labels_conditions_rbp`. Please not that labels and conditions are always sorted in reverse alphabetical order. 

In [12]:
protocol.name

'control_als_untreated_TDP-43'

Here are the corresponding annotations from the database for this protocol: 

In [13]:
db.get_protocol_data(protocol)

Unnamed: 0,experiment,plate,neuron_type,condition,stress_label,cell_line,als_label,well_row,well_col,number_of_channels,channel_1,channel_2,channel_3,channel_4,fov,number_of_planes,exclude
0,screenE,P11,D6,untreated,no_stress,CTRL2,control,2,3,4,DAPI,BIII,SFPQ,TDP-43,1,4,no
1,screenE,P11,D6,untreated,no_stress,CTRL2,control,2,3,4,DAPI,BIII,SFPQ,TDP-43,2,4,no
2,screenE,P11,D6,untreated,no_stress,CTRL2,control,2,3,4,DAPI,BIII,SFPQ,TDP-43,3,4,no
3,screenE,P11,D6,untreated,no_stress,CTRL2,control,2,3,4,DAPI,BIII,SFPQ,TDP-43,4,4,no
4,screenE,P11,D6,untreated,no_stress,CTRL2,control,2,3,4,DAPI,BIII,SFPQ,TDP-43,5,4,no
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
975,screenI,P29,D6,untreated,no_stress,MUT2,als,7,9,4,DAPI,BIII,SFPQ,TDP-43,7,5,no
976,screenI,P29,D6,untreated,no_stress,MUT2,als,7,9,4,DAPI,BIII,SFPQ,TDP-43,8,5,no
977,screenI,P29,D6,untreated,no_stress,MUT2,als,7,9,4,DAPI,BIII,SFPQ,TDP-43,9,5,no
978,screenI,P29,D6,untreated,no_stress,MUT2,als,7,9,4,DAPI,BIII,SFPQ,TDP-43,10,5,no


### CustomDataset

When using Pytorch for training and testing, you need a dataset class over which Pytorch will iterate to find your images. See this tutorial for an example: https://pytorch.org/tutorials/beginner/data_loading_tutorial.html. Our `CustomDataset` represents the actual dataset of images that we will use either for training or testing. 

To keep it short, this class selects some indices from a protocol (could be either indices from the train set or from the test set), finds the corresponding images in the source directory, applies some preprocessing on those images and stores them in a target directory. 

**CAUTION** : Every time you create a `CustomDataset`, a target directory is created and starts storing images in it (if you access elements in the dataset). This can take quite some space so don't forget to use the `delete` method to delete the target directory when you are done with your experiment. 

To actually use this dataset in training and testing, we use Pytorch dataloaders (defined in `src.utils`) to iterate over the elements. 

## Training

There are 52 trainable models (accessible [**here**](https://zenodo.org/record/4664252)) in total, 13 for each of the following binary classifications: 
* ``'als'`` vs ``'control'``, 
* ``'untreated'`` vs ``'osmotic'`` stress, 
* ``'untreated'`` vs ``'oxidative'`` stress, and 
* ``'untreated'`` vs ``'heat'`` stress

The 13 models correspond to different channels combinations that we want to compare: 
* ``'DAPI'``, 
* ``'BIII'``, 
* ``'DAPI-BIII'``, 
* ``'SFPQ'``, 
* ``'FUS'``, 
* ``'TDP-43'``, 
* ``'hnRNPA1'``, 
* ``'hnRNPK'``, 
* ``'DAPI-BIII-SFPQ'``, 
* ``'DAPI-BIII-FUS'``, 
* ``'DAPI-BIII-TDP-43'``, 
* ``'DAPI-BIII-hnRNPA1'``, 
* ``'DAPI-BIII-hnRNPK'``. 

If you have GPU resources (strongly recommended), you can retrain the models using `train_model.py`. It is a command-line application that you can call from the terminal with the desired options: 

In [1]:
%run ../src/train_model.py --help

Usage: train_model.py [OPTIONS]

  Train a model

Options:
  -c, --config_name TEXT          Name of the configuration file  [required]
  -cl, --classification [stress|als]
                                  Classify stress vs untreated (choose
                                  'stress') or als vs control (choose 'als')
                                  [required]

  -p, --protocol_name TEXT        Name of the protocol for training (available
                                  protocols are listed in database.protocols)
                                  [required]

  -ch, --channels [DAPI|BIII|SFPQ|FUS|TDP-43|hnRNPA1|hnRNPK]
                                  List of channels used for training
                                  [required]

  -f, --fold INTEGER              Fold for 10-fold cross validation (between 0
                                  and 9)  [required]

  -s, --save_state_dict BOOLEAN   Save state dict of the trained model
  --dry_run                       Dry run for test

For example, if you have downloaded the subset of data on Zenodo ([**here**](https://zenodo.org/record/4664177)), you can run the following line. It will train the model to classify images of ``'als'`` and ``'control'`` cultures of untreated motor neurons,  only using the ``'TDP-43'`` channel. With the available images from this subset, you can also use either ``'DAPI'``, ``'BIII'``, ``'DAPI-BIII'`` or ``'DAPI-BIII-TDP-43'``. 

In [37]:
%run ../src/train_model.py -c user -cl als -p control_als_untreated_TDP-43 -ch TDP-43 -f 0 -s False 

Number of training samples : 59424
Number of test samples : 1104


Training models will save results in `results` > `auc.csv`. This file contains the AUC (measure of performance) on the test set for each fold in 10-fold cross validation for each protocol associated with channels. 

## Predictions

If you don't have GPU resources or you simply want to evaluate already trained models on some data, you can use `predict_model.py`. It is a command-line application that you can call from the terminal with the desired options: 

In [2]:
%run ../src/predict_model.py --help

Usage: predict_model.py [OPTIONS]

  Evaluate a model

Options:
  -c, --config_name TEXT          Name of the configuration file  [required]
  -cl, --classifier [als|osmotic|oxidative|heat]
                                  Choose classifier: als vs control (choose
                                  'als'), osmotic vs untreated (choose
                                  'osmotic'), oxidative  vs untreated (choose
                                  'oxidative') or heat vs untreated (choose
                                  'heat')  [required]

  -e, --expert [DAPI|BIII|DAPI_BIII|SFPQ|FUS|TDP-43|hnRNPA1|hnRNPK|DAPI_BIII_SFPQ|DAPI_BIII_FUS|DAPI_BIII_TDP-43|DAPI_BIII_hnRNPA1|DAPI_BIII_hnRNPK]
                                  Choose one of the 13 experts (corresponds to
                                  a combination of channels)  [required]

  -la, --label [control|als]      Label of images on which the expert will be
                                  evaluated  [required]

  -co, --conditio

For example, if you have downloaded the subset of data on Zenodo ([**here**](https://zenodo.org/record/4664177)), and the trained model entitled ``'state_dict_control_als_untreated_TDP-43_TDP-43_fold_0.pt'``, you can run the following line. It will evaluate the model which was trained to classify images of ``'als'`` and ``'control'`` cultures of untreated motor neurons using the ``'TDP-43'`` channel on images of ``'control'`` cultures of untreated motor neurons, which were not seen during training. You can also evaluate on images of ``'als'`` cultures of untreated motor neurons.  

In [36]:
%run ../src/predict_model.py -c user -cl als -e TDP-43 -la control -co untreated 

expert:control_als_untreated_TDP-43_TDP-43
data:control_untreated_TDP-43
---------------------------------------------
Number of test samples : 1600


Evaluating models will save results in `results` > `image_probabilities.csv`. This file contains the output probabilities from each of the 52 models, on each *valid* image of the dataset (i.e. models trained on images with the SFPQ channel are only evaluated on images containing this channel). 

## Figures

Results from our experiments are stored in `results` > `auc.csv`, `image_probabilities.csv`. If you want to reproduce some figures, you can use functions in `visualization` > `visualize_utils.py` or simply generate figures from the terminal with ``visualize.py``: 

In [31]:
%run ../src/visualization/visualize.py --help

Usage: visualize.py [OPTIONS]

Options:
  -f, --figure [1D|2A|2D|3A|3C|4A|2C|3D|4B|5A|5B|6A]
                                  Choose the figure to generate  [required]
  --help                          Show this message and exit.
