<a href="https://colab.research.google.com/github/matjesg/deepflash2/blob/master/paper/ood_detection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# deepflash2 - Out-of-distribution detection

> This notebook reproduces the results of the deepflash2 [paper](https://arxiv.org/abs/2111.06693) for out-of-distribution detection.

- **Data and models**: Data and trained models are available on [Google Drive](https://drive.google.com/drive/folders/1r9AqP9qW9JThbMIvT0jhoA5mPxWEeIjs?usp=sharing). To use the data in Google Colab, create a [shortcut](https://support.google.com/drive/answer/9700156?hl=en&co=GENIE.Platform%3DDesktop) of the data folder in your personal Google Drive.

*Source files created with this notebook*:

`ood_detection.csv`

*References*:

Griebel, M., Segebarth, D., Stein, N., Schukraft, N., Tovote, P., Blum, R., & Flath, C. M. (2021). Deep-learning in the bioimaging wild: Handling ambiguous data with deepflash2. arXiv preprint arXiv:2111.06693.


## Setup

- Install dependecies
- Connect to drive

In [None]:
!pip install deepflash2

In [None]:
# Imports
import numpy as np
import pandas as pd
from pathlib import Path
from deepflash2.all import *
from deepflash2.data import _read_msk

In [None]:
# Connect to drive
from google.colab import drive
drive.mount('/gdrive')

## Settings


In [None]:
IN_DATASET = 'cFOS_in_HC'
OOD_DATASETS = ['PV_in_HC', 'mScarlet_in_PAG', 'YFP_in_CTX', 'GFAP_in_HC']
OUTPUT_PATH = Path("/content") 
DATA_PATH = Path('/gdrive/MyDrive/deepflash2-paper/data')
TRAINED_MODEL_PATH= Path('/gdrive/MyDrive/deepflash2-paper/models/')
MODEL_NO = '1'
SOURCE_DATA_URL = 'https://github.com/matjesg/deepflash2/releases/download/paper_source_data/'

## Analysis

1. Collect files
  - 280 in-distribution files and party out-of-ditribuition files from `cFOS_in_HC` additional data
  - 32 fully out-of-ditribuition files from test sets of other datasets
2. Predict segmentations and uncertainty scores with model trained on `cFOS_in_HC` data

See `deepflash2_figures-and-tables.ipynb` for plots of the data.

In [None]:
# Collect files
add_data_path = DATA_PATH/IN_DATASET/'additional_images'
files = [f for f in add_data_path.iterdir()]

for dataset in OOD_DATASETS:
  test_data_path = DATA_PATH/dataset/'test'
  files += [f for f in (test_data_path/'images').iterdir()]

print(len(files))

# Predict segmentations and uncertainty scores 
ensemble_path = TRAINED_MODEL_PATH/IN_DATASET/MODEL_NO

el_pred = EnsembleLearner('images', # We will not use this data
                          path=test_data_path, 
                          ensemble_path=ensemble_path) 

# Predict and save semantic segmentation masks
_ = el_pred.get_ensemble_results(files, use_tta=True)

# Merge predictions with information
merge_cols = ['idx', 'dataset']
df = el_pred.df_ens
df['dataset'] = df['image_path'].apply(lambda x: x.parent.parent.parent.name)
df.loc[df['dataset']=='data', 'dataset'] = 'cFOS_in_HC'
df['idx'] = df['file'].str.split('.').str[0]
df[['idx', 'dataset', 'uncertainty_score']].to_csv('ood_detection.csv', index=False)
df_ood_scores = df[['idx', 'dataset', 'uncertainty_score']].set_index(merge_cols)
df_ood_info = pd.read_csv(SOURCE_DATA_URL+'ood_information.csv').set_index(merge_cols)
df_ood = df_ood_scores.join(df_ood_info).reset_index()

# Sort
df_ood = df_ood.sort_values('uncertainty_score', ascending=False).reset_index(drop=True)
df_ood['rank'] = df_ood.reset_index()['index']

# Save
df_ood.to_csv('ood_detection.csv', index=False)