Similar to exploration2, but this time including an addition dataset (Shenzhen) and not separating between right and left lung.

# Datasets

## Some publicly available datasets

- [JSRT](https://www.ajronline.org/doi/pdf/10.2214/ajr.174.1.1740071)
	- 247 chest X-rays, 154 have lung nodules. Has lung and heart seg.
	- [Get here](http://db.jsrt.or.jp/eng.php) (register at bottom of page)
- [Montgomery and Shenzhen](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4256233/)
	- Montgomery contains 138 chest X-rays, 80 healthy, 58 tuberculosis. Has lung seg.
	- Shenzhen contains 662 chest X-rays, 326 healthy, 336 tuberculosis. Has lung seg.
	- [Get both here](https://openi.nlm.nih.gov/faq?it=xg#collection). Look for "tuberculosis collection"

That’s 1047 images with lung segmentation label. There are larger datasets that have no segmentation label:

- [NIH ChestX-ray8](https://arxiv.org/abs/1705.02315)
	- 108,948 CXRs of 32,717 patients with eight text-mined disease labels
	- [this might be a way to download](https://nihcc.app.box.com/v/ChestXray-NIHCC)
- [NLST](https://www.nejm.org/doi/10.1056/NEJMoa1102873)
	- There's [this link](https://cdas.cancer.gov/publications/320/), which eventually leads [here](https://cdas.cancer.gov/datasets/nlst/), but I don't see any actual CXR images being made available.
 

### JSRT

See [this guide](JSRT_UsersGuide.pdf) to the data for details.

- `JPCLN***.IMG` for chest lung nodule images, and `JPCNN***.IMG` for non-nodule images. These are important classes to keep in mind for the purpose of proportional train/val/test split.
-  coordinates of the upper left of the image are `(0,0)`
- Image type: 16-bit Unsigned
- Width: 2048 pixels
- Height: 2048 pixels
- Offset to First Image: 0 bytes
- Gap Between Images: 0 bytes

You can load the images using [ImageJ](https://imagej.nih.gov/ij/).
Just import as "RAW" and put in the settings specified by the JSRT guide.

Hmm, stuck on this for now... also where are the segmentations? When I load RAW data into Image J I only see an xray image, and I see no reference to segmentation in the JSRT guide. The JSRT download page doesn't say anything about segmentation labels either.

### Shenzhen

[The readme](NLM-ChinaCXRSet-ReadMe.pdf).

- 336 cases with manifestation of tuberculosis, and 
- 326 normal cases.

- Format: PNG
- Image size varies for each X-ray. It is approximately 3K x 3K.

- Image file names are coded as `CHNCXR_#####_0/1.png`, where ‘0’ represents the normal and ‘1’
represents the abnormal lung. 

Segmentation can be obtained separately [here](https://www.kaggle.com/yoctoman/shcxr-lung-mask), and it was done manually by: "students and teachers of Computer Engineering Department, Faculty of Informatics and Computer Engineering, National Technical University of Ukraine "Igor Sikorsky Kyiv Polytechnic Institute", Kyiv, Ukraine." So, not necessarily medical experts.

### Montgomery

[The readme](NLM-MontgomeryCXRSet-ReadMe.pdf).

- 58 cases	with	manifestation	of	tuberculosis,	and	 80 normal	cases.
- Image	 file	 names	are	 coded	as	`MCUCXR_#####_0/1.png`, where	‘0’	 represents	 the	 normal	and	‘1’ represents	the	abnormal	lung. These are important classes to keep in mind for the purpose of proportional train/val/test split.

---

- Format:	PNG
- Matrix	size	is	4020	x	4892,	or	4892	x	4020.
- The	pixel	spacing	in	vertical	and	horizontal	directions	is	0.0875	mm.	
- Number	of	gray	levels	is	12 bits.

---

Segmentation:
> We	manually	generated	the	“gold	standard” segmentations	for	the	chest	X-ray	under	the	supervision	of a	radiologist.	We	used	the	following	conventions	for outlining	the	lung	boundaries:	Both	posterior	and	anterior	ribs	are	readily	visible	in	the	CXRs;	the	part	of	the	lung	behind	the	heart	is	excluded.	We	follow	anatomical	 landmarks	 such	 as	 the	 boundary	 of	 the	 heart,	 aortic	 arc/line,	 and	 pericardium	 line;	 and	sharp	costophrenic	angle	that	follow	the	diaphragm	boundary. We	draw	an	inferred	boundary	when	the	pathology	is	severe	and	affects	the	morphological	appearance	of	the	lungs. The	lung	boundaries	(left	and	right)	are	in	binary	image	format	and	have	the	same	file	name	as	chest	Xrays	( e.g.	`…/left/MCUCXR_#####_0/1.png` or	`…/right/MCUCXR_#####_0/1.png`). 

### Data used here

We will use Montgomery and Shenzhen

In [None]:
import os, glob

data_base_path = '/home/ebrahim/data/chest_xrays'

montgomery_imgs_path = os.path.join(data_base_path, 'MontgomerySet/CXR_png')
montgomery_segs_path_left = os.path.join(data_base_path, 'MontgomerySet/ManualMask/leftMask')
montgomery_segs_path_right = os.path.join(data_base_path, 'MontgomerySet/ManualMask/rightMask')

shenzhen_imgs_path = os.path.join(data_base_path, 'ChinaSet_AllFiles/CXR')
shenzhen_segs_path = os.path.join(data_base_path, 'ChinaSet_AllFiles/CXR_segs')

montgomery_imgs = glob.glob(os.path.join(montgomery_imgs_path, '*.png'))
montgomery_segs_left = glob.glob(os.path.join(montgomery_segs_path_left, '*.png'))
montgomery_segs_right = glob.glob(os.path.join(montgomery_segs_path_right, '*.png'))

shenzhen_imgs = glob.glob(os.path.join(shenzhen_imgs_path, '*.png'))
shenzhen_segs = glob.glob(os.path.join(shenzhen_segs_path, '*.png'))

# These happen to work the same way for both montgomery and shenzhen datasets
file_path_to_ID = lambda p : os.path.basename(p)[7:11]
file_path_to_abnormality = lambda p : bool(int(os.path.basename(p)[12]))

montgomery_img_ids = list(map(file_path_to_ID,montgomery_imgs))
montgomery_seg_ids_left = list(map(file_path_to_ID,montgomery_segs_left))
montgomery_seg_ids_right = list(map(file_path_to_ID,montgomery_segs_right))

shenzhen_img_ids = list(map(file_path_to_ID,shenzhen_imgs))
shenzhen_seg_ids = list(map(file_path_to_ID,shenzhen_segs))

# See "look over data to sanity check" section below.
# I used that cell to actually look over the segmentations by hand, because some of them were bad
# I excluded segmentations that ignored the heart boundary, since that is not consistent with montgomery set segs
# I excluded segmentations that were clearly impossible given how the diaphram can be arranged
shenzhen_excluded_ids = [
    '0043', '0242', '0439', '0518', '0506', '0635', '0337',
    '0603', '0455', '0254', '0610', '0028', '0511', '0059',
    '0031', '0297', '0032', '0327', '0569', '0030', '0178',
    '0295', '0298', '0003', '0022', '0020', '0098', '0033', 
    '0250', '0423', '0066', '0464', '0381', '0289', '0253',
    '0338', '0076', '0621', '0243', '0016', '0241', '0339',
    '0324', '0007', '0660', '0285', '0294', '0286', '0281', 
    '0047', '0527', '0013', '0002', '0046', '0370', '0293', 
    '0372', '0361', '0290', '0474', '0513', '0090', 
]

data = []
for img in montgomery_imgs:
    img_id = file_path_to_ID(img)
    seg_left = montgomery_segs_left[montgomery_seg_ids_left.index(img_id)]
    seg_right = montgomery_segs_right[montgomery_seg_ids_right.index(img_id)]
    tuberculosis = file_path_to_abnormality(img)
    data.append({
        'img' : img,
        'mo_seg_left' : seg_left, # mo for montgomery
        'mo_seg_right' : seg_right,
        'tuberculosis' : tuberculosis,
        'id' : 'montgomery:'+img_id,
        'source' : "montgomery"
    })
skipped_no_seg = 0
skipped_bad = 0
for img in shenzhen_imgs:
    img_id = file_path_to_ID(img)
    if img_id not in shenzhen_seg_ids:
        skipped_no_seg += 1
        continue
    if img_id in shenzhen_excluded_ids:
        skipped_bad += 1
        continue
    seg = shenzhen_segs[shenzhen_seg_ids.index(img_id)]
    tuberculosis = file_path_to_abnormality(img)
    data.append({
        'img' : img,
        'sh_seg' : seg, # sh for shenzhen
        'tuberculosis' : tuberculosis,
        'id' : 'shenzhen:'+img_id,
        'source' : "shenzhen"
    })
if skipped_no_seg>0:
    print(f"{skipped_no_seg} of the shenzhen images do not have an associated segmentation, and they were skipped.")
    print(f"{skipped_bad} of the shenzhen images were marked for exclusion due to questionable segmentation.")    

In [None]:
import monai
import matplotlib.pyplot as plt
import numpy as np
import torch
from util import UnionMasksD, rgb_to_grayscale, list_data_collate_no_meta
from segmentation_post_processing import SegmentationPostProcessing

monai.utils.misc.set_determinism(seed=9274)

In [None]:
data_train, data_valid = monai.data.utils.partition_dataset_classes(
    data,
    classes = list(map(lambda d : str(d['tuberculosis'])+d['source'], data)),
    ratios = (8,2)
)

In [None]:
image_size = 256

keys_to_delete = ['mo_seg_left', 'mo_seg_right', 'sh_seg']
keys_to_delete += [k+"_meta_dict" for k in keys_to_delete] + [k+"_transforms" for k in keys_to_delete]

transform_valid = monai.transforms.Compose([
    monai.transforms.LoadImageD(reader='itkreader',keys = ['img']), # A few shenzhen images get mysteriously value-inverted with readers other than itkreader
    monai.transforms.LambdaD(keys=['img'], func = rgb_to_grayscale), # A few of the shenzhen imgs are randomly RGB encoded rather than grayscale colormap
    monai.transforms.LoadImageD(keys = ['mo_seg_left', 'mo_seg_right', 'sh_seg'], dtype="int8", allow_missing_keys=True),
    monai.transforms.TransposeD(keys = ['img', 'mo_seg_left', 'mo_seg_right', 'sh_seg'], indices = (1,0), allow_missing_keys=True),
    monai.transforms.AddChannelD(keys = ['img']),
    UnionMasksD(keys = ['mo_seg_left', 'mo_seg_right'], keyList=['mo_seg_left', 'mo_seg_right'], newKeyName='seg'),
    UnionMasksD(keys = ['sh_seg',], keyList=['sh_seg'], newKeyName='seg'), # using for one-hot conversion, not "union"
    monai.transforms.DeleteItemsD(keys = keys_to_delete),
    monai.transforms.ResizeD(
        keys = ['img', 'seg'],
        spatial_size=(image_size,image_size),
        mode = ['bilinear', 'nearest']
    ),
    monai.transforms.ToTensorD(keys = ['img', 'seg']),
])

transform_train = monai.transforms.Compose([
    transform_valid,
    monai.transforms.RandCoarseDropoutd(
        keys = ['img'],
        holes = 1,
        max_holes=3,
        spatial_size=image_size//32,
        max_spatial_size=image_size//4,
        prob=0.5,
        fill_value=255
    ),
    monai.transforms.RandCoarseDropoutd(
        keys = ['img'],
        holes = 1,
        max_holes=3,
        spatial_size=image_size//32,
        max_spatial_size=image_size//4,
        prob=0.5,
        fill_value=0
    ),
    monai.transforms.RandZoomD(
        keys = ['img', 'seg'],
        mode = ['bilinear', 'nearest'],
        prob=1.,
        padding_mode="constant",
        min_zoom = 0.7,
        max_zoom=1.3,
    ),
    monai.transforms.RandRotateD(
        keys = ['img', 'seg'],
        mode = ['bilinear', 'nearest'],
        prob=1.,
        range_x = np.pi/8,
        padding_mode="zeros",
    ),
    monai.transforms.RandGaussianSmoothD(
        keys = ['img'],
        prob = 0.4
    ),
    monai.transforms.RandAdjustContrastD(
        keys = ['img'],
        prob=0.4,
    ),
    monai.transforms.ToNumpyD(keys=['img']),
    monai.transforms.RandHistogramShiftD(
        keys = ['img'],
        prob=0.0,
    ),
    monai.transforms.ToTensorD(keys=['img', 'seg']),
])

# Look over data to sanity check

There are some problematic segmenations in the Shenzhen set, e.g. some include the heart, some are very messy

Catch them here and add fix them to exclusion list above.

In [None]:
#Initialization cell for data preview procedure -- run this once

binary_mask = lambda x : (x!=0).astype('float')
bdry = lambda s : binary_mask((np.abs(np.diff(s, axis=0, prepend=0)) + np.abs(np.diff(s, axis=1, prepend=0)))!=0)
bdry_thick = lambda s : binary_mask(bdry(bdry(s)) + bdry(s))

i = 0

In [None]:
# Iteration cell -- run this several times till you have run through all the data or you are satisfied
num_rows = 2
fig, axs = plt.subplots(num_rows,5,figsize=(20,5*num_rows))
for ax in axs.reshape(-1):
    if i>=len(data): break

    d = data[i]
    d = transform_valid(d)
    
    im = d['img'].expand((3,)+d['img'].shape[1:])
    im = im/im.max()
    seg = d['seg'].float()
    seg_bdry = bdry_thick(seg[1])
    im[0,seg_bdry==1.] = 1 # R
    im[1,seg_bdry==1.] = 0 # G
    im[2,seg_bdry==1.] = 0 # B
    im = np.transpose(im,axes=(1,2,0))
    ax.imshow(im, cmap='bone')
    
    ax.set_title(f"i={i}, id={d['id']}")
    
    
    i+=1
plt.plot();

# Previewing

In [None]:
def preview(data_item, show_bdry = False):
    fig = plt.figure(figsize=(7,7))
    im = data_item['img'].expand((3,)+data_item['img'].shape[1:])
    im = im/im.max()
    seg = data_item['seg'].float()
    im[1,:,:] *= 1-0.3*seg[1,:,:]
    if show_bdry:
        seg_bdry = bdry(seg[1])
        mask = (seg_bdry == 1.)
        im[0,mask], im[1,mask], im[2,mask] = 1,0,0 # R, G, B
    im = np.transpose(im,axes=(1,2,0))
    plt.imshow(im, cmap='bone')
    plt.plot();

In [None]:
dataset_train = monai.data.CacheDataset(data_train, transform_train)
dataset_valid = monai.data.CacheDataset(data_valid, transform_valid)
# dataset_train = monai.data.Dataset(data_train, transform_train)
# dataset_valid = monai.data.Dataset(data_valid, transform_valid)

In [None]:
import random
i = random.choice(range(len(dataset_train)))
d = dataset_train[i]
preview(d, show_bdry=True)
print(d['id'])

# seg net

Structure of U-Net is inspired by this paper: https://arxiv.org/abs/1703.08770

But it's not exactly the same. In the paper there's one giant deconvolution step at the end, instead of having a symmetric looking unet.

In [None]:
spatial_dims = 2;
image_channels = 1;
seg_channels = 2; # lung, background
seg_net_channel_seq = (8,16,32,32,32,64,64,64)
stride_seq = (2,2,2,2,1,2,1) # I don't know why, but MONAI unet insists on this being one shorter than I expect,
# and then it forces a stride of 1 at that last step.
dropout_seg_net = 0.5
num_res_units = 2

seg_net = monai.networks.nets.UNet(
    spatial_dims = spatial_dims,
    in_channels = image_channels,
    out_channels = seg_channels, 
    channels = seg_net_channel_seq,
    strides = stride_seq,
    dropout = dropout_seg_net,
    num_res_units = num_res_units
)

num_params = sum(p.numel() for p in seg_net.parameters())
print(f"seg_net has {num_params} parameters")

# Loss

In [None]:
dice_loss = monai.losses.DiceLoss(
    to_onehot_y = False, # the segs we pass in are already in one-hot form
    softmax = True, # Note that our segmentation network is missing the softmax at the end
)

In [None]:
# Test drive
data_item = dataset_train[42]
seg_pred = seg_net(data_item['img'].unsqueeze(0)) # shape is (1,3,1024,1024), which is (B,N,H,W)

dice_loss(
    seg_net(data_item['img'].unsqueeze(0)), # input, one-hot
    data_item['seg'].unsqueeze(0), # target, one-hot
)

# Previewing seg net outputs

In [None]:
def preview_seg_net(data_item, figsize=(15,10), print_loss = True, show_heatmap = False, show_bdry=False, show_post_processing=0):
    """
    Preview seg net prediciton
    
    Args:
        data_item: A data item to input into seg_net.
        figsize: figure size to be used at each matplotlib plotting call
        print_loss: show Dice loss
        show_heatmap: whether to show class probability image
        show_bdry: whether to draw the boundry
        show_post_processing: 0 to not show it,
            1 to show post processed result,
            2 to show post processed result and intermediate steps
    """
    
    seg_net.eval()
    
    with torch.no_grad():
        im_device = data_item['img'].to(next(seg_net.parameters()).device.type)
        seg_pred = seg_net(im_device.unsqueeze(0))[0].cpu()
        _, max_indices = seg_pred.max(dim=0)
        seg_pred_mask = (max_indices==1).type(torch.uint8)

        f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize)

        im = data_item['img'].expand((3,)+data_item['img'].shape[1:])
        im = im/im.max()

        seg_true = data_item['seg'].float()
        im_true = im.clone()
        im_true[1,:,:] *= 1-0.4*seg_true[1,:,:]
        if show_bdry:
            seg_true_bdry = bdry(seg_true[1])
            mask = (seg_true_bdry == 1.)
            im_true[0,mask], im_true[1,mask], im_true[2,mask] = 1,0,0 # R, G, B
        im_true = np.transpose(im_true,axes=(1,2,0))
        ax1.imshow(im_true, cmap='bone')
        ax1.set_title("true seg overlay")
        ax1.axis('off')

        ax2.imshow(max_indices)
        ax2.set_title("predicted seg")
        ax2.axis('off')

        im_pred = im.clone()
        im_pred[1,:,:] *= 1-0.4*seg_pred_mask
        if show_bdry:
            seg_pred_bdry = bdry(seg_pred_mask)
            mask = (seg_pred_bdry == 1.)
            im_pred[0,mask], im_pred[1,mask], im_pred[2,mask] = 1,0,0 # R, G, B
        im_pred = np.transpose(im_pred,axes=(1,2,0))
        ax3.imshow(im_pred, cmap='bone')
        ax3.set_title("predicted seg overlay")
        ax3.axis('off')

        plt.show();
        
        if show_heatmap:
            f, ax1 = plt.subplots(1, 1, figsize=figsize)
            ax1.imshow(seg_pred.softmax(dim=0)[1])
            ax1.axis('off')
            print("predicted seg class probability maps:")
            plt.show()
        
        if show_post_processing!=0:
            plt.figure(figsize = figsize)
            seg_post_process = SegmentationPostProcessing()
            seg_pred_processed = seg_post_process(seg_pred_mask)
            im_pred = im.clone()
            im_pred[1,:,:] *= 1-0.4*(seg_pred_processed==1)
            im_pred[0,:,:] *= 1-0.4*(seg_pred_processed==2)
            if show_bdry:
                seg_pred_bdry1 = bdry(seg_pred_processed==1)
                seg_pred_bdry2 = bdry(seg_pred_processed==2)
                mask1 = (seg_pred_bdry1 == 1.)
                mask2 = (seg_pred_bdry2 == 1.)
                im_pred[0,mask1], im_pred[1,mask1], im_pred[2,mask1] = 1,0,0 # R, G, B
                im_pred[0,mask2], im_pred[1,mask2], im_pred[2,mask2] = 0,1,0 # R, G, B
            im_pred = np.transpose(im_pred,axes=(1,2,0))
            plt.imshow(im_pred, cmap='bone')
            plt.title("post-processed segmentation overlay")
            plt.axis('off')
            plt.show()
            if show_post_processing>1:
                seg_post_process.preview_intermediate_steps()

        if print_loss:
            loss = dice_loss(
                seg_pred.unsqueeze(0),
                data_item['seg'].unsqueeze(0),
            )
            print(f"Dice loss: {loss.item()}")

In [None]:
# Try seg_net on a random image.
preview_seg_net(random.choice(dataset_train), show_bdry=True);

In [None]:
# verify that I haven't messed up the segmentation labels in my transforms
dataset_train[0]['seg'].unique()

# Training

In [None]:
seg_net.to('cuda')

dataloader_train = monai.data.DataLoader(
    dataset_train,
    batch_size=16,
    num_workers=8,
    shuffle=True,
    collate_fn = list_data_collate_no_meta
)

dataloader_valid = monai.data.DataLoader(
    dataset_valid,
    batch_size=64,
    num_workers=8,
    shuffle=False,
    collate_fn = list_data_collate_no_meta
)

learning_rate = 1e-3
optimizer = torch.optim.Adam(seg_net.parameters(), learning_rate)

epoch_number = 0
training_losses = [] 
validation_losses = []
preview_index = random.choice(range(len(dataset_valid)))
best_validation_loss = 99999;

In [None]:
max_epochs = 40
while epoch_number < max_epochs:
    
    print(f"Epoch {epoch_number+1}/{max_epochs} ...")
    
    if (epoch_number%5==0):
        preview_seg_net(dataset_valid[preview_index], figsize=(6,6), print_loss=False);
    
    seg_net.train()
    losses = []
    for batch in dataloader_train:
        imgs = batch['img'].to('cuda')
        true_segs = batch['seg'].to('cuda')

        optimizer.zero_grad()
        predicted_segs = seg_net(imgs)
        loss = dice_loss(predicted_segs, true_segs)
        loss.backward()
        optimizer.step()

        losses.append(loss.item())
    
    training_loss = np.mean(losses)
    training_losses.append([epoch_number, training_loss])
    
    print(f"\ttraining loss: {training_loss}")

    if (epoch_number%5==0):
    
        seg_net.eval()
        losses = []
        with torch.no_grad():
            for batch in dataloader_valid:
                imgs = batch['img'].to('cuda')
                true_segs = batch['seg'].to('cuda')
                predicted_segs = seg_net(imgs)
                loss = dice_loss(predicted_segs, true_segs)
                losses.append(loss.item())
            validation_loss = np.mean(losses)

        print(f"\tvalidation loss: {validation_loss}")
        
        validation_losses.append([epoch_number, validation_loss])
        
        if validation_loss < best_validation_loss:
            best_validation_loss = validation_loss
            torch.save(seg_net.state_dict(),f'seg_net_bestval.pth')
    
    epoch_number +=1

del imgs, true_segs, predicted_segs, loss
torch.cuda.empty_cache()

In [None]:
#Try on a random validation image
data_item_index = random.choice(range(len(dataset_valid)))
print(data_item_index)
data_item = dataset_valid[data_item_index]
with torch.no_grad():
    preview_seg_net(data_item, show_heatmap=False, show_bdry=True, show_post_processing=1);

In [None]:
run_id = '0016'
if (os.path.exists(f'seg_net{run_id}.pth')):
    del run_id
    raise Exception("Please change run_id so you don't overwrite things.")

In [None]:
# CHECKPOINT CELL; LOAD
# seg_net.load_state_dict(torch.load(f'seg_net_bestval.pth'))
# seg_net.load_state_dict(torch.load(f'seg_net{run_id}.pth'))

In [None]:
# CHECKPOINT CELL; SAVE
torch.save(seg_net.state_dict(),f'seg_net{run_id}.pth') # Save parameters dict
torch.save(seg_net,f'seg_net_model{run_id}.pth') # Save entire model

In [None]:
def plot_against_epoch_numbers(epoch_value_pairs, label):
    array = np.array(epoch_value_pairs)
    plt.plot(array[:,0], array[:,1], label=label)

plot_against_epoch_numbers(training_losses, label="training")
plot_against_epoch_numbers(validation_losses, label="validation")
plt.legend()
plt.xlabel('epoch')
plt.ylabel('dice loss')
plt.title('seg net training')
plt.savefig(f'seg_net_losses{run_id}.png')
plt.show()