In [None]:
import numpy as np
import pandas as pd
import sys
import matplotlib.pyplot as plt
import os
import numpy as np
sys.path.append('..')
from load_data import load_data, Dataset
from config import xvertseg_dir, verse2019_dir, resolution, patch_size

In [None]:
# load data from corresponding data dir
xvertseg_imgs, xvertseg_msks, xvertseg_scores = load_data(xvertseg_dir, resolution)
verse2019_imgs, verse2019_msks, verse2019_scores = load_data(verse2019_dir, resolution)

In [None]:
# stack data sets together
imgs = np.concatenate((xvertseg_imgs, verse2019_imgs))
msks = np.concatenate((xvertseg_msks, verse2019_msks))
scores = xvertseg_scores.append(verse2019_scores)

display(scores)

In [None]:
# make a dataset
np_scores = scores.to_numpy()
full_dataset = Dataset(np_scores, imgs, msks, patch_size)

### Statistics about this dataset
Things like:
* distribution of vertebra type
* the frequency of fractures (distribution when we have different cases and grades)

In [None]:
patches = full_dataset.get_patches()
sources = full_dataset.get_sources()
ids = full_dataset.get_ids()
vertebrae = full_dataset.get_vertebrae()
fractured = full_dataset.get_scores()

In [None]:
# for plotting
vert_names = np.asarray(['T{}'.format(i) for i in range(1, 13)] + ['L{}'.format(i) for i in range(1, 7)])

In [None]:
ind_xvert = np.argwhere(sources == 'xvertseg')
ind_verse = np.argwhere(sources == 'verse2019')
ind_fractured = np.argwhere(fractured)

# fractured counts
_, counts_xvert_frac = np.unique(vertebrae[np.intersect1d(ind_xvert, ind_fractured)], return_counts=True)
_, counts_verse_frac = np.unique(vertebrae[np.intersect1d(ind_verse, ind_fractured)], return_counts=True)
_, counts_both_frac = np.unique(vertebrae[ind_fractured], return_counts=True)

# all counts fractured counts
labels_xvert, counts_xvert = np.unique(vertebrae[ind_xvert], return_counts=True)
labels_verse, counts_verse = np.unique(vertebrae[ind_verse], return_counts=True)
labels_both, counts_both = np.unique(vertebrae, return_counts=True)

# hack but ok
counts_verse_frac = np.append(counts_verse_frac, [0])
counts_both_frac = np.append(counts_both_frac, [0])

counts_xvert_non_frac = counts_xvert - counts_xvert_frac
counts_verse_non_frac = counts_verse - counts_verse_frac
counts_both_non_frac = counts_both - counts_both_frac

plt.figure(figsize=(30, 6))
plt.subplot(1, 3, 1)
plt.bar(labels_xvert, counts_xvert_frac, align='center', label='Fractured')
plt.bar(labels_xvert, counts_xvert_non_frac, align='center', label='Healthy', bottom=counts_xvert_frac)
plt.title('Histogram: type of vertebrae in xVertSeg')
plt.ylabel('Nr of occurences')
plt.legend()
plt.gca().set_xticks(labels_xvert)
plt.gca().set_xticklabels(vert_names[-6:-1])
plt.gca().set_ylim((0, 100))
plt.subplot(1, 3, 2)
plt.bar(labels_verse, counts_verse_frac, align='center', label='Fractured')
plt.bar(labels_verse, counts_verse_non_frac, align='center', label='Healthy', bottom=counts_verse_frac)
plt.title('Histogram: type of vertebrae in Verse2019')
plt.ylabel('Nr of occurences')
plt.legend()
plt.gca().set_xticks(labels_both)
plt.gca().set_xticklabels(vert_names)
plt.gca().set_ylim((0, 100))
plt.subplot(1, 3, 3)
plt.bar(labels_both, counts_both_frac, align='center', label='Fractured')
plt.bar(labels_both, counts_both_non_frac, align='center', label='Healthy', bottom=counts_both_frac)
plt.title('Histogram: type of vertebrae in both')
plt.ylabel('Nr of occurences')
plt.legend()
plt.gca().set_xticks(labels_both)
plt.gca().set_xticklabels(vert_names)
plt.gca().set_ylim((0, 100))
plt.plot
plt.show()

## Sanity check: plot some patches in the set

In [None]:
plt.figure(figsize=(20, 20))

for n in range(25):
    i = np.random.randint(full_dataset.__len__())

    patch = patches[i]
    source = sources[i]
    ID =  ids[i]
    vertebra = vertebrae[i]
    frac = fractured[i]

    img, msk = patch[0], patch[1]
    plt_msk = np.ma.masked_where(msk == 0, msk)              
    mid_slice = img.shape[0] // 2 
    ind_label = vertebra - 8

    plt.subplot(5, 5, n + 1)
    plt.imshow(img[mid_slice, :, :], cmap='gray')
    plt.imshow(plt_msk[mid_slice, :, :], alpha=0.3, vmin=8, vmax=25)
    description = '{} ID: {} \n vert {} with label {} '.format(source, ID, vert_names[ind_label], vertebra)
    description += 'fractured' if frac else 'healthy'
    plt.title(description)

plt.tight_layout()
plt.show()