# Segmentation of white matter lesions from MRI with Faster RCNN

White matter lesions (also called white matter hyperintensities, silent cerebral infarcts, silent strokes, etc.) typically represent as bright spots on MRI T2 FLAIR images. A current active area of research is automatic identification and segmentation of these silent infarcts to derive positions, volumes and stereological features. This segmentation task is suited for a Region-based CNN approach. If a precise segmentation of the lesion is needed, then Mask RCNN is the appropriate approach.

In this project, since my previous project has proposed a seed-based semi-automatic segmentation approach based on Canny edge detection (cited in this [ref](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6192027/)), to avoid unnecessary training, I have chosen to apply **Faster RCNN** instead of Mask RCNN. Faster RCNN will only yield bounding boxes for each lesion instead of the segmentation mask. Subsequently, we can feed these bounding boxes into the semi-automatic segmentation engine to yield lesion masks.

Instead of implementing the network from scratch, I will use the **IceVision** library (https://airctic.com/0.11.0/) built upon the FastAI framework to train and generate predictions for the white matter lesions.

In [1]:
import icevision
from icevision.all import *
from icevision.models.torchvision import faster_rcnn
from icevision.models.checkpoint import *
import pandas as pd
import numpy as np
import fastai
import pickle
import os

In [2]:
already_trained = True # load model instead of training again

In [3]:
data_root = '/Users/chauvu/Documents/Chau/DataScience/Proj_Lesions_RCNN/lesions/'
data_csv = os.path.join(data_root, 'bbox.csv')
data_info = pd.read_csv(data_csv, header=0)

class_map_list = data_info['type'].unique().tolist()
class_map = ClassMap(class_map_list)

data_dir = Path(data_root)
template_record = ObjectDetectionRecord()
Parser.generate_template(template_record)

class MyParser(Parser):
    def __init__(self, template_record):
        super().__init__(template_record=template_record)
    def __iter__(self) -> Any:
    def __len__(self) -> int:
    def record_id(self, o: Any) -> Hashable:
    def parse_fields(self, o: Any, record: BaseRecord, is_new: bool):
        record.set_filepath(<Union[str, Path]>)
        record.set_img_size(<ImgSize>)
        record.detection.set_class_map(<ClassMap>)
        record.detection.add_labels(<Sequence[Hashable]>)
        record.detection.add_bboxes(<Sequence[BBox]>)


A **Parser** class is defined in order to load training and testing data into the appropriate format for `icevision`.

In [4]:
class LesionParser(Parser):
	def __init__(self, template_record, data_dir):
		super().__init__(template_record=template_record)

		self.data_dir = data_dir
		self.df = pd.read_csv(data_dir/'bbox.csv', header=0)
		self.class_map = ClassMap(self.df['type'].unique().tolist())

	def __iter__(self) -> Any:
		for o in self.df.itertuples():
			yield o

	def __len__(self) -> int:
		return len(self.df)

	def record_id(self, o) -> Hashable:
		return o.image_id

	def parse_fields(self, o, record, is_new):
		if is_new:
			filepath = self.data_dir / 'images' / f'{o.image_id}.jpg' 
			record.set_filepath(filepath)
			record.set_img_size(ImgSize(width = 30, height = 30))
			record.detection.set_class_map(self.class_map)
		record.detection.add_bboxes([BBox.from_xyxy(o.xmin, o.ymin, o.xmax, o.ymax)])
		record.detection.add_labels([o.type])

In [5]:
parser = LesionParser(template_record, data_dir) 
dsplit = icevision.data.RandomSplitter([0.8,0.1,0.1], seed=123)
train_records, valid_records, test_records = parser.parse(data_splitter=dsplit, show_pbar=False, autofix=False)

In [6]:
# show_records(train_records[0:2], ncols=2, display_label=False)

**Data augmentation** is performed on the training set, including shifting, scaling, rotating and cropping. The test set is not augmented. Both datasets are normalized.

In [7]:
presize = 512
size = 384
shift_scale_rotate = tfms.A.ShiftScaleRotate(rotate_limit=10)
crop_fn = partial(tfms.A.RandomSizedCrop, min_max_height=(size // 2, size), p=0.5)
train_tfms = tfms.A.Adapter(
    [
        *tfms.A.aug_tfms(
            size=size,
            presize=presize,
            shift_scale_rotate=shift_scale_rotate,
            crop_fn=crop_fn,
        ),
        tfms.A.Normalize(),
    ]
)
valid_tfms = tfms.A.Adapter([*tfms.A.resize_and_pad(size=size), tfms.A.Normalize()])
test_tfms = tfms.A.Adapter([*tfms.A.resize_and_pad(size=size), tfms.A.Normalize()])
train_ds = Dataset(train_records, train_tfms)
valid_ds = Dataset(valid_records, valid_tfms)
test_ds = Dataset(test_records, test_tfms)

As stated previously, I will use **Faster RCNN** model from the IceVision library. I will be using the ResNet50 as the backbone for the CNN classification. The checkpoint `chkpt_lesions.pth` is available for download.

In [8]:
model_type = faster_rcnn
backbone = model_type.backbones.resnet50_fpn
model = model_type.model(backbone=backbone(pretrained=True), num_classes=len(parser.class_map)) 

checkpoint_path = data_dir / 'chkpt_lesions.pth'
if already_trained:
	model = model_from_checkpoint(str(checkpoint_path), 
									    model_name='torchvision.faster_rcnn',
									    backbone_name='resnet50_fpn',
									    )
	model = model['model']
    
train_dl = model_type.train_dl(train_ds, batch_size=16, num_workers=0, shuffle=True)
valid_dl = model_type.valid_dl(valid_ds, batch_size=16, num_workers=0, shuffle=False)
test_dl = model_type.valid_dl(test_ds, batch_size=16, num_workers=0, shuffle=False)

load checkpoint from local path: /Users/chauvu/Documents/Chau/DataScience/Proj_Lesions_RCNN/lesions/chkpt_lesions.pth


Training is performed for 10 epochs. Since the data is so large and training is performed on a CPU, it took 2 days to complete. The model checkpoint is saved, here are the training losses and COCOMetrics for the 10 epochs.

epoch     train_loss  valid_loss  COCOMetric   
0         0.305333    0.292251    0.032314                         
1         0.307596    0.301169    0.050098                        
2         0.286604    0.293986    0.050614                        
3         0.298923    0.296777    0.101636                        
4         0.307370    0.286497    0.151044                         
5         0.280956    0.276258    0.182553                         
6         0.284154    0.275480    0.204247                          
7         0.258887    0.265800    0.222916                          
8         0.279240    0.264212    0.227343                         
9         0.278910    0.262303    0.236614  

In [9]:
metrics = [COCOMetric(metric_type=COCOMetricType.bbox)]
learn = model_type.fastai.learner(dls=[train_dl, valid_dl], model=model, metrics=metrics)

if not already_trained:
	learn.freeze()
	suggested_lr = learn.lr_find()
	cbs = [fastai.callback.tracker.SaveModelCallback(every_epoch=1)]
	learn.fine_tune(10, suggested_lr.valley, cbs=cbs)

	# SAVE CHECKPOINT
	save_icevision_checkpoint(model, 
                        model_name='torchvision.faster_rcnn', 
                        backbone_name='resnet50_fpn',
                        classes = class_map_list, 
                        filename = str(checkpoint_path),
                        meta = {'icevision_version': '0.11.0'})

In [10]:
training_metrics = pd.DataFrame(columns=['epoch', 'train_loss', 'valid_loss', 'COCOMetric'])
training_metrics = training_metrics.append({'epoch': 0, 'train_loss': 0.305333, 'valid_loss': 0.292251, 'COCOMetric': 0.032314}, ignore_index=True)
training_metrics = training_metrics.append({'epoch': 1, 'train_loss': 0.307596, 'valid_loss': 0.301169, 'COCOMetric': 0.050098}, ignore_index=True)
training_metrics = training_metrics.append({'epoch': 2, 'train_loss': 0.286604, 'valid_loss': 0.293986, 'COCOMetric': 0.050614}, ignore_index=True)
training_metrics = training_metrics.append({'epoch': 3, 'train_loss': 0.298923, 'valid_loss': 0.296777, 'COCOMetric': 0.101636}, ignore_index=True)
training_metrics = training_metrics.append({'epoch': 4, 'train_loss': 0.307370, 'valid_loss': 0.286497, 'COCOMetric': 0.151044}, ignore_index=True)
training_metrics = training_metrics.append({'epoch': 5, 'train_loss': 0.280956, 'valid_loss': 0.276258, 'COCOMetric': 0.182553}, ignore_index=True)
training_metrics = training_metrics.append({'epoch': 6, 'train_loss': 0.284154, 'valid_loss': 0.275480, 'COCOMetric': 0.204247}, ignore_index=True)
training_metrics = training_metrics.append({'epoch': 7, 'train_loss': 0.258887, 'valid_loss': 0.265800, 'COCOMetric': 0.222916}, ignore_index=True)
training_metrics = training_metrics.append({'epoch': 8, 'train_loss': 0.279240, 'valid_loss': 0.264212, 'COCOMetric': 0.227343}, ignore_index=True)
training_metrics = training_metrics.append({'epoch': 9, 'train_loss': 0.278910, 'valid_loss': 0.262303, 'COCOMetric': 0.236614}, ignore_index=True)
training_metrics

Unnamed: 0,epoch,train_loss,valid_loss,COCOMetric
0,0.0,0.305333,0.292251,0.032314
1,1.0,0.307596,0.301169,0.050098
2,2.0,0.286604,0.293986,0.050614
3,3.0,0.298923,0.296777,0.101636
4,4.0,0.30737,0.286497,0.151044
5,5.0,0.280956,0.276258,0.182553
6,6.0,0.284154,0.27548,0.204247
7,7.0,0.258887,0.2658,0.222916
8,8.0,0.27924,0.264212,0.227343
9,9.0,0.27891,0.262303,0.236614


In [11]:
infer_dl = model_type.infer_dl(test_ds, batch_size=16)
if not already_trained:
	preds = model_type.predict_from_dl(model, infer_dl, keep_images=True)
	with open(data_dir / 'preds_lesions.pkl', 'wb') as f:
		pickle.dump(preds, f)
else:
	with open(data_dir / 'preds_lesions.pkl', 'rb') as f:
		preds = pickle.load(f)

# show_preds(preds=preds[7:9])
# plt.show()

In [12]:
preds_gt1 = [p for p in preds if len(p.ground_truth.as_dict()['detection']['bboxes'])==1]

preds_dsc = []
for p in preds_gt1:
	p_dsc = []
	p_gt = p.ground_truth.as_dict()['detection']['bboxes'][0]
	p_prs = p.pred.as_dict()['detection']['bboxes']
	for p_pr in p_prs:
		x_left = max(p_gt.xmin, p_pr.xmin)
		y_top = max(p_gt.ymin, p_pr.ymin)
		x_right = min(p_gt.xmax, p_pr.xmax)
		y_bottom = min(p_gt.ymax, p_pr.ymax)

		if x_right < x_left or y_bottom < y_top:
			tp = 0
		else:
			tp = (x_right - x_left) * (y_bottom - y_top)

		fp = (p_pr.xmax - p_pr.xmin) * (p_pr.ymax - p_pr.ymin) - tp
		fn = (p_gt.xmax - p_gt.xmin) * (p_gt.ymax - p_gt.ymin) - tp
		dsc = 2*tp / (2*tp + fp + fn)
		p_dsc.append(dsc)
	preds_dsc.append(np.mean(p_dsc))
preds_dsc = [p for p in preds_dsc if not np.isnan(p)]

print(np.mean(preds_dsc))

0.6393224606541984


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


## Conclusion
The mean Dice Similarity Coefficient is 0.64 for the bounding boxes. This is very good since we are not interested in the actual area of the bounding boxes, but rather it relative location to the lesion. After getting the bounding boxes, we can select a central voxel within the bounding box as the seed into the semi-automatic lesion segmentation algorithm. Therefore, a 0.64 Dice score is very good because it proves that these predicted bounding boxes do cover part of the lesions, so choosing a seed within the boxes are reasonable to generate lesion masks.