Skip to content
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
320 lines (277 sloc) 9.75 KB
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/60_medical.imaging.ipynb (unless otherwise specified).
__all__ = ['DcmDataset', 'DcmTag', 'DcmMultiValue', 'dcmread', 'pixels', 'scaled_px', 'array_freqhist_bins',
'dicom_windows', 'TensorCTScan', 'PILCTScan', 'show', 'uniform_blur2d', 'gauss_blur2d', 'mask2bbox',
'crop_resize', 'shape']
# Cell
from ..basics import *
from import *
import pydicom,kornia,skimage
from pydicom.dataset import Dataset as DcmDataset
from pydicom.tag import BaseTag as DcmTag
from pydicom.multival import MultiValue as DcmMultiValue
from PIL import Image
import cv2
except: pass
# Cell
def dcmread(self:Path): return pydicom.dcmread(str(self))
# Cell
def pixels(self:DcmDataset):
"`pixel_array` as a tensor"
return tensor(self.pixel_array.astype(np.float32))
# Cell
def scaled_px(self:DcmDataset):
"`pixels` scaled by `RescaleSlope` and `RescaleIntercept`"
img = self.pixels
return img*self.RescaleSlope + self.RescaleIntercept
# Cell
def array_freqhist_bins(self, n_bins=100):
"A numpy based function to split the range of pixel values into groups, such that each group has around the same number of pixels"
imsd = np.sort(self.flatten())
t = np.array([0.001])
t = np.append(t, np.arange(n_bins)/n_bins+(1/2/n_bins))
t = np.append(t, 0.999)
t = (len(imsd)*t+0.5).astype(
return np.unique(imsd[t])
# Cell
def freqhist_bins(self:Tensor, n_bins=100):
"A function to split the range of pixel values into groups, such that each group has around the same number of pixels"
imsd = self.view(-1).sort()[0]
t =[tensor([0.001]),
t = (len(imsd)*t).long()
return imsd[t].unique()
# Cell
def hist_scaled_pt(self:Tensor, brks=None):
# Pytorch-only version - switch to this if/when interp_1d can be optimized
if brks is None: brks = self.freqhist_bins()
brks =
ys = torch.linspace(0., 1., len(brks)).to(self.device)
return self.flatten().interp_1d(brks, ys).reshape(self.shape).clamp(0.,1.)
# Cell
def hist_scaled(self:Tensor, brks=None):
if self.device.type=='cuda': return self.hist_scaled_pt(brks)
if brks is None: brks = self.freqhist_bins()
ys = np.linspace(0., 1., len(brks))
x = self.numpy().flatten()
x = np.interp(x, brks.numpy(), ys)
return tensor(x).reshape(self.shape).clamp(0.,1.)
# Cell
def hist_scaled(self:DcmDataset, brks=None, min_px=None, max_px=None):
px = self.scaled_px
if min_px is not None: px[px<min_px] = min_px
if max_px is not None: px[px>max_px] = max_px
return px.hist_scaled(brks=brks)
# Cell
def windowed(self:Tensor, w, l):
px = self.clone()
px_min = l - w//2
px_max = l + w//2
px[px<px_min] = px_min
px[px>px_max] = px_max
return (px-px_min) / (px_max-px_min)
# Cell
def windowed(self:DcmDataset, w, l):
return self.scaled_px.windowed(w,l)
# Cell
# From
dicom_windows = types.SimpleNamespace(
# Cell
class TensorCTScan(TensorImageBW): _show_args = {'cmap':'bone'}
# Cell
class PILCTScan(PILBase): _open_args,_tensor_cls,_show_args = {},TensorCTScan,TensorCTScan._show_args
# Cell
def show(self:DcmDataset, scale=True,, min_px=-1100, max_px=None, **kwargs):
px = (self.windowed(*scale) if isinstance(scale,tuple)
else self.hist_scaled(min_px=min_px,max_px=max_px,brks=scale) if isinstance(scale,(ndarray,Tensor))
else self.hist_scaled(min_px=min_px,max_px=max_px) if scale
else self.scaled_px)
show_image(px, cmap=cmap, **kwargs)
# Cell
def pct_in_window(dcm:DcmDataset, w, l):
"% of pixels in the window `(w,l)`"
px = dcm.scaled_px
return ((px > l-w//2) & (px < l+w//2)).float().mean().item()
# Cell
def uniform_blur2d(x,s):
w = x.new_ones(1,1,1,s)/s
# Factor 2d conv into 2 1d convs
x = unsqueeze(x, dim=0, n=4-x.dim())
r = (F.conv2d(x, w, padding=s//2))
r = (F.conv2d(r, w.transpose(-1,-2), padding=s//2)).cpu()[:,0]
return r.squeeze()
# Cell
def gauss_blur2d(x,s):
s2 = int(s/4)*2+1
x2 = unsqueeze(x, dim=0, n=4-x.dim())
res = kornia.filters.gaussian_blur2d(x2, (s2,s2), (s,s), 'replicate')
return res.squeeze()
# Cell
def mask_from_blur(x:Tensor, window, sigma=0.3, thresh=0.05, remove_max=True):
p = x.windowed(*window)
if remove_max: p[p==1] = 0
return gauss_blur2d(p, s=sigma*x.shape[-1])>thresh
# Cell
def mask_from_blur(x:DcmDataset, window, sigma=0.3, thresh=0.05, remove_max=True):
return to_device(x.scaled_px).mask_from_blur(window, sigma, thresh, remove_max=remove_max)
# Cell
def _px_bounds(x, dim):
c = x.sum(dim).nonzero().cpu()
idxs,vals = torch.unique(c[:,0],return_counts=True)
vs = torch.split_with_sizes(c[:,1],tuple(vals))
d = {k.item():v for k,v in zip(idxs,vs)}
default_u = tensor([0,x.shape[-1]-1])
b = [d.get(o,default_u) for o in range(x.shape[0])]
b = [tensor([o.min(),o.max()]) for o in b]
return torch.stack(b)
# Cell
def mask2bbox(mask):
no_batch = mask.dim()==2
if no_batch: mask = mask[None]
bb1 = _px_bounds(mask,-1).t()
bb2 = _px_bounds(mask,-2).t()
res = torch.stack([bb1,bb2],dim=1).to(mask.device)
return res[...,0] if no_batch else res
# Cell
def _bbs2sizes(crops, init_sz, use_square=True):
bb = crops.flip(1)
szs = (bb[1]-bb[0])
if use_square: szs = szs.max(0)[0][None].repeat((2,1))
overs = (szs+bb[0])>init_sz
bb[0][overs] = init_sz-szs[overs]
lows = (bb[0]/float(init_sz))
return lows,szs/float(init_sz)
# Cell
def crop_resize(x, crops, new_sz):
# NB assumes square inputs. Not tested for non-square anythings!
bs = x.shape[0]
lows,szs = _bbs2sizes(crops, x.shape[-1])
if not isinstance(new_sz,(list,tuple)): new_sz = (new_sz,new_sz)
id_mat = tensor([[1.,0,0],[0,1,0]])[None].repeat((bs,1,1)).to(x.device)
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=UserWarning)
sp = F.affine_grid(id_mat, (bs,1,*new_sz))+1.
grid = sp*unsqueeze(szs.t(),1,n=2)+unsqueeze(lows.t()*2.,1,n=2)
return F.grid_sample(x.unsqueeze(1), grid-1)
# Cell
def to_nchan(x:Tensor, wins, bins=None):
res = [x.windowed(*win) for win in wins]
if not isinstance(bins,int) or bins!=0: res.append(x.hist_scaled(bins).clamp(0,1))
dim = [0,1][x.dim()==3]
return TensorCTScan(torch.stack(res, dim=dim))
# Cell
def to_nchan(x:DcmDataset, wins, bins=None):
return x.scaled_px.to_nchan(wins, bins)
# Cell
def to_3chan(x:Tensor, win1, win2, bins=None):
return x.to_nchan([win1,win2],bins=bins)
# Cell
def to_3chan(x:DcmDataset, win1, win2, bins=None):
return x.scaled_px.to_3chan(win1, win2, bins)
# Cell
def save_jpg(x:(Tensor,DcmDataset), path, wins, bins=None, quality=90):
fn = Path(path).with_suffix('.jpg')
x = (x.to_nchan(wins, bins)*255).byte()
im = Image.fromarray(x.permute(1,2,0).numpy(), mode=['RGB','CMYK'][x.shape[0]==4]), quality=quality)
# Cell
def to_uint16(x:(Tensor,DcmDataset), bins=None):
d = x.hist_scaled(bins).clamp(0,1) * 2**16
return d.numpy().astype(np.uint16)
# Cell
def save_tif16(x:(Tensor,DcmDataset), path, bins=None, compress=True):
fn = Path(path).with_suffix('.tif')
Image.fromarray(x.to_uint16(bins)).save(str(fn), compression='tiff_deflate' if compress else None)
# Cell
def set_pixels(self:DcmDataset, px):
self.PixelData = px.tobytes()
self.Rows,self.Columns = px.shape
DcmDataset.pixel_array = property(DcmDataset.pixel_array.fget, set_pixels)
# Cell
def zoom(self:DcmDataset, ratio):
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
self.pixel_array = ndimage.zoom(self.pixel_array, ratio)
# Cell
def zoom_to(self:DcmDataset, sz):
if not isinstance(sz,(list,tuple)): sz=(sz,sz)
rows,cols = sz
# Cell
def shape(self:DcmDataset): return self.Rows,self.Columns
# Cell
def _cast_dicom_special(x):
cls = type(x)
if not cls.__module__.startswith('pydicom'): return x
if cls.__base__ == object: return x
return cls.__base__(x)
def _split_elem(res,k,v):
if not isinstance(v,DcmMultiValue): return
res[f'Multi{k}'] = 1
for i,o in enumerate(v): res[f'{k}{"" if i==0 else i}']=o
# Cell
def as_dict(self:DcmDataset, px_summ=True, window=dicom_windows.brain):
pxdata = (0x7fe0,0x0010)
vals = [self[o] for o in self.keys() if o != pxdata]
its = [(v.keyword,v.value) for v in vals]
res = dict(its)
res['fname'] = self.filename
for k,v in its: _split_elem(res,k,v)
if not px_summ: return res
stats = 'min','max','mean','std'
pxs = self.pixel_array
for f in stats: res['img_'+f] = getattr(pxs,f)()
res['img_pct_window'] = self.pct_in_window(*window)
except Exception as e:
for f in stats: res['img_'+f] = 0
for k in res: res[k] = _cast_dicom_special(res[k])
return res
# Cell
def _dcm2dict(fn, **kwargs): return fn.dcmread().as_dict(**kwargs)
# Cell
def _from_dicoms(cls, fns, n_workers=0, **kwargs):
return pd.DataFrame(parallel(_dcm2dict, fns, n_workers=n_workers, **kwargs))
pd.DataFrame.from_dicoms = classmethod(_from_dicoms)
You can’t perform that action at this time.