In [1]:
#default_exp train3d

In [2]:
#export
from rsna_retro.imports import *
from rsna_retro.metadata import *
from rsna_retro.preprocess import *
from rsna_retro.train import *

Loading imports


In [3]:
torch.cuda.set_device(2)

In [4]:
# #export
# @IntToFloatTensor
# def encodes(self, o:TensorCTScan): return o.float().div_(self.div)

In [5]:
#export
class OpenCTs:
    def __init__(self, path, open_fn=get_pil_fn, tfms=None): 
        self.fn = open_fn(path)
        if tfms is None: tfms = [] 
        self.tfms = Pipeline(tfms+[ToTensor])
    def __call__(self, item):
        if isinstance(item, (str, Path)): return self.tfms(self.fn(item))
        xs = [self.tfms(self.fn(x)) for x in item]
        return TensorCTScan(torch.stack(xs))

In [6]:
# df_slice_count = Meta.df_comb.groupby(['SeriesInstanceUID']).agg(['count'])
# max(df_slice_count.PatientID.values), min(df_slice_count.PatientID.values)


In [7]:
#export
max_seq = 60
def pad_batch(x, pad_to=None, value=0):
    if isinstance(x, tuple): return tuple([pad_batch(s, pad_to, value) for s in x])
    if isinstance(x, dict): return {k:pad_batch(item, pad_to, value) for k,item in x.items()}
    bs_pad = pad_to-x.shape[0]
    pad = [0]*len(x.shape)*2
    pad[-1] = bs_pad
    return type(x)(F.pad(x, pad=pad, value=value))

# def pad_collate(items, values=[0,-1], pad_to=max_seq):
def pad_collate(items, values=[0,-1], pad_to=max_seq):
#     def get_bs(x): return x[0].shape[0] if hasattr(x[0], 'shape') else get_bs(x[0])
#     pad_to = max([get_bs(x) for x in items]) if fixed_pad is None else fixed_pad
    pad_to = 60
    def pad_row(row, pad_to, vals): return tuple([pad_batch(x,pad_to,v) for x,v in zip(row,vals)])
    res = [pad_row(row, pad_to, values) for row in items]
    return res

In [8]:
# # Normalization stats for position
# pos = Meta.df_comb.ImagePositionPatient2.values
# pos.min(), pos.max(), pos.mean(), pos.std()
# # (pos - pos.min())/pos.max()

In [9]:
#export
# saving hardcoded positioning so we can normalize the test set the same way
pos_min, pos_max, pos_mean, pos_std = (-998.400024, 1794.01276, 167.08153131830622, 244.90964319136026)

In [10]:
#export
class TfmSOP:
    def __init__(self,df,open_fn,test=False,meta=False):
        store_attr(self, 'df,open_fn,test,meta')
    
    def x(self, sid):
        sids = self.df.SOPInstanceUID[sid].values
        imgs = self.open_fn(sids)
        if self.meta: 
            pos = self.df.ImagePositionPatient2[sid].values.reshape(-1, 1)
            pos = torch.from_numpy(pos).float()
            pos_norm = (pos - pos_mean)/pos_std
#             return {'ct':imgs, 'pos':pos_norm}
            return imgs, pos_norm
        return imgs
    
    def y(self, sid): 
        sids = self.df.SOPInstanceUID[sid].values
        if self.test: return torch.zeros((self.df.loc[sid].shape[0], 6)).float()
        vals = self.df.loc[sid,htypes].values
        return TensorMultiCategory(tensor(vals)).float()

In [11]:
#export
def get_3d_dsets(df, open_fn, grps=Meta.grps_stg1, cv_idx=0, tfms=None, column='SeriesInstanceUID', test=False, meta=False):
    df_series = df.reset_index().set_index(column).sort_values([column, "ImagePositionPatient2"])
    sids = df_series.index.unique()
    sid2idx = dict(zip(sids, range(len(sids))))
    
    # multi index is 10x faster
    df_series.index = pd.MultiIndex.from_tuples(df_series.index.str.split('|').tolist())
    tfm = TfmSOP(df_series, open_fn, test=test, meta=meta)
    
    if test: 
        splits=[L.range(sids), L.range(sids)]
    else:
        s1 = [sid2idx[sid] for sid in group_cv(cv_idx,grps) if sid in sid2idx]
        s2 = [sid2idx[sid] for sid in grps[cv_idx] if sid in sid2idx]
        splits = (s1, s2)
    dsets = Datasets(sids, [[tfm.x]+L(tfms),[tfm.y]], splits=splits)
    return dsets

