In [None]:
#default_exp medical.imaging

# Medical Imaging

> Helpers for working with DICOM files

In [None]:
#export 
from fastai2.basics import *
from fastai2.vision.all import *
from fastai2.data.transforms import *

import pydicom,kornia,skimage
from pydicom.dataset import Dataset as DcmDataset
from pydicom.tag import BaseTag as DcmTag
from pydicom.multival import MultiValue as DcmMultiValue
from PIL import Image

try:
    import cv2
    cv2.setNumThreads(0)
except: pass

In [None]:
from nbdev.showdoc import *
matplotlib.rcParams['image.cmap'] = 'bone'

In [None]:
#export
_all_ = ['DcmDataset', 'DcmTag', 'DcmMultiValue', 'dcmread', 'get_dicom_files']

## Patching

In [None]:
#export
def get_dicom_files(path, recurse=True, folders=None):
    "Get dicom files in `path` recursively, only in `folders`, if specified."
    return get_files(path, extensions=[".dcm"], recurse=recurse, folders=folders)

In [None]:
#export
@patch
def dcmread(fn:Path, force = False):
    "Open a `DICOM` file"
    return pydicom.dcmread(str(fn), force)

In [None]:
#export
class TensorDicom(TensorImage): _show_args = {'cmap':'gray'}

In [None]:
#export
class PILDicom(PILBase):
    _open_args,_tensor_cls,_show_args = {},TensorDicom,TensorDicom._show_args
    @classmethod
    def create(cls, fn:(Path,str,bytes), mode=None)->None:
        "Open a `DICOM file` from path `fn` or bytes `fn` and load it as a `PIL Image`"
        if isinstance(fn,bytes): im = Image.fromarray(pydicom.dcmread(pydicom.filebase.DicomBytesIO(fn)).pixel_array)
        if isinstance(fn,Path): im = Image.fromarray(dcmread(fn).pixel_array)
        im.load()
        im = im._new(im.im)
        return cls(im.convert(mode) if mode else im)

PILDicom._tensor_cls = TensorDicom

In [None]:
# #export
# @patch
# def png16read(self:Path): return array(Image.open(self), dtype=np.uint16)

In [None]:
TEST_DCM = Path('images/sample.dcm')
dcm = TEST_DCM.dcmread()

In [None]:
#export
@patch_property
def pixels(self:DcmDataset):
    "`pixel_array` as a tensor"
    return tensor(self.pixel_array.astype(np.float32))

In [None]:
#export
@patch_property
def scaled_px(self:DcmDataset):
    "`pixels` scaled by `RescaleSlope` and `RescaleIntercept`"
    img = self.pixels
    return img if self.Modality == "CR" else img * self.RescaleSlope + self.RescaleIntercept

In [None]:
#export
def array_freqhist_bins(self, n_bins=100):
    "A numpy based function to split the range of pixel values into groups, such that each group has around the same number of pixels"
    imsd = np.sort(self.flatten())
    t = np.array([0.001])
    t = np.append(t, np.arange(n_bins)/n_bins+(1/2/n_bins))
    t = np.append(t, 0.999)
    t = (len(imsd)*t+0.5).astype(np.int)
    return np.unique(imsd[t])

In [None]:
#export
@patch
def freqhist_bins(self:Tensor, n_bins=100):
    "A function to split the range of pixel values into groups, such that each group has around the same number of pixels"
    imsd = self.view(-1).sort()[0]
    t = torch.cat([tensor([0.001]),
                   torch.arange(n_bins).float()/n_bins+(1/2/n_bins),
                   tensor([0.999])])
    t = (len(imsd)*t).long()
    return imsd[t].unique()

In [None]:
#export
@patch
def hist_scaled_pt(self:Tensor, brks=None):
    # Pytorch-only version - switch to this if/when interp_1d can be optimized
    if brks is None: brks = self.freqhist_bins()
    brks = brks.to(self.device)
    ys = torch.linspace(0., 1., len(brks)).to(self.device)
    return self.flatten().interp_1d(brks, ys).reshape(self.shape).clamp(0.,1.)

In [None]:
#export
@patch
def hist_scaled(self:Tensor, brks=None):
    if self.device.type=='cuda': return self.hist_scaled_pt(brks)
    if brks is None: brks = self.freqhist_bins()
    ys = np.linspace(0., 1., len(brks))
    x = self.numpy().flatten()
    x = np.interp(x, brks.numpy(), ys)
    return tensor(x).reshape(self.shape).clamp(0.,1.)

