# Training

> Notebook to train deep learning models or ensembles for segmentation of fluorescent labels in microscopy images.

This notebook is optmizied to be executed on [Google Colab](https://colab.research.google.com).

* Press the the *play* butten to execute the cells. It will show up between \[     \] on the left side of the code cells. 
* Run the cells consecutively. Skip cells that do not apply for your case.
* Use Firefox or Google Chrome if you want to upload files

In [None]:
#@title Set up environment
#@markdown Please run this cell to get started.
%load_ext autoreload
%autoreload 2
try:
    from google.colab import files, drive
except ImportError:
    pass
try:
    import deepflash2
except ImportError:
    !pip install -q deepflash2
import zipfile
import imageio
from sklearn.model_selection import KFold, train_test_split
from fastai.vision.all import *
from deepflash2.all import *

## Provide Training Data

### Required data structure

__Structure__

- __One folder for training images__
- __One folder for segmentation masks__

_Examplary structure:_

* [folder] images
  * [file] 0001.tif
  * [file] 0002.tif
* [folder] masks
  * [file] 0001_mask.png
  * [file] 0002_mask.png

__Naming__

- Images names must have unique ID
    - _ID: 0001 -> 0001.tif; ID: img_1 --> img_1.png, ..._ 
- Masks must start with ID + a mask suffix
    - _0001 -> 0001_mask.png (mask_suffix = "_mask.png")_
    - _0001 -> 0001.png (mask_suffix = ".png")_

### Colab (recommended)

Working on _Google Colab_, this section allows you to upload a *zip* folder or connect to your _Google Drive_.

#### Upload _zip_ file

- The *zip* file must contain all images and segmentations and correct folder structure. 
- See [here](https://www.hellotech.com/guide/for/how-to-zip-a-file-mac-windows-pc) how to _zip_ files on Windows or Mac.

In [None]:
#@markdown Run to upload a *zip* file
path = Path('data')
try:
    u_dict = files.upload()
    for key in u_dict.keys():
        zip_ref = zipfile.ZipFile(key, 'r')
        zip_ref.extractall(path)
        zip_ref.close()
except:
    print("Warning: File upload only works on Google Colab.")
    pass

#### Connect to _Google Drive_

- The folder in your drive must contain all segmentations and correct folder structure. 
- See [here](https://support.google.com/drive/answer/2375091?co=GENIE.Platform%3DDesktop&hl=en) how to organize your files in _Google Drive_.
- See this [stackoverflow post](https://stackoverflow.com/questions/46986398/import-data-into-google-colaboratory) for browsing files with the file browser

In [None]:
#@markdown Provide the path to the folder on your _Google Drive_
try:
    drive.mount('/content/drive')
    path = "/content/drive/My Drive/data" #@param {type:"string"}
    path = Path(path)
    #@markdown Example: "/content/drive/My Drive/data"
except:
    print("Warning: Connecting to Google Drive only works on Google Colab.")
    pass

### Local Installation

If you're working on your local machine or server, provide a path to the correct folder.

In [None]:
#@markdown Provide path (either relative to notebook or absolute) and run cell
path = "my_data" #@param {type:"string"}
path = Path(path)
#@markdown Example: "expert_segmentations"

### Try with sample data

If you don't have any data available yet, try our sample data

In [None]:
#@markdown Run to use sample files
path = Path('sample_data_train')
url = "https://github.com/matjesg/deepflash2/releases/download/model_library/wue1_cFOS_small.zip"
urllib.request.urlretrieve(url, 'sample_data.zip');
zip_ref = zipfile.ZipFile('sample_data.zip', 'r')
zip_ref.extractall(path)
zip_ref.close()

## Check and load data

In [None]:
#@markdown Provide your parameters according to your provided data
image_folder = "images" #@param {type:"string"}
mask_folder = "masks" #@param {type:"string"}
mask_suffix = "_cFOS.png" #@param {type:"string"}

Set [mask weights](https://matjesg.github.io/deepflash2/data.html#Weight-Calculation) parameters for training.
- Default values should work for most of the data. 
- However, this choice can significantly change the model performance later on.

In [None]:
#@markdown Run to set mask weight parameters
border_weight_sigma=6 #@param {type:"number"}
foreground_dist_sigma=1 #@param {type:"number"}
border_weight_factor=10 #@param {type:"number"}
foreground_background_ratio=0.1 #@param {type:"number"}

mw_dict = {'bws': border_weight_sigma,
           'fds': foreground_dist_sigma, 
           'bwf': border_weight_factor,
           'fbr' : foreground_background_ratio}

In [None]:
#@markdown **Check and load data**
files = get_image_files(path/image_folder)
label_fn = lambda o: path/mask_folder/f'{o.stem}{mask_suffix}'
#Check if corresponding masks exist
mask_check = [os.path.isfile(label_fn(x)) for x in files]
if len(files)==sum(mask_check):
    print(f'Found {len(files)} images and {sum(mask_check)} masks on path {path}.')
    #@markdown Number of classes: e.g., 2 for binary segmentation (foreground and background class)
    n_classes = 2 #@param {type:"integer"}
    #@markdown Check if you are providing instance labels (class-aware and instance-aware)
    instance_labels = False #@param {type:"boolean"}
    ds = RandomTileDataset(files, label_fn, n_classes=n_classes, instance_labels=instance_labels)
else:
    print(f'IMAGE/MASK MISMATCH! Found {len(files)} images and {sum(mask_check)} masks on path {path}.')
    print('Please check the steps above.')

In [None]:
#@markdown Run to show data. { run: "auto" }
#@markdown Use the slider to control the number of displayed images
first_n = 6 #@param {type:"slider", min:1, max:100, step:1}
ds.show_data(first_n, figsize=(15,15), overlay=False)

## Model Defintion

Select model [model architecture](https://matjesg.github.io/deepflash2/models.html#U-Net-architectures)

In [None]:
#@markdown { run: "auto" }
model_arch = 'unet_deepflash2' #@param ["unet_deepflash2",  "unet_falk2019", "unet_ronnberger2015"]

Select [pretraind](https://matjesg.github.io/deepflash2/) model weights
- See here for data description
- Select 'new' to use an untrained model (no pretrained weights)


In [None]:
pretrained_weights = "new" #@param ["new", "cFOS", "Parv"]
pre = False if pretrained_weights=="new" else True
n_channels = ds.get_data(max_n=1)[0].shape[-1]
model = torch.hub.load('matjesg/deepflash2', model_arch, pretrained=pre, n_classes=ds.c, in_channels=n_channels)
if pretrained_weights=="new": apply_init(model)

## Model Training

In [None]:
#@markdown Run to setup model for training.
cbs = [SaveModelCallback(monitor='iou'), ElasticDeformCallback]
metrics = [Dice_f1(), Iou()]
loss_fn = WeightedSoftmaxCrossEntropy(axis=1)
dls = DataLoaders.from_dsets(ds,ds, bs=4)
if torch.cuda.is_available(): dls.cuda(), model.cuda()
learn = Learner(dls, model, metrics = metrics, wd=0.001, loss_func=loss_fn, cbs=cbs)#.to_fp16()

### Setting training paramers

**Finding a good learning rate**

According to the [fastai docs](https://docs.fast.ai/callback.schedule#LRFinder), a good value for the learning rates is then either :

- one tenth of the minimum before the divergence
- when the slope is the steepest

In our experiments, we found that a **maximum learning rate of 5e-4** (e.g., 0.0005) yielded the best results across experiments.

In [None]:
#@markdown Run to a find suitable learning rate
lr_min,lr_steep = learn.lr_find()
print(f"Minimum/10: {lr_min:.2e}, steepest point: {lr_steep:.2e}")

In [None]:
#@title Set training paramers { run: "auto" }
#@markdown Set `max_lr` according to the learning rate finder.
max_lr = 5e-4 #@param {type:"number"}
#@markdown Number of models to train. If you're experimenting with parameters, try only one model first.
n_models = 1 #@param {type:"slider", min:1, max:5, step:1}
#@markdown One epoch is when an entire (augemented) dataset is passed through the model for training.
#@markdown We found that about 30 epochs is sufficient to train a model with 36 images. 
#@markdown Select more epochs for smaller datasets.
epochs = 30 #@param {type:"slider", min:1, max:100, step:1}
batch_size = 4

### Train models
- Using a train-test-split of 0.75/0.25 for one model
- Using _k-fold cross validation_ for model ensembles

In [None]:
#@markdown Run to train model(s).<br/> **THIS CAN TAKE A FEW HOURS FOR MULTIPLE MODELS!**
kf = KFold(n_splits=max(n_models,2))
model_path = path/'models'
model_path.mkdir(parents=True, exist_ok=True)
res, res_mc = {}, {}
fold = 0
for train_idx, val_idx in kf.split(files):
    fold += 1
    name = f'model{fold}'
    print('Train', name)
    if n_models==2:
        files_train, files_val = train_test_split(files)
    else:
        files_train, files_val = files[train_idx], files[val_idx]
        
    train_ds = RandomTileDataset(files_train, label_fn)
    valid_ds = TileDataset(files_val, label_fn)
    
    dls = DataLoaders.from_dsets(train_ds, valid_ds, bs=batch_size)
    dls_valid = DataLoaders.from_dsets(valid_ds, batch_size=batch_size ,shuffle=False, drop_last=False)
    model = torch.hub.load('matjesg/deepflash2', model_arch, pretrained=pre, n_classes=ds.c, in_channels=n_channels)
    if pretrained_weights=="new": apply_init(model)
    if torch.cuda.is_available(): dls.cuda(), model.cuda(), dls_valid.cuda()
        
    learn = Learner(dls, model, metrics = metrics, wd=0.001, loss_func=loss_fn, cbs=cbs)
    learn.fit_one_cycle(epochs, max_lr)
    save_model(model_path/f'{name}.pth', learn.model, opt=None)
    
    smxs, segs, _ = learn.predict_tiles(dl=dls_valid.train)    
    smxs_mc, segs_mc, std = learn.predict_tiles(dl=dls_valid.train, mc_dropout=True, n_times=10)
    
    for i, file in enumerate(files_val):
        res[(name, file)] = smxs[i], segs[i]
        res_mc[(name, file)] = smxs_mc[i], segs_mc[i], std[i]
    
    if n_models==1:
        break

## Validate models and ensembles

Here you can validate your results. 
If you choose to only train one model (`n_models = 1`), ensemble and model results will be the same.

In [None]:
#@markdown Create folders to save the resuls. They will be created at your provided 'path'.
pred_dir = 'val_preds' #@param {type:"string"}
pred_path = path/pred_dir/'ensemble'
pred_path.mkdir(parents=True, exist_ok=True)
uncertainty_dir = 'val_uncertainties' #@param {type:"string"}
uncertainty_path = path/uncertainty_dir/'ensemble'
uncertainty_path.mkdir(parents=True, exist_ok=True)

#@markdown Define `filetype` to save the predictions and uncertainties. All common filetypes are supported.
filetype = 'png' #@param {type:"string"}

In [None]:
#@markdown Show and save ensemble results
val_files = set([f for _, f in res.keys()])
res_list = []
for file in val_files:
    img = ds.get_data(file)[0]
    msk = ds.get_data(file, mask=True)[0]
    pred = ensemble_results(res, file)
    pred_std = ensemble_results(res_mc, file, idx=2)
    df_tmp = pd.Series({'file' : file.name, 'iou': iou(msk, pred) ,'entropy': mean_entropy(pred_std)})
    plot_results(img, msk, pred, pred_std, df=df_tmp)
    res_list.append(df_tmp)
    imageio.imsave(pred_path/f'{file.name}_pred.{filetype}', pred.astype(np.uint8) if np.max(pred)>1 else pred.astype(np.uint8)*255)
    imageio.imsave(uncertainty_path/f'{file.name}_uncertainty.{filetype}', pred_std.astype(np.uint8)*255)
df_res = pd.DataFrame(res_list)
df_res.to_csv(path/'val_ensemble_results.csv', index=False)

In [None]:
#@markdown Show and save (single) model results { run: "auto" }
model_number = 1 #@param {type:"slider", min:1, max:5, step:1}
model_name = f'model{model_number}'
val_files = [f for mod , f in res.keys() if mod == model_name]
pred_path = path/pred_dir/model_name
uncertainty_path = path/uncertainty_dir/model_name
uncertainty_path.mkdir(parents=True, exist_ok=True)
res_list = []
for file in val_files:
    img = ds.get_data(file)[0]
    msk = ds.get_data(file, mask=True)[0]
    pred = res[(model_name,file)][1]
    pred_std = res_mc[(model_name,file)][2][...,0]
    df_tmp = pd.Series({'file' : file.name, 'iou': iou(msk, pred) ,'entropy': mean_entropy(pred_std)})
    plot_results(img, msk, pred, pred_std, df=df_tmp)
    res_list.append(df_tmp)
    imageio.imsave(pred_path/f'{file.name}_pred.{filetype}', pred.astype(np.uint8) if np.max(pred)>1 else pred.astype(np.uint8)*255)
    imageio.imsave(uncertainty_path/f'{file.name}_uncertainty.{filetype}', pred_std.astype(np.uint8)*255)
df_res = pd.DataFrame(res_list)
df_res.to_csv(path/f'val_{model_name}_results.csv', index=False)

## Download Section

- The models will always be the _last_ version trained in Section _Model Training_
- To download validation predictions and uncertainties, you first need to execute Section _Validate models and ensembles_.

_Note: If you're connected to *Google Drive*, the models are automatically saved to your drive._

In [None]:
#@title Download models { run: "auto" }
model_number = "1" #@param ["1", "2", "3", "4", "5"]
model_path = path/'models'/f'model{model_number}.pth'
try:
    files.download(model_path)
except:
    print("Warning: File download only works on Google Colab.")
    print(f"Models are saved at {model_path.parent}")
    pass

In [None]:
#@markdown Download validation predicitions { run: "auto" }
zipObj = zipfile.ZipFile('val_predictions.zip', 'w')
for f in get_image_files(path/pred_dir):
      zipObj.write(f)
zipObj.close()
try:
    files.download(model_path)
except:
    print("Warning: File download only works on Google Colab.")
    pass

In [None]:
#@markdown Download validation uncertainties
zipObj = zipfile.ZipFile('val_uncertainties.zip', 'w')
for f in get_image_files(path/uncertainty_dir):
      zipObj.write(f)
zipObj.close()
try:
    files.download(model_path)
except:
    print("Warning: File download only works on Google Colab.")
    pass

In [None]:
#@markdown Download result analysis '.csv' files
zipObj = zipfile.ZipFile('val_results.zip', 'w')
for f in get_files(path, extensions='.csv'):
      zipObj.write(f)
zipObj.close()
try:
    files.download(model_path)
except:
    print("Warning: File download only works on Google Colab.")
    pass