### Segmentation of the blood vessels and calculation of density and banching point density

*Step (b) of the block diagram in the methods section

In [None]:
# Initial imports and device setting

from pathlib import Path
from functools import partial
import os
import random

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import imgaug.augmenters as iaa

from torchtrainer.imagedataset import ImageSegmentationDataset
from torchtrainer import img_util
from torchtrainer import transforms
from torchtrainer.models.resunet import ResUNet
from torchtrainer.learner import Learner
from torchtrainer import perf_funcs

import pyvane
from pyvane import pipeline, image

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

### CNN training

In [3]:
def name_to_label_map(img_filename):  
    
    return img_filename.replace('.tiff', '.png')

def get_train_val_files(img_dir, val_split, seed=None):
    
    if seed is not None:
        random.seed(seed)

    files = os.listdir(img_dir)
    val_files = []
    train_files = []
    
    num_val_files = int(len(files)*val_split)
    val_files = random.sample(files, num_val_files)
    train_files = filter(lambda x: x not in val_files, files)

    return train_files, val_files

root_dir = Path('vessels')
img_dir = root_dir/'images'     # Original images
label_dir = root_dir/'labels'   # Labels

img_shape = (1104, 1376)
patch_shape = (256, 256)
valid_shape = (768, 768)
bs = 10                         # Batch size
lr = 0.01                       # Learning rate
epochs = 10

# Image augmentations
imgaug_seq = iaa.Sequential([
    iaa.CropToFixedSize(width=patch_shape[1], height=patch_shape[0]),
    iaa.GaussianBlur(sigma=[1, 2]),
    iaa.OneOf([
        iaa.GammaContrast(gamma=(0.5, 1.5)),
        iaa.pillike.EnhanceBrightness(factor=(0.1, 3.)),
        iaa.pillike.EnhanceContrast(factor=(0.2, 3.)),
        iaa.Add(value=(-20, 30))
    ]),
    iaa.AdditiveGaussianNoise(loc=0, scale=(0, 20)),
    iaa.Fliplr(0.5),
    iaa.Flipud(0.5),
    iaa.Crop(px=(0, 20)),
    iaa.CLAHE(clip_limit=6, tile_grid_size_px=12)
])
imgaug_seq = transforms.translate_imagaug_seq(imgaug_seq)
train_transforms = [transforms.TransfToImgaug(), imgaug_seq, transforms.TransfToTensor(), 
                    transforms.TransfWhitten(67.576, 37.556)]   # Statistics calculated over the whole dataset

imgaug_seq = iaa.Sequential([
    iaa.CenterCropToFixedSize(width=valid_shape[1], height=valid_shape[0]),
    iaa.CLAHE(clip_limit=6, tile_grid_size_px=12)
])    
imgaug_seq = transforms.translate_imagaug_seq(imgaug_seq)
valid_transforms = [transforms.TransfToImgaug(), imgaug_seq, transforms.TransfToTensor(), 
                    transforms.TransfWhitten(67.576, 37.556)]

img_opener_partial = partial(img_util.pil_img_opener, channel=None)
label_opener_partial = partial(img_util.pil_img_opener, is_label=True)

# Create ImageDataset instance
dataset = ImageSegmentationDataset(img_dir, label_dir, name_to_label_map=name_to_label_map,
                            img_opener=img_opener_partial, label_opener=label_opener_partial,
                            cache_size=10*10**9)

train_ds, valid_ds = dataset.split_train_val(0.2, seed=10)
train_ds.set_transforms(train_transforms)
valid_ds.set_transforms(valid_transforms)
train_dl = train_ds.dataloader(batch_size=bs, shuffle=True)
valid_dl = valid_ds.dataloader(batch_size=1, shuffle=False)

In [None]:
%%time
loss_func = perf_funcs.DiceLossRaw()

model = ResUNet(num_channels=1, num_classes=2) 
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, epochs=epochs, 
                                                steps_per_epoch=len(train_dl), pct_start=0.1)

learner = Learner(model, loss_func, optimizer, train_dl, valid_dl, scheduler, 
                       perf_funcs=perf_funcs.build_segmentation_accuracy_dict(), checkpoint_file='learner_vessel.tar',
                       main_perf_func='iou', scheduler_step_epoch=False)

learner.fit(epochs)

### Medial lines and graph creation

The code below creates a graph of the blood vessels.

In [None]:
# Some definitions

def load_roi(file):
    '''P0 tissue does not fill the whole sample. We need to consider only
    regions with tissue.'''
    
    if 'P0' in str(file):
        file_roi = str(file).replace('CD31(vessels)', 'masks').replace('CD31.tif', 'GFP.png')
        img_roi = plt.imread(file_roi).astype(np.uint8)
        return img_roi
    else:
        return None

root = 'Astrocytes/'

params = {
    # Pipeline params
    'save_steps': ('skeletonization', 'network'), 
    'roi_process': None,
    'roi_analysis': load_roi,
    # File params
    'batch_name':    '2D analysis astrocytes',
    'input_path':     root+'/CD31(vessels)/',
    'output_path':    root+'/pipeline_vessels/',
    'name_filter': None,
    'channel_to_use': None,
    # Skeletonization params
    'num_threads': 7,
    # Graph params
    'length_threshold': 9,    # Graph pruning length in pixels.
    # Measurement params
    'tortuosity_scale': 40,
}

In [None]:
class CNNSegmentation(pyvane.pipeline.BaseProcessor):
    '''Fake segmenter. It just reads the files segmented using the CNN model above.'''
    
    def __init__(self, model, checkpoint_file):
        pass
        
    def apply(self, img, file):

        file_binary = file.replace('.tif', '.png')
        img_final = np.array(Image.open(file_binary))
        img_final = img_final//255
        
        return image.Image(img_final.astype(np.uint8), img.path, pix_size=img.pix_size)
        
img_reader = partial(pipeline.read_and_adjust_img, channel=params['channel_to_use'], roi=params['roi_process'])
model = ResUNet(num_channels=1, num_classes=2)         
checkpoint_file = 'learner_vessel.tar'
segmenter = CNNSegmentation(model, checkpoint_file)
skeleton_builder = pipeline.DefaultSkeletonBuilder(
            num_threads=params['num_threads']
)
network_builder = pipeline.DefaultNetworkBuilder(
            length_threshold = params['length_threshold']
)
analyzer = pipeline.DefaultAnalyzer(
            tortuosity_scale=params['tortuosity_scale']
)
analyzer.load_roi = params['roi_analysis']

pipe = pipeline.BasePipeline(params['input_path'], img_reader, output_path=params['output_path'],
                             batch_name=params['batch_name'], name_filter=params['name_filter'])
pipe.set_processors(segmenter, skeleton_builder, network_builder, analyzer)

res = pipe.run()