### Prediction Notebook

Import packages

In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import pandas as pd
from skimage import io
import os
from deepflash import unet, preproc
import matplotlib.pyplot as plt
from tqdm import tqdm
%matplotlib inline

## Global Settings

In [None]:
PRETAINED = 'sc_falk_cFOS_None_rohini.0100'# 'caffe/caffe_weights.h5' #None
MASK = 'cFOS'
IMAGE = 'red'
CHANNELS_IMG = 1
DATA_PATH = "01_data"
ASSIGNMENT_PATH = 'Zuordnung_aktuell.xlsx'
TILE_SHAPE = (540,540)
PADDING = (184,184)
EL_SIZE = [635.9, 635.9] #micrometers
BATCH_NORM = False
CHECKPOINT_PATH = 'checkpoints_sc'
OUTPUT = 'predictions'

## Load Data

Excel list with assignments

In [None]:
assignment = pd.read_excel(ASSIGNMENT_PATH, converters={'Nummer': lambda x: str(x).zfill(4)})

assignment = assignment[(assignment['Genotyp']=='WT') & 
                       (assignment['region']=='dHC') & 
                       (assignment['Area'].isin(['CA1', 'CA3', 'DG'])) & 
                       (assignment['Experiment'].isin([1,2,3,4])) &
                       (assignment['Cross-coder Training'].isna()) & 
                       (assignment['Ausschluss von Analyse'].isna()) &
                       (assignment['broken'].isna())]

file_ids = assignment['Nummer'].tolist()

Images

In [None]:
image_list = [io.imread(os.path.join(DATA_PATH, img_name), as_gray=True) for 
              img_name in [s + '_' + IMAGE + '.tif' for s in file_ids]]

image_list = [np.expand_dims(img, axis=2) for img in image_list]
data = [{'rawdata': img, 'element_size_um': EL_SIZE} for img in image_list]

In [None]:
output_path = os.path.join(OUTPUT, PRETAINED)
if not os.path.isdir(output_path):
    os.makedirs(output_path)

In [None]:
## Predict
pred_model = unet.Unet2D(snapshot= os.path.join(CHECKPOINT_PATH, PRETAINED + '.h5'),
                n_channels=1, 
                n_classes=2, 
                n_levels=4, 
                batch_norm =  BATCH_NORM,
                upsample=False,
                relu_alpha=0.1,
                n_features=64,name="U-Net")

In [None]:
tile_generator = preproc.TileGenerator(data, TILE_SHAPE, PADDING)

In [None]:
predictions = pred_model.predict(tile_generator)

In [None]:
## Save
for i in range(len(predictions[1])):
    idx = file_ids[i]
    file_name = idx + '_' + MASK + '.png'
    io.imsave(os.path.join(output_path, file_name), predictions[1][i])