# Training

> Notebook to train deep learning models or ensembles for segmentation of fluorescent labels in microscopy images.

This notebook is optmizied to be executed on [Google Colab](https://colab.research.google.com).

* Press the the *play* butten to execute the cells. It will show up between \[     \] on the left side of the code cells. 
* Run the cells consecutively. Skip cells that do not apply for your case.
* Use Firefox or Google Chrome if you want to upload files

In [None]:
#@title Set up environment
#@markdown Please run this cell to get started.
%load_ext autoreload
%autoreload 2
try:
    from google.colab import files, drive
except ImportError:
    pass
try:
    import deepflash2
except ImportError:
    !pip install -q deepflash2
import zipfile
import shutil
import imageio
from sklearn.model_selection import KFold, train_test_split
from fastai.vision.all import *
from deepflash2.all import *

## Provide Training Data

### Required data structure

__Structure__

- __One folder for training images__
- __One folder for segmentation masks__

_Examplary structure:_

* [folder] images
  * [file] 0001.tif
  * [file] 0002.tif
* [folder] masks
  * [file] 0001_mask.png
  * [file] 0002_mask.png

__Naming__

- Images names must have unique ID
    - _ID: 0001 -> 0001.tif; ID: img_1 --> img_1.png, ..._ 
- Masks must start with ID + a mask suffix
    - _0001 -> 0001_mask.png (mask_suffix = "_mask.png")_
    - _0001 -> 0001.png (mask_suffix = ".png")_

### Working on _Google Colab_ (recommended)

This section allows you to upload a *zip* folder or connect to your _Google Drive_.

#### Upload _zip_ file

- The *zip* file must contain all images and segmentations and correct folder structure. 
- See [here](https://www.hellotech.com/guide/for/how-to-zip-a-file-mac-windows-pc) how to _zip_ files on Windows or Mac.

In [None]:
#@markdown Run to upload a *zip* file
path = Path('data')
try:
    u_dict = files.upload()
    for key in u_dict.keys():
        unzip(path, key)
    print('Path contains the following files and folders: \n', L(os.listdir(path)))
except:
    print("Warning: File upload only works on Google Colab.")
    pass

#### Connect to _Google Drive_

- The folder in your drive must contain all segmentations and correct folder structure. 
- See [here](https://support.google.com/drive/answer/2375091?co=GENIE.Platform%3DDesktop&hl=en) how to organize your files in _Google Drive_.
- See this [stackoverflow post](https://stackoverflow.com/questions/46986398/import-data-into-google-colaboratory) for browsing files with the file browser

In [None]:
#@markdown Provide the path to the folder on your _Google Drive_
try:
    drive.mount('/content/drive')
    path = "/content/drive/My Drive/data" #@param {type:"string"}
    path = Path(path)
    print('Path contains the following files and folders: \n', L(os.listdir(path)))
    #@markdown Example: "/content/drive/My Drive/data"
except:
    print("Warning: Connecting to Google Drive only works on Google Colab.")
    pass

### Local Installation

If you're working on your local machine or server, provide a path to the correct folder.

In [None]:
#@markdown Provide path (either relative to notebook or absolute) and run cell
path = "data" #@param {type:"string"}
path = Path(path)
print('Path contains the following files and folders: \n', L(os.listdir(path)))

### Try with sample data

If you don't have any data available yet, try our sample data

In [None]:
#@markdown Run to use sample files
path = Path('sample_data_cFOS')
url = "https://github.com/matjesg/deepflash2/releases/download/model_library/wue1_cFOS_small.zip"
urllib.request.urlretrieve(url, 'sample_data_cFOS.zip')
unzip(path, 'sample_data_cFOS.zip')

## Check and load data

In [None]:
#@markdown Provide your parameters according to your provided data
image_folder = "images" #@param {type:"string"}
mask_folder = "masks" #@param {type:"string"}
mask_suffix = "_mask.png" #@param {type:"string"}
#@markdown Check if you are providing instance labels (class-aware and instance-aware)
instance_labels = False #@param {type:"boolean"}

f_names = get_image_files(path/image_folder)
label_fn = lambda o: path/mask_folder/f'{o.stem}{mask_suffix}'
#Check if corresponding masks exist
mask_check = [os.path.isfile(label_fn(x)) for x in f_names]
if len(f_names)==sum(mask_check) and len(f_names)>0:
    print(f'Found {len(f_names)} images and {sum(mask_check)} masks in "{path}".')
else:
    print(f'IMAGE/MASK MISMATCH! Found {len(f_names)} images and {sum(mask_check)} masks in "{path}".')
    print('Please check the steps above.')

Set [mask weights](https://matjesg.github.io/deepflash2/data.html#Weight-Calculation) parameters for training.
- Default values should work for most of the data. 
- However, this choice can significantly change the model performance later on.

In [None]:
#@markdown Run to set parameters and load data
border_weight_sigma=6 #@param {type:"number"}
foreground_dist_sigma=1 #@param {type:"number"}
border_weight_factor=50 #@param {type:"number"}
foreground_background_ratio=0.1 #@param {type:"number"}

mw_dict = {'bws': border_weight_sigma,
           'fds': foreground_dist_sigma, 
           'bwf': border_weight_factor,
           'fbr' : foreground_background_ratio}

#@markdown Number of classes: e.g., 2 for binary segmentation (foreground and background class)
n_classes = 2 #@param {type:"integer"}
ds = RandomTileDataset(f_names, label_fn, n_classes=n_classes, instance_labels=instance_labels, **mw_dict)

In [None]:
#@markdown Run to show data. { run: "auto" }
#@markdown Use the slider to control the number of displayed images
first_n = 3 #@param {type:"slider", min:1, max:100, step:1}
ds.show_data(max_n = first_n, figsize=(15,15), overlay=False)

## Model Defintion

- Select one of the available [model architectures](https://matjesg.github.io/deepflash2/models.html#U-Net-architectures).

In [None]:
#@markdown { run: "auto" }
model_arch = 'unet_deepflash2' #@param ["unet_deepflash2",  "unet_falk2019", "unet_ronnberger2015"]

- Select [pretraind](https://matjesg.github.io/deepflash2/) model weights
    - See here for data description
    - Select 'new' to use an untrained model (no pretrained weights)

In [None]:
pretrained_weights = "new" #@param ["new", "cFOS", "Parv"]
pre = False if pretrained_weights=="new" else True
n_channels = ds.get_data(max_n=1)[0].shape[-1]
model = torch.hub.load('matjesg/deepflash2', model_arch, pretrained=pre, n_classes=ds.c, in_channels=n_channels)
if pretrained_weights=="new": apply_init(model)

## Model Training

### Setting training paramers

- *mixed_precision_training*: enables [Mixed precision training](https://docs.fast.ai/callback.fp16#A-little-bit-of-theory)
    - decreases memory usage and speed-up training
    - may effect model accuracy
- *batch_size*: the number of samples that will be propagated through the network during one iteration
    - 4 works best in our experiements
    - 4-8 works good for [mixed precision training](https://docs.fast.ai/callback.fp16#A-little-bit-of-theory)
    

In [None]:
#@markdown Run to setup model for training.
mixed_precision_training = False #@param {type:"boolean"}
batch_size = 4 #@param {type:"slider", min:2, max:8, step:2}
cbs = [SaveModelCallback(monitor='iou'), ElasticDeformCallback]
metrics = [Dice_f1(), Iou()]
loss_fn = WeightedSoftmaxCrossEntropy(axis=1)
dls = DataLoaders.from_dsets(ds,ds, bs=batch_size)
if torch.cuda.is_available(): dls.cuda(), model.cuda()
learn = Learner(dls, model, metrics = metrics, wd=0.001, loss_func=loss_fn, cbs=cbs)
if mixed_precision_training: learn.to_fp16()

- **Learning rate**: controls how quickly or slowly a neural network model learns. <br>
We need to set the __maximum learning rate__ `max_lr` when traing with the **[One-Cycle-Policy](https://fastai1.fast.ai/callbacks.one_cycle.html#The-1cycle-policy)**. Using the **Learning Rate Finder** below, a good value for `max_lr` is somthing in the range of:
    - one tenth of the minimum before the divergence (_Minimum/10_)
    - when the slope is the steepest (_steepest point_) 
    - In our experiments, we found that a **maximum learning rate of 5e-4** (e.g., 0.0005) yielded the best results across experiments.

In [None]:
#@markdown Run to a find suitable learning rate
lr_min,lr_steep = learn.lr_find()
print(f"Minimum/10: {lr_min:.2e}, steepest point: {lr_steep:.2e}")

- *n_models*: Number of models to train.
    - If you're experimenting with parameters, try only one model first.
    - Depending on the data, ensembles should comprise 3-5 models.
    - Effects train-validation-split:
        - `n_models = 1` leads to a train-validation-split of 0.75/0.25
        - `n_models > 1` (model ensembles) to _[k-fold cross validation](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html)_
- *epochs*: One epoch is when an entire (augemented) dataset is passed through the model for training.
    - Epochs need to be adusted depending on the size and number of images
    - We found that choosing the number of epochs such that the network parameters are update about 1000 times (iterations) leads to satiesfying results in most cases.

In [None]:
#@title Set training paramers { run: "auto" }
#@markdown Set `max_lr` according to the learning rate finder.
max_lr = 5e-4 #@param {type:"number"}
#@markdown 
n_models = 1 #@param {type:"slider", min:1, max:5, step:1}
epochs = 30 #@param {type:"slider", min:1, max:100, step:1}
print("Suggested epochs for 1000 iterations:", calc_iterations(len(ds), batch_size, n_models))

### Train models

In [None]:
#@markdown Run to train model(s).<br/> **THIS CAN TAKE A FEW HOURS FOR MULTIPLE MODELS!**
kf = KFold(n_splits=max(n_models,2))
model_path = path/'models'
model_path.mkdir(parents=True, exist_ok=True)
res, res_mc = {}, {}
fold = 0
for train_idx, val_idx in kf.split(f_names):
    fold += 1
    name = f'model{fold}'
    print('Train', name)
    if n_models==1:
        files_train, files_val = train_test_split(f_names)
    else:
        files_train, files_val = f_names[train_idx], f_names[val_idx]
    print(f'Validation Images: {files_val}')    
    train_ds = RandomTileDataset(files_train, label_fn, **mw_dict)
    valid_ds = TileDataset(files_val, label_fn, **mw_dict)
    
    dls = DataLoaders.from_dsets(train_ds, valid_ds, bs=batch_size)
    dls_valid = DataLoaders.from_dsets(valid_ds, batch_size=batch_size ,shuffle=False, drop_last=False)
    model = torch.hub.load('matjesg/deepflash2', model_arch, pretrained=pre, n_classes=ds.c, in_channels=n_channels)
    if pretrained_weights=="new": apply_init(model)
    if torch.cuda.is_available(): dls.cuda(), model.cuda(), dls_valid.cuda()
        
    learn = Learner(dls, model, metrics = metrics, wd=0.001, loss_func=loss_fn, cbs=cbs)
    if mixed_precision_training: learn.to_fp16()
    learn.fit_one_cycle(epochs, max_lr)
    # save_model(model_path/f'{name}.pth', learn.model, opt=None)
    torch.save(learn.model.state_dict(), model_path/f'{name}.pth', _use_new_zipfile_serialization=False)
    
    smxs, segs, _ = learn.predict_tiles(dl=dls_valid.train)    
    smxs_mc, segs_mc, std = learn.predict_tiles(dl=dls_valid.train, mc_dropout=True, n_times=10)
    
    for i, file in enumerate(files_val):
        res[(name, file)] = smxs[i], segs[i]
        res_mc[(name, file)] = smxs_mc[i], segs_mc[i], std[i]
    
    if n_models==1:
        break

## Validate models

Here you can validate your models. To avoid information leakage, only predictions on the respective models' validation set are made.

In [None]:
#@markdown Create folders to save the resuls. They will be created at your provided 'path'.
pred_dir = 'val_preds' #@param {type:"string"}
pred_path = path/pred_dir/'ensemble'
pred_path.mkdir(parents=True, exist_ok=True)
uncertainty_dir = 'val_uncertainties' #@param {type:"string"}
uncertainty_path = path/uncertainty_dir/'ensemble'
uncertainty_path.mkdir(parents=True, exist_ok=True)
result_path = path/'results'
result_path.mkdir(exist_ok=True)

#@markdown Define `filetype` to save the predictions and uncertainties. All common filetypes are supported.
filetype = 'png' #@param {type:"string"}

In [None]:
#@markdown Show and save results
res_list = []
for model_number in range(1,n_models+1):
    model_name = f'model{model_number}'
    val_files = [f for mod , f in res.keys() if mod == model_name]
    print(f'Validating {model_name}')
    pred_path = path/pred_dir/model_name
    pred_path.mkdir(parents=True, exist_ok=True)
    uncertainty_path = path/uncertainty_dir/model_name
    uncertainty_path.mkdir(parents=True, exist_ok=True)
    for file in val_files:
        img = ds.get_data(file)[0]
        msk = ds.get_data(file, mask=True)[0]
        pred = res[(model_name,file)][1]
        pred_std = res_mc[(model_name,file)][2][...,0]
        df_tmp = pd.Series({'file' : file.name,
                            'model' : model_name,
                            'iou': iou(msk, pred),
                            'entropy': mean_entropy(pred_std)})
        plot_results(img, msk, pred, pred_std, df=df_tmp)
        res_list.append(df_tmp)
        imageio.imsave(pred_path/f'{file.stem}_pred.{filetype}', pred.astype(np.uint8) if np.max(pred)>1 else pred.astype(np.uint8)*255)
        imageio.imsave(uncertainty_path/f'{file.stem}_uncertainty.{filetype}', pred_std.astype(np.uint8)*255)
df_res = pd.DataFrame(res_list)
df_res.to_csv(result_path/f'val_results.csv', index=False)

## Download Section

- The models will always be the _last_ version trained in Section _Model Training_
- To download validation predictions and uncertainties, you first need to execute Section _Validate models_.

_Note: If you're connected to *Google Drive*, the models are automatically saved to your drive._

In [None]:
#@title Download models { run: "auto" }
model_number = "1" #@param ["1", "2", "3", "4", "5"]
model_path = path/'models'/f'model{model_number}.pth'
try:
    files.download(model_path)
except:
    print("Warning: File download only works on Google Colab.")
    print(f"Models are saved at {model_path.parent}")
    pass

In [None]:
#@markdown Download validation predicitions { run: "auto" }
out_name = 'val_predictions'
shutil.make_archive(path/out_name, 'zip', path/pred_dir)
try:
    files.download(path/f'{out_name}.zip')
except:
    print("Warning: File download only works on Google Colab.")
    pass

In [None]:
#@markdown Download validation uncertainties
out_name = 'val_uncertainties'
shutil.make_archive(path/out_name, 'zip', path/uncertainty_dir)
try:
    files.download(path/f'{out_name}.zip')
except:
    print("Warning: File download only works on Google Colab.")
    pass

In [None]:
#@markdown Download result analysis '.csv' files
try:
    files.download(result_path/f'val_results.csv')
except:
    print("Warning: File download only works on Google Colab.")
    pass