In [None]:
#export
@patch
def hist_scaled(self:DcmDataset, brks=None, min_px=None, max_px=None):
    px = self.scaled_px
    if min_px is not None: px[px<min_px] = min_px
    if max_px is not None: px[px>max_px] = max_px
    return px.hist_scaled(brks=brks)

In [None]:
#export
@patch
def windowed(self:Tensor, w, l):
    px = self.clone()
    px_min = l - w//2
    px_max = l + w//2
    px[px<px_min] = px_min
    px[px>px_max] = px_max
    return (px-px_min) / (px_max-px_min)

In [None]:
#export
@patch
def windowed(self:DcmDataset, w, l):
    return self.scaled_px.windowed(w,l)

In [None]:
#export
# From https://radiopaedia.org/articles/windowing-ct
dicom_windows = types.SimpleNamespace(
    brain=(80,40),
    subdural=(254,100),
    stroke=(8,32),
    brain_bone=(2800,600),
    brain_soft=(375,40),
    lungs=(1500,-600),
    mediastinum=(350,50),
    abdomen_soft=(400,50),
    liver=(150,30),
    spine_soft=(250,50),
    spine_bone=(1800,400)
)

In [None]:
#export
class TensorCTScan(TensorImageBW): _show_args = {'cmap':'bone'}

In [None]:
#export
class PILCTScan(PILBase): _open_args,_tensor_cls,_show_args = {},TensorCTScan,TensorCTScan._show_args

In [None]:
#export
@patch
@delegates(show_image)
def show(self:DcmDataset, scale=True, cmap=plt.cm.bone, min_px=-1100, max_px=None, **kwargs):
    px = (self.windowed(*scale) if isinstance(scale,tuple)
          else self.hist_scaled(min_px=min_px,max_px=max_px,brks=scale) if isinstance(scale,(ndarray,Tensor))
          else self.hist_scaled(min_px=min_px,max_px=max_px) if scale
          else self.scaled_px)
    show_image(px, cmap=cmap, **kwargs)

In [None]:
scales = False, True, dicom_windows.brain, dicom_windows.subdural
titles = 'raw','normalized','brain windowed','subdural windowed'
for s,a,t in zip(scales, subplots(2,2,imsize=4)[1].flat, titles):
    dcm.show(scale=s, ax=a, title=t)

In [None]:
dcm.show(cmap=plt.cm.gist_ncar, figsize=(6,6))