In [12]:
# #export
# _collate_types = (ndarray, Tensor, typing.Mapping, str)
# def test_collate(t):
#     b = t[0]
#     print('kdfjslfjlsdfk', [type(z) for z in t])
    
# #     print(t)
# #     pdb.set_trace()
#     if isinstance(b, tuple): 
# #         print('sdfds', type(t), type(b))
#         return [test_collate(s) for s in zip(*t)]
#     if isinstance(b, _collate_types): 
#         print('ctype', type(t), type(b))
# #         print('dsfjsdf', [type(x)(default_collate(x)) for x in t])
#         return type(b)(default_collate(t))
# #     print('else', type(t))
# #     return type(t)(default_collate(t))
        
# #     return (default_collate(t) if isinstance(b, _collate_types)
# #             else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence)
# #             else default_collate(t))

In [13]:
# dsets = get_3d_dsets(Meta.df_comb1, open_fn=OpenCTs(path_jpg), 
#                      grps=Meta.grps_stg1, test=True, meta=True)
# test_collate([dsets[0], dsets[0]])

In [14]:
#export
def get_dls(dsets, bs, batch_tfms, num_workers=8):
    before_batch = [pad_collate] if bs != 1 else []
    dls = DataLoaders(
        TfmdDL(dsets.train, bs=bs, before_batch=before_batch, after_batch=batch_tfms, num_workers=num_workers, shuffle=True),
        TfmdDL(dsets.valid, bs=bs, before_batch=before_batch, after_batch=batch_tfms, num_workers=num_workers)
    )
    dls.device = default_device()
    dls.c = 6
    return dls

In [15]:
#export
def get_3d_dls(df, path=path_jpg256, bs=1, num_workers=8, tfms=None, **kwargs):
    dsets = get_3d_dsets(df, open_fn=OpenCTs(path), **kwargs)

    nrm = Normalize.from_stats(mean,std)
    batch_tfms = L(IntToFloatTensor(), nrm, *L(tfms))
    
    return get_dls(dsets, bs, batch_tfms, num_workers)


In [16]:
dls = get_3d_dls(Meta.df_comb, path=path_jpg256, bs=2)
%time xb,yb = dls.valid.one_batch();
# # Wall time: 4.69 s

CPU times: user 166 ms, sys: 48.7 ms, total: 215 ms
Wall time: 215 ms


In [17]:
#export
class Tfm5D(Transform):
    order = 0
    def encodes(self, o:TensorImage): return o.view(-1, *o.shape[-3:])
class Tfm6D(Transform):
    order = 200
    def encodes(self, o:TensorImage): return o.view(-1, max_seq, *o.shape[-3:])

In [18]:
#export
class PipeMeta(Pipeline):
    def __call__(self, o): 
#         if isinstance(o, tuple): return [super(PipeMeta, self).__call__(x) for x in o]
        if isinstance(o, tuple): 
            return super().__call__(TensorCTScan(o[0])), o[1]
        return super().__call__(o)

In [19]:
class MetaTfm(Transform):
    def __init__(self, tfms):
        super().__init__()
        self.pipe = Pipeline(tfms)
    def encodes(self, o): 
        return [self.pipe(x) for x in o]

In [20]:
#export
def get_3d_dls_aug(df, sz=None, path=path_jpg256, bs=1, num_workers=8, grps=Meta.grps_stg1, test=False, meta=False):
    dsets = get_3d_dsets(df, open_fn=OpenCTs(path), grps=grps, test=test, meta=meta)
    
    tfms = [Tfm5D(), IntToFloatTensor(), Tfm6D()]+aug_transforms()
    if sz is not None: tfms = tfms+[RandomResizedCropGPU(sz, min_scale=0.7, ratio=(1.,1.), valid_scale=0.9)]

    nrm = Normalize.from_stats(mean,std)
    batch_tfms = tfms+L(nrm)
    
    if meta: return get_dls(dsets, bs, [PipeMeta(batch_tfms)], num_workers)
    return get_dls(dsets, bs, batch_tfms, num_workers)


