<a href="https://colab.research.google.com/github/matjesg/DeepFLaSH2/blob/master/Deepflash_Predict.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Prediction Notebook

### Colab options

Connect to dfpredict

In [0]:
try:
    from google.colab import drive
    drive.mount('/content/drive')
    !git clone https://github.com/matjesg/DeepFLaSH2.git /content/drive/My\ Drive/DeepFLaSH2
    %cd /content/drive/My\ Drive/DeepFLaSH2
    !git pull
except:
    pass

Import packages

In [0]:
%load_ext autoreload
%autoreload 2
import numpy as np
import pandas as pd
from skimage import io
import os
from deepflash import unet, preproc
import matplotlib.pyplot as plt
from tqdm import tqdm
%matplotlib inline

## Global Settings

In [0]:
PRETAINED = 'all_falk_cFOS_triangular.0100'
MASK = 'cFOS'
IMAGE = 'red'
CHANNELS_IMG = 1
DATA_PATH = "data"
ASSIGNMENT_PATH = 'Zuordnung_aktuell.xlsx'
TILE_SHAPE = (540,540)
PADDING = (184,184)
EL_SIZE = [635.9, 635.9] #micrometers
BATCH_NORM = False
CHECKPOINT_PATH = 'checkpoints_sc'
OUTPUT = 'predictions'

## Load Data

Excel list with assignments

In [0]:
assignment = pd.read_excel(ASSIGNMENT_PATH, converters={'Nummer': lambda x: str(x).zfill(4)})

assignment = assignment[(assignment['Genotyp']=='WT') & 
                       (assignment['region']=='dHC') & 
                       (assignment['Area'].isin(['CA1', 'CA3', 'DG'])) & 
                       (assignment['Experiment'].isin([1,2,3,4])) &
                       (assignment['Cross-coder Training'].isna()) & 
                       (assignment['Ausschluss von Analyse'].isna()) &
                       (assignment['broken'].isna())]

file_ids = assignment['Nummer'].tolist()

Images

In [0]:
image_list = [io.imread(os.path.join(DATA_PATH, MASK, img_name), as_gray=True) for 
              img_name in [s + '_' + IMAGE + '.tif' for s in file_ids]]

image_list = [np.expand_dims(img, axis=2) for img in image_list]
data = [{'rawdata': img, 'element_size_um': EL_SIZE} for img in image_list]

In [0]:
output_path = os.path.join(OUTPUT, PRETAINED)
if not os.path.isdir(output_path):
    os.makedirs(output_path)

### Predict

In [0]:
pred_model = unet.Unet2D(snapshot= os.path.join(CHECKPOINT_PATH, PRETAINED + '.h5'),
                n_channels=1, 
                n_classes=2, 
                n_levels=4, 
                batch_norm =  BATCH_NORM,
                upsample=False,
                relu_alpha=0.1,
                n_features=64,name="U-Net")

In [0]:
tile_generator = preproc.TileGenerator(data, TILE_SHAPE, PADDING)

In [0]:
predictions = pred_model.predict(tile_generator)

## Save

In [0]:
for i in range(len(predictions[1])):
    idx = file_ids[i]
    file_name = idx + '_' + MASK + '.png'
    io.imsave(os.path.join(output_path, file_name), predictions[1][i])