In [None]:
#export
@patch
def pct_in_window(dcm:DcmDataset, w, l):
    "% of pixels in the window `(w,l)`"
    px = dcm.scaled_px
    return ((px > l-w//2) & (px < l+w//2)).float().mean().item()

In [None]:
dcm.pct_in_window(*dicom_windows.brain)

In [None]:
#export
def uniform_blur2d(x,s):
    w = x.new_ones(1,1,1,s)/s
    # Factor 2d conv into 2 1d convs
    x = unsqueeze(x, dim=0, n=4-x.dim())
    r = (F.conv2d(x, w, padding=s//2))
    r = (F.conv2d(r, w.transpose(-1,-2), padding=s//2)).cpu()[:,0]
    return r.squeeze()

In [None]:
ims = dcm.hist_scaled(), uniform_blur2d(dcm.hist_scaled(),50)
show_images(ims, titles=('orig', 'blurred'))

In [None]:
#export
def gauss_blur2d(x,s):
    s2 = int(s/4)*2+1
    x2 = unsqueeze(x, dim=0, n=4-x.dim())
    res = kornia.filters.gaussian_blur2d(x2, (s2,s2), (s,s), 'replicate')
    return res.squeeze()

In [None]:
#export
@patch
def mask_from_blur(x:Tensor, window, sigma=0.3, thresh=0.05, remove_max=True):
    p = x.windowed(*window)
    if remove_max: p[p==1] = 0
    return gauss_blur2d(p, s=sigma*x.shape[-1])>thresh

In [None]:
#export
@patch
def mask_from_blur(x:DcmDataset, window, sigma=0.3, thresh=0.05, remove_max=True):
    return to_device(x.scaled_px).mask_from_blur(window, sigma, thresh, remove_max=remove_max)

In [None]:
mask = dcm.mask_from_blur(dicom_windows.brain)
wind = dcm.windowed(*dicom_windows.brain)

_,ax = subplots(1,1)
show_image(wind, ax=ax[0])
show_image(mask, alpha=0.5, cmap=plt.cm.Reds, ax=ax[0]);

In [None]:
#export
def _px_bounds(x, dim):
    c = x.sum(dim).nonzero().cpu()
    idxs,vals = torch.unique(c[:,0],return_counts=True)
    vs = torch.split_with_sizes(c[:,1],tuple(vals))
    d = {k.item():v for k,v in zip(idxs,vs)}
    default_u = tensor([0,x.shape[-1]-1])
    b = [d.get(o,default_u) for o in range(x.shape[0])]
    b = [tensor([o.min(),o.max()]) for o in b]
    return torch.stack(b)

In [None]:
#export
def mask2bbox(mask):
    no_batch = mask.dim()==2
    if no_batch: mask = mask[None]
    bb1 = _px_bounds(mask,-1).t()
    bb2 = _px_bounds(mask,-2).t()
    res = torch.stack([bb1,bb2],dim=1).to(mask.device)
    return res[...,0] if no_batch else res

In [None]:
bbs = mask2bbox(mask)
lo,hi = bbs
show_image(wind[lo[0]:hi[0],lo[1]:hi[1]]);

In [None]:
#export
def _bbs2sizes(crops, init_sz, use_square=True):
    bb = crops.flip(1)
    szs = (bb[1]-bb[0])
    if use_square: szs = szs.max(0)[0][None].repeat((2,1))
    overs = (szs+bb[0])>init_sz
    bb[0][overs] = init_sz-szs[overs]
    lows = (bb[0]/float(init_sz))
    return lows,szs/float(init_sz)

In [None]:
#export
def crop_resize(x, crops, new_sz):
    # NB assumes square inputs. Not tested for non-square anythings!
    bs = x.shape[0]
    lows,szs = _bbs2sizes(crops, x.shape[-1])
    if not isinstance(new_sz,(list,tuple)): new_sz = (new_sz,new_sz)
    id_mat = tensor([[1.,0,0],[0,1,0]])[None].repeat((bs,1,1)).to(x.device)
    with warnings.catch_warnings():
        warnings.filterwarnings('ignore', category=UserWarning)
        sp = F.affine_grid(id_mat, (bs,1,*new_sz))+1.
        grid = sp*unsqueeze(szs.t(),1,n=2)+unsqueeze(lows.t()*2.,1,n=2)
        return F.grid_sample(x.unsqueeze(1), grid-1)

In [None]:
px256 = crop_resize(to_device(wind[None]), bbs[...,None], 128)[0]
show_image(px256)
px256.shape

In [None]:
#export
@patch
def to_nchan(x:Tensor, wins, bins=None):
    res = [x.windowed(*win) for win in wins]
    if not isinstance(bins,int) or bins!=0: res.append(x.hist_scaled(bins).clamp(0,1))
    dim = [0,1][x.dim()==3]
    return TensorCTScan(torch.stack(res, dim=dim))

In [None]:
#export
@patch
def to_nchan(x:DcmDataset, wins, bins=None):
    return x.scaled_px.to_nchan(wins, bins)

In [None]:
#export
@patch
def to_3chan(x:Tensor, win1, win2, bins=None):
    return x.to_nchan([win1,win2],bins=bins)

In [None]:
#export
@patch
def to_3chan(x:DcmDataset, win1, win2, bins=None):
    return x.scaled_px.to_3chan(win1, win2, bins)

In [None]:
show_images(dcm.to_nchan([dicom_windows.brain,dicom_windows.subdural,dicom_windows.abdomen_soft]))

In [None]:
#export
@patch
def save_jpg(x:(Tensor,DcmDataset), path, wins, bins=None, quality=90):
    fn = Path(path).with_suffix('.jpg')
    x = (x.to_nchan(wins, bins)*255).byte()
    im = Image.fromarray(x.permute(1,2,0).numpy(), mode=['RGB','CMYK'][x.shape[0]==4])
    im.save(fn, quality=quality)

In [None]:
#export
@patch
def to_uint16(x:(Tensor,DcmDataset), bins=None):
    d = x.hist_scaled(bins).clamp(0,1) * 2**16
    return d.numpy().astype(np.uint16)

In [None]:
#export
@patch
def save_tif16(x:(Tensor,DcmDataset), path, bins=None, compress=True):
    fn = Path(path).with_suffix('.tif')
    Image.fromarray(x.to_uint16(bins)).save(str(fn), compression='tiff_deflate' if compress else None)

In [None]:
_,axs=subplots(1,2)
with tempfile.TemporaryDirectory() as f:
    f = Path(f)
    dcm.save_jpg(f/'test.jpg', [dicom_windows.brain,dicom_windows.subdural])
    show_image(Image.open(f/'test.jpg'), ax=axs[0])
    dcm.save_tif16(f/'test.tif')
    show_image(Image.open(str(f/'test.tif')), ax=axs[1]);

In [None]:
#export
@patch
def set_pixels(self:DcmDataset, px):
    self.PixelData = px.tobytes()
    self.Rows,self.Columns = px.shape
DcmDataset.pixel_array = property(DcmDataset.pixel_array.fget, set_pixels)

In [None]:
#export
@patch
def zoom(self:DcmDataset, ratio):
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", UserWarning)
        self.pixel_array = ndimage.zoom(self.pixel_array, ratio)

In [None]:
#export
@patch
def zoom_to(self:DcmDataset, sz):
    if not isinstance(sz,(list,tuple)): sz=(sz,sz)
    rows,cols = sz
    self.zoom((rows/self.Rows,cols/self.Columns))

In [None]:
#export
@patch_property
def shape(self:DcmDataset): return self.Rows,self.Columns

In [None]:
dcm2 = TEST_DCM.dcmread()
dcm2.zoom_to(90)
test_eq(dcm2.shape, (90,90))

In [None]:
dcm2 = TEST_DCM.dcmread()
dcm2.zoom(0.25)
dcm2.show()

In [None]:
#export
def _cast_dicom_special(x):
    cls = type(x)
    if not cls.__module__.startswith('pydicom'): return x
    if cls.__base__ == object: return x
    return cls.__base__(x)

def _split_elem(res,k,v):
    if not isinstance(v,DcmMultiValue): return
    res[f'Multi{k}'] = 1
    for i,o in enumerate(v): res[f'{k}{"" if i==0 else i}']=o

In [None]:
#export
@patch
def as_dict(self:DcmDataset, px_summ=True, window=dicom_windows.brain):
    pxdata = (0x7fe0,0x0010)
    vals = [self[o] for o in self.keys() if o != pxdata]
    its = [(v.keyword,v.value) for v in vals]
    res = dict(its)
    res['fname'] = self.filename
    for k,v in its: _split_elem(res,k,v)
    if not px_summ: return res
    stats = 'min','max','mean','std'
    try:
        pxs = self.pixel_array
        for f in stats: res['img_'+f] = getattr(pxs,f)()
        res['img_pct_window'] = self.pct_in_window(*window)
    except Exception as e:
        for f in stats: res['img_'+f] = 0
        print(res,e)
    for k in res: res[k] = _cast_dicom_special(res[k])
    return res

In [None]:
#export
def _dcm2dict(fn, **kwargs): return fn.dcmread().as_dict(**kwargs)

In [None]:
#export
@delegates(parallel)
def _from_dicoms(cls, fns, n_workers=0, **kwargs):
    return pd.DataFrame(parallel(_dcm2dict, fns, n_workers=n_workers, **kwargs))
pd.DataFrame.from_dicoms = classmethod(_from_dicoms)

## Export -

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_torch_core.ipynb.
Converted 01_layers.ipynb.
Converted 02_data.load.ipynb.
Converted 03_data.core.ipynb.
Converted 04_data.external.ipynb.
Converted 05_data.transforms.ipynb.
Converted 06_data.block.ipynb.
Converted 07_vision.core.ipynb.
Converted 08_vision.data.ipynb.
Converted 09_vision.augment.ipynb.
Converted 09b_vision.utils.ipynb.
Converted 09c_vision.widgets.ipynb.
Converted 10_tutorial.pets.ipynb.
Converted 11_vision.models.xresnet.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_callback.core.ipynb.
Converted 13a_learner.ipynb.
Converted 13b_metrics.ipynb.
Converted 14_callback.schedule.ipynb.
Converted 14a_callback.data.ipynb.
Converted 15_callback.hook.ipynb.
Converted 15a_vision.models.unet.ipynb.
Converted 16_callback.progress.ipynb.
Converted 17_callback.tracker.ipynb.
Converted 18_callback.fp16.ipynb.
Converted 19_callback.mixup.ipynb.
Converted 20_interpret.ipynb.
Converted 20a_distributed.ipynb.
Converted 21_vision.learner.ipynb.
Converted 22_tutorial.ima