In [39]:
??Normalize

In [58]:
broadcast_vec(0,3,mean)[0].shape

torch.Size([3, 1, 1])

In [59]:
#export
def get_3d_dls_album(df, sz=None, path=path_jpg256, bs=1, num_workers=8, grps=Meta.grps_stg1, test=False, meta=False):
    
    nrm = Normalize.from_stats(mean,std,cuda=False,dim=0,ndim=3)
    ab_tfms = [ABTfms((sz,sz)), image2tensor, IntToFloatTensor(), nrm]
    dsets = get_3d_dsets(df, open_fn=OpenCTs(path, get_cv2_fn, tfms=ab_tfms), grps=grps, test=test, meta=meta)
#     batch_tfms = [nrm]
    batch_tfms=None
    if meta: return get_dls(dsets, bs, batch_tfms, num_workers)
    return get_dls(dsets, bs, batch_tfms, num_workers)


In [60]:
dls = get_3d_dls_album(Meta.df_comb, sz=128, bs=2, grps=Meta.grps_stg1, meta=True)

In [61]:
xb, yb = dls.one_batch()

In [62]:
xb[0].shape

torch.Size([2, 60, 3, 128, 128])

In [None]:
dls = get_3d_dls_aug(Meta.df_comb, sz=128, bs=2, grps=Meta.grps_stg1, meta=True)

In [69]:
dls = get_3d_dls_aug(Meta.df_tst, sz=384, path=path_tst_jpg, bs=4, test=True, meta=True)
%time xb,yb = next(iter(dls.valid)); xb[0].shape
# Wall time: 14.2 s - without Tfm5D

CPU times: user 190 ms, sys: 592 ms, total: 782 ms
Wall time: 4.12 s


torch.Size([4, 60, 3, 384, 384])

In [22]:
#export
class Wrap():
    def __init__(self, tfm, tfm_all=True): self.tfm = tfm
    def __getattr__(self, x): return getattr(self.tfm, x)
    def __call__(self, *args, **kwargs): return self.encodes(*args, **kwargs)
    def encodes(self, x:TensorImage): return self.reshape(x, self.tfm)
    def decodes(self, x:TensorImage): return self.reshape(x, self.tfm.decodes)
    
    def reshape(self, x, func):
        if len(x.shape) != 5: return x
        bs,ts,ch,w,h = x.shape
        x = x.reshape(-1,ch,w,h)
        out = func(x)
        return out.reshape(bs,ts, *out.shape[-3:])
        

# wrapped_tfms = [Wrap(tfm) for tfm in aug_transforms()]

In [23]:
# dsets = get_3d_dsets(Meta.df_any, open_fn=OpenCTs(path_jpg256))
# x,y = dsets[0]
# x.shape, y.shape

# dls = get_3d_dls(df_any, bs=10)
# x,y = dls.one_batch()
# x.shape, y.shape

## Features

In [24]:
#export
def get_np_fn(p):
    def _f(fn): return torch.from_numpy(np.load(str(p/f'{fn}.npy')))
    return _f

In [35]:
#export
def get_3d_dls_feat(df, path=path_feat_384avg, bs=1, num_workers=8, test=False, meta=False):
    dsets = get_3d_dsets(df, open_fn=OpenCTs(path, get_np_fn), test=test, meta=meta)
    return get_dls(dsets, bs, [], num_workers)


In [43]:
dls_feat = get_3d_dls_feat(Meta.df_comb, bs=10, meta=True)
xb,yb = dls_feat.one_batch()
# xb.shape, yb.shape

In [48]:
xb[1]

