# Training of U-Net

This tutorial explains the creation of training data and the training of a U-Net neural network model for the prediction of sarcomere Z-bands from microscopy images. An analogous approach is used for the prediction of a cell mask. SarcAsM uses our package `bio-image-unet`, see [https://github.com/danihae/bio-image-unet](https://github.com/danihae/bio-image-unet). We strongly recommend using GPU-equipped workstation or server for training. Make sure that [CUDA toolkit](https://developer.nvidia.com/cuda-toolkit-archive) along the respective version of [PyTorch](https://pytorch.org/get-started/locally/) are installed and verify the installation by 
```python 
import torch
torch.cuda.is_available()
```

## Creation of training data set

For training a custom model, we recommend selecting a representative training dataset of 20-50 images. If the images have multiple channels or frames, they should be transformed to single 1-channel grayscale images.

We recommend the following folder structure for images and labels, with identical filenames for images and labels:

```
training_data/sarcomere_z_bands
|
├── image
│   ├── 105.tif
│   ├── 111.tif
│   ├── image123xyz.tif
│   ├── 121.tif
│   ├── 1.tif
│   ├── 2.tif
│   └── 83.tif
└── label
    ├── 105.tif
    ├── 111.tif
    ├── image123xyz.tif
    ├── 121.tif
    ├── 1.tif
    ├── 2.tif
    └── 83.tif
```

In [None]:
import numpy as np

# set paths of folders 
folder_images = '../../training_data_tutorial/image/'
folder_labels = '../../training_data_tutorial/label/'

### Generate small patches from large images or stacks (optional)

To enhance training efficiency and speed up annotation processes, randomly crop patches from large images or stacks. This ensures a representative selection of samples for training.

In [None]:
import glob
import os
import tifffile
import numpy as np

# create folders for patches
folder_images_patches = '../../training_data_tutorial/image_patch/'
folder_labels_patches = '../../training_data_tutorial/label_patch/'
os.makedirs(folder_images_patches, exist_ok=True)
os.makedirs(folder_labels_patches, exist_ok=True)

# list images or stacks
list_images = glob.glob(folder_images + '*.tif')

# patch size and number per image
patch_size = (512, 512)
n_patches = 6

# iterate through images and create random patches
np.random.seed(0)
for image in list_images:
    data = tifffile.imread(image)
    x_patches, y_patches = np.random.randint(0, data.shape[0]-patch_size[0], size=n_patches), np.random.randint(0, data.shape[1]-patch_size[1], size=n_patches)
    for x, y in zip(x_patches, y_patches):
        patch = data[y:y+patch_size[1], x:x+patch_size[0]]
        name = folder_images_patches + os.path.splitext(os.path.basename(image))[0] + f'_patch_{y}x{x}.tif'
        tifffile.imwrite(name, patch)

### Annotation of images

For annotation of sarcomere Z-bands or other targets, an application on tablet equipped with a pen, or a bio-image viewer (e.g., ImageJ or napari) can be used to create binary mask of sarcomere Z-bands, see Figure below. Here we demonstrate the annotation using a custom script built in [napari](https://napari.org/stable/) included in our package [bio-image-unet](https://github.com/danihae/bio-image-unet). It iterates through all images, when a annotation is finished, press "Save and Next" to proceed to the next image.

In [None]:
from biu.utils.image_annotator import ImageAnnotator

# Annotate all images, press "Save and Next" when image is finished
annotator = ImageAnnotator(
    folder_images=folder_images_patches,  
    output_folder=folder_labels_patches,  
    label_name='Z-bands',  
    brush_size=4
)

## Training

### Prepare and process training data

Prior to training, the training images and labels are processed and augmented. For the different options for processing and augmentation (add noise, blur, adjust contrast, ...) see docstring of `DataProcess`.

In [None]:
import biu.unet as unet

# select folders with images and labels, see above
folder_images_patches = '../../training_data_tutorial/image_patch/'
folder_labels_patches = '../../training_data_tutorial/label_patch/'

# temp folder
folder_temp = '../../training_data_tutorial/temp/'
os.makedirs(folder_temp, exist_ok=True)

# prepare and process training data
data = unet.DataProcess(folder_images_patches, folder_labels_patches, aug_factor=6, data_path=folder_temp)

### Set training parameters and train       

For different training parameters, check the docstring of `unet.Trainer`.

In [None]:
# set training parameters
trainer = unet.Trainer(dataset=data, num_epochs=100, loss_function='BCEDice')

# start training
trainer.start()

After training is completed, the model parameters `model.pth` are stored in the `folder_temp`. 