# 2- Background segmentation

![title](../scratch/pipeline_diagram-2.png)

The purpose of this notebook is to train a neural network (UNET) for segmenting animals from the background.
The steps include:

* 2.1 Building an initial training set
* 2.2 Train the neural network (UNET)
* 2.3 Review results and possibly augment training set

In [1]:
import sys
sys.path.append('../')
from utils.io import *
import numpy as np
import cv2
import time

## 2.1 Build an initial training set

Scan through all recordings and save a subset of frames for manual annotation. It is assumed that the recordings were generated using the SoSeq-acquisition repository or converted to the correct format using the SoSeq-acquisition repository. 

### Input parameters

In [2]:
data_directory = '../../SoSeq-acquire/data_cropped/' # path to directory containing video recordings
                               # it is assumed that the recordings are all named by the convention:
                               # <data_directory>/<name>_color.mp4
                               # <data_directory>/<name>_depth.avi
            
frame_size = (540,480) # the dimensions of the video files
            
save_directory_train = data_directory + '/segmentation_train' # training images and annotations will be saved here
save_directory_test = data_directory + '/segmentation_test' # test images and annotations will be saved here
            
number_test_frames = 100 # number of test images to annotate - this test set will generally be fixed over time
number_train_frames = 200 # number of train images to annotate - this train set may be augmented over time

# it is faster to load a chunk of consecutive frames than to load each frame separately
buffer_size = 100 # the number of consecutive frames to load at once
buffer_skip = 20 # the gap between frames to be saved in the train and test sets

### Load video lengths and set up save directories

In [3]:
# get the length of each video
video_lengths = get_video_lengths(data_directory)

# create the training/testing directories
for directory in [save_directory_train,save_directory_test]:
    if not os.path.exists(directory): os.makedirs(directory)

### Randomly sample frames

In [4]:
videos = list(video_lengths.keys())
num_saved = 0
while num_saved < number_train_frames+number_test_frames:
    video = np.random.choice(videos)
    frame = np.random.randint(0,video_lengths[video])-buffer_size
    t = time.time()
    frame_buffer = read_color_frames(data_directory+'/'+video, range(frame,frame+buffer_size), frame_size=frame_size).squeeze()
    
    for i in range(0,buffer_size, buffer_skip):
        save_dir = (save_directory_train if num_saved < number_train_frames else save_directory_test)
        cv2.imwrite(save_dir+'/'+video+'_'+str(frame+i)+'.png',frame_buffer[i,:,:,::-1])
        num_saved += 1
    
    if num_saved < number_train_frames: print('\rSaved',num_saved,'out of',number_train_frames,'train frames',end='')
    elif num_saved == number_train_frames: print('')
    else: print('\rSaved',num_saved-number_train_frames,'out of',number_test_frames,'test frames',end='')

Saved 195 out of 200 train frames
Saved 100 out of 100 test frames

### Annotate frames

Frames can be annotated using labelme, as described below. If you are using a compyter with a graphical interface, then labelme can be started using by running ```labelme``` in the soseq environment. Otherwise, copy the train and test images to your local computer and run labelme from there. Instructions for installing labelme can be found on the [labelme githib page](https://github.com/wkentaro/labelme). The train and test images that have just been saved are in the following directories:

```
{{save_directory_train}}
{{save_directory_test}}
```

Once you have opened labelme, load the train/test images directory

![title](../scratch/labelme1-01.png)

Click "Create Polygons" and then click a circumference around one of the animals, finishing on the initial vertex

![title](../scratch/labelme2-01.png)

Upon completing the polygon, enter a name such as species name. Make sure to use the same label for each image and animal instance. You will be asked to input the name during network training. 

![title](../scratch/labelme3-01.png)

Outline each of the animals in frame and then click "Save" or use the save keyboard shortcut (command-shift-S on a mac). Labelme will automatically suggest a directory for saving. Do not change this directory. 

![title](../scratch/labelme4-01.png)

When two animals overlap, the outline of the occluded animal should only include its visible parts. Do not try to predict the outline of the hiden portions and avoid overlap between the polygons drawn for each animal. If one animal is fully bisected by the occlusion from another, then draw a separate polygon for each of the visible portions. 

![title](../scratch/labelme5-01.png)

When you have annotated all train and test frames, the image directories should contain a .json file for each original .png file. If annotations were performed on a separate computer, copy them back to the original image directories:

```
{{save_directory_train}}
{{save_directory_test}}
```

## 2.2 Train the neural network (UNET)

Aggregate all annotations into a pair of h5 files then build and train the neural network

### Input parameters

In [None]:
animal_name = 'mouse' # this string should match the name you used for the polygons in labelme


In [None]:
from skimage.draw import polygon
import json
import numpy as np
import h5py

def load_annotation(annotation_path, frame_size):
    segmentations = []
    labels = []
    for ii,shape in enumerate(json.load(open(annotation_path))['shapes']):
        seg = np.zeros((np.max(frame_size),np.max(frame_size)))
        poly = np.array(shape['points'])
        rr,cc = polygon(poly[:,0], poly[:,1], (np.max(frame_size),np.max(frame_size)))
        seg[cc,rr] = 255
        segmentations.append(seg[:frame_size[1],:frame_size[0],None])
        labels.append(shape['label'])
    segmentations = np.concatenate(segmentations, axis=2)
    return segmentations.astype(np.uint8), labels


def aggregate_annotations_background_segmentation(annotations_directory, frame_size, animal_name):
    dataset = h5py.File(annotations_directory+'/dataset.h5','w')
    for f in os.listdir(annotations_directory):
        if '.json' in f:
            segmentation,labels = load_annotation(f, frame_size)
            use_regions = np.array([label==animal_name for label in labels])
            if use_regions.sum() > 0:
                mask = np.any(segmentation[:,:,use_regions],axis=2)
                
                video_prefix = data_directory+'/'+f.split('color.mp4')[0]
                frame = int(f.split('_')[-1].split('.png')[0])
                color = read_color_frames(video_prefix+'color.mp4', [frame], frame_size=frame_size).squeeze()
                depth = read_depth_frames16(video_prefix+'depth.avi', [frame], frame_size=frame_size).squeeze()
                
                

#    19_11_20-C57_GRIN5-MR-000371392012_color.mp4_3861.png

In [None]:
import cv2
import numpy as np, os
import UNET.u_net as unet
from keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint, TensorBoard
import matplotlib.pyplot as plt, h5py
from keras.preprocessing.image import img_to_array

### Load the model

In [None]:
model = unet.get_unet_512(num_classes=1, input_shape=(512, 512, 4))
model.summary()

In [None]:
# Processing function for the training data
def train_process(data):
    img, mask = data
    img = cv2.resize(img, SIZE).astype(float)
    mask = cv2.resize(mask, SIZE).astype(float)
    img = img/255
    
    ff = img[:,:,3].flatten()>0
    for i in range(4):
        img[:,:,i] = img[:,:,i] / img[:,:,i].flatten()[ff].mean()

    mask = np.expand_dims(mask, axis=2)/255
    return (img, mask)

In [None]:
# Processing function for the validation data, no data augmentation
def validation_process(data):
    img, mask = data
    img = cv2.resize(img, SIZE).astype(float)
    mask = cv2.resize(mask, SIZE).astype(float)
    img = img/255.
    
    ff = img[:,:,3].flatten()>0
    for i in range(4):
        img[:,:,i] = img[:,:,i] / img[:,:,i].flatten()[ff].mean()
    
    mask = np.expand_dims(mask, axis=2)/255
    return (img, mask)

In [None]:
BATCH_SIZE = 1
training_dir = 'training_images_UNET/'

In [None]:
def train_generator_maker(training_paths):
    while True:
        ix = np.random.randint(500,len(training_paths),BATCH_SIZE)
        out_img = []
        out_mask = []
        for i in ix:
            data = np.load(training_paths[i])['arr_0'].item()
            img = data['image'].astype(float)
            mask = data['masks'].squeeze().astype(np.uint8)*255
            img,mask = train_process((img,mask))
            out_img.append(img); out_mask.append(mask)
        yield np.array(out_img), np.array(out_mask)

In [None]:
def validation_generator_maker(training_paths):
    while True:
        ix = np.random.randint(0,500,BATCH_SIZE)
        out_img = []
        out_mask = []
        for i in ix:
            data = np.load(training_paths[i])['arr_0'].item()
            img = data['image'].astype(float)
            mask = data['masks'].squeeze().astype(np.uint8)*255
            img,mask = train_process((img,mask))
            out_img.append(img); out_mask.append(mask)
        yield np.array(out_img), np.array(out_mask)

In [None]:
callbacks = [EarlyStopping(monitor='val_loss',
                           patience=10,
                           verbose=1,
                           min_delta=1e-4),
             ReduceLROnPlateau(monitor='val_loss',
                               factor=0.1,
                               patience=10,
                               verbose=1,
                               epsilon=1e-4),
             ModelCheckpoint(monitor='val_loss',
                             filepath='UNET/weights/best_weights.hdf5',
                             save_best_only=False,
                             save_weights_only=True)]

In [None]:
epochs=150
model.fit_generator(generator=train_generator,
                    steps_per_epoch=2000,
                    epochs=epochs,
                    callbacks=callbacks,
                    validation_data=validation_generator,
                    validation_steps=100)