tensor([[[-0.6698],
         [-0.6538],
         [-0.6379],
         [-0.6220],
         [-0.6061],
         [-0.5901],
         [-0.5742],
         [-0.5583],
         [-0.5423],
         [-0.5264],
         [-0.5105],
         [-0.4946],
         [-0.4738],
         [-0.4525],
         [-0.4313],
         [-0.4101],
         [-0.3888],
         [-0.3676],
         [-0.3464],
         [-0.3251],
         [-0.3039],
         [-0.2826],
         [-0.2614],
         [-0.2402],
         [-0.2189],
         [-0.1977],
         [-0.1765],
         [-0.1552],
         [-0.1340],
         [-0.1127],
         [-0.0915],
         [-0.0703],
         [-0.0490],
         [-0.0278],
         [-0.0066],
         [ 0.0147],
         [ 0.0000],
         [ 0.0000],
         [ 0.0000],
         [ 0.0000],
         [ 0.0000],
         [ 0.0000],
         [ 0.0000],
         [ 0.0000],
         [ 0.0000],
         [ 0.0000],
         [ 0.0000],
         [ 0.0000],
         [ 0.0000],
         [ 0.0000],


## Model

In [27]:
#export
class ReshapeBodyHook():
    def __init__(self, body):
        super().__init__()
        self.pre_reg = body.register_forward_pre_hook(self.pre_hook)
        self.reg = body.register_forward_hook(self.forward_hook)
        self.shape = None
    
    def deregister(self):
        self.reg.remove()
        self.pre_reg.remove()
        
    def pre_hook(self, module, input):
        x = input[0]
        self.shape = x.shape
        return (x.view(-1, *x.shape[2:]),)
    
    def forward_hook(self, module, input, x):
        return x.view(*self.shape[:2], *x.shape[1:])

In [28]:
#export
def conv3(ni,nf,stride=1):
    return ConvLayer(ni, nf, (5,3,3), stride=(1,stride,stride), ndim=3, padding=(2,1,1))

In [29]:
#export
class Batchify(Module):
    def forward(self, x): return x.transpose(1,2)

class DeBatchify(Module):
    def forward(self, x):
        x_t = x.transpose(1,2)
        x_c = x_t.contiguous().view(-1, *x_t.shape[2:])
        return x_c

def get_3d_head(concat_pool=True):
    pool, feat = (AdaptiveConcatPool2d(1), 64*2) if concat_pool else (nn.AdaptiveAvgPool2d(1), 64)
    m = nn.Sequential(Batchify(),
        conv3(512,256,2), # 8
        conv3(256,128,2), # 4
        conv3(128, 64,2), # 2
        DeBatchify(), pool, Flatten(), nn.Linear(feat,6))
    init_cnn(m)
    return m

## Ignore Padding in Loss

In [30]:
#export
class DePadLoss(Callback):
    def __init__(self, pad_idx=-1): 
        super().__init__()
        store_attr(self, 'pad_idx')

    def after_pred(self):
        learn = self.learn
        targ = learn.yb[0].view(-1, *learn.yb[0].shape[2:])
        if targ.shape[0] != self.pred.shape[0]:
            pred = learn.pred.view(-1, *learn.pred.shape[2:])
        else: pred = learn.pred
        
        mask = targ[:,-1] != self.pad_idx
        
        learn.pred = pred[mask]
        learn.yb = (targ[mask],)

## Training Features - By Batch

In [None]:
dls_feat = get_3d_dls_feat(Meta.df_any, bs=16)

In [None]:
m = get_3d_head()
learn = get_learner(dls_feat, m)

In [29]:
learn.add_cb(DePadLoss())

<fastai2.learner.Learner at 0x7fd0316ad6d0>

In [None]:
do_fit(learn, 1, 1e-2)

In [None]:
learn.summary()

## Train on Slice

In [29]:
m = get_3d_head()

In [32]:
dls = get_3d_dls_aug(Meta.df_any, bs=10)
config=dict(custom_head=m, init=None)
learn = get_learner(dls, resnet18, get_loss(), config=config)

In [33]:
# learn.model[0] = ReshapeCNNBody(learn.model[0])

In [34]:
hook = ReshapeBodyHook(learn.model[0])

In [35]:
learn.add_cb(DePadLoss())

<fastai2.learner.Learner at 0x7f5be0f68710>

In [None]:
do_fit(learn, 1, 1e-2)

## Export

In [64]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_metadata.ipynb.
Converted 01_preprocess.ipynb.
Converted 01_preprocess_mean_std.ipynb.
Converted 02_train.ipynb.
Converted 02_train_01_save_features.ipynb.
Converted 03_train3d.ipynb.
Converted 04_trainfull3d_deprecated.ipynb.
Converted 04_trainfull3d_labels.ipynb.
Converted 05_train_adjacent.ipynb.
Converted 06_seutao_features.ipynb.
Converted Tabular_02_FeatureImportance.ipynb.
Converted Untitled.ipynb.
