In [1]:
#default_exp train3d

In [2]:
#export
from rsna_retro.imports import *
from rsna_retro.metadata import *
from rsna_retro.preprocess import *
from rsna_retro.train import *

Loading imports


In [3]:
torch.cuda.set_device(3)

In [4]:
#export
class OpenCTs:
    def __init__(self, path): 
        self.fn = get_pil_fn(path)
        self.tt = ToTensor()
    def __call__(self, item):
        if isinstance(item, (str, Path)): return self.fn(item)
        xs = [self.tt(self.fn(x)) for x in item]
        return TensorCTScan(torch.stack(xs))

In [5]:
# df_slice_count = Meta.df_comb.groupby(['SeriesInstanceUID']).agg(['count'])
# max(df_slice_count.PatientID.values), min(df_slice_count.PatientID.values)


In [6]:
#export
max_seq = 60
def pad_batch(x, pad_to=None, value=0):
    if isinstance(x, tuple): return tuple([pad_batch(s, pad_to, value) for s in x])
    bs_pad = pad_to-x.shape[0]
    pad = [0]*len(x.shape)*2
    pad[-1] = bs_pad
    return F.pad(x, pad=pad, value=value)

def pad_collate(items, values=[0,-1], pad_to=max_seq):
    def get_bs(x): return x[0].shape[0] if hasattr(x[0], 'shape') else get_bs(x[0])
#     pad_to = max([get_bs(x) for x in items]) if fixed_pad is None else fixed_pad
    pad_to = 60
    def pad_row(row, pad_to, vals): return tuple([pad_batch(x,pad_to,v) for x,v in zip(row,vals)])
    return [pad_row(row, pad_to, values) for row in items]

In [7]:
#export
class TfmSOP:
    def __init__(self,df,open_fn,test=False):
        self.open_fn = open_fn
        self.df = df
        self.test=test
    
    def x(self, sid):
        sids = self.df.SOPInstanceUID[sid].values
        return self.open_fn(sids)
    
    def y(self, sid): 
        sids = self.df.SOPInstanceUID[sid].values
        if self.test: return torch.zeros((self.df.loc[sid].shape[0], 6)).float()
        vals = self.df.loc[sid,htypes].values
        return TensorMultiCategory(tensor(vals)).float()

In [8]:
#export
def get_3d_dsets(df, open_fn, grps=Meta.grps_stg1, cv_idx=0, tfms=None, column='SeriesInstanceUID', test=False):
    df_series = df.reset_index().set_index(column).sort_values([column, "ImagePositionPatient2"])
    sids = df_series.index.unique()
    sid2idx = dict(zip(sids, range(len(sids))))
    
    # multi index is 10x faster
    df_series.index = pd.MultiIndex.from_tuples(df_series.index.str.split('|').tolist())
    tfm = TfmSOP(df_series, open_fn, test=test)
    
    if test: 
        splits=[L.range(sids), L.range(sids)]
    else:
        s1 = [sid2idx[sid] for sid in group_cv(cv_idx,grps) if sid in sid2idx]
        s2 = [sid2idx[sid] for sid in grps[cv_idx] if sid in sid2idx]
        splits = (s1, s2)
    dsets = Datasets(sids, [[tfm.x]+L(tfms),[tfm.y]], splits=splits)
    return dsets

In [9]:
#export
def get_dls(dsets, bs, batch_tfms, num_workers=8):
    before_batch = [pad_collate] if bs != 1 else []
    dls = DataLoaders(
        TfmdDL(dsets.train, bs=bs, before_batch=before_batch, after_batch=batch_tfms, num_workers=num_workers, shuffle=True),
        TfmdDL(dsets.valid, bs=bs, before_batch=before_batch, after_batch=batch_tfms, num_workers=num_workers)
    )
    dls.device = default_device()
    dls.c = 6
    return dls

In [10]:
#export
def get_3d_dls(df, path=path_jpg256, bs=1, num_workers=8, tfms=None):
    dsets = get_3d_dsets(df, open_fn=OpenCTs(path))

    nrm = Normalize.from_stats(mean,std)
    batch_tfms = L(nrm, IntToFloatTensor(), *L(tfms))
    
    return get_dls(dsets, bs, batch_tfms, num_workers)


In [11]:
dsets = get_3d_dsets(Meta.df_tst, open_fn=OpenCTs(path_tst_jpg), tfms=[IntToFloatTensor()], grps=None, test=True)

In [12]:
# dls = get_3d_dls(Meta.df_comb, path=path_jpg, bs=4)
# %time xb,yb = next(iter(dls.valid)); xb.shape, yb.shape
# # Wall time: 4.69 s

In [13]:
#export
class Tfm5D(Transform):
    order = 0
    def encodes(self, o:TensorImage): return o.view(-1, *o.shape[-3:])
class Tfm6D(Transform):
    order = 200
    def encodes(self, o:TensorImage): return o.view(-1, max_seq, *o.shape[-3:])

In [14]:
#export
def get_3d_dls_aug(df, sz=None, path=path_jpg256, bs=1, num_workers=8, grps=Meta.grps_stg1, test=False):
    dsets = get_3d_dsets(df, open_fn=OpenCTs(path), grps=grps, test=test)
    
    tfms = [Tfm5D(), IntToFloatTensor(), Tfm6D()]+aug_transforms()
    if sz is not None: tfms = tfms+[RandomResizedCropGPU(sz, min_scale=0.7, ratio=(1.,1.), valid_scale=0.9)]

    nrm = Normalize.from_stats(mean,std)
    batch_tfms = tfms+L(nrm)
    return get_dls(dsets, bs, batch_tfms, num_workers)


In [15]:
dls = get_3d_dls_aug(Meta.df_tst, sz=384, path=path_tst_jpg, bs=4, test=True)
%time xb,yb = next(iter(dls.valid)); xb.shape
# Wall time: 14.2 s - without Tfm5D

CPU times: user 350 ms, sys: 551 ms, total: 902 ms
Wall time: 3.28 s


torch.Size([4, 60, 3, 384, 384])

In [16]:
#export
class Wrap():
    def __init__(self, tfm, tfm_all=True): self.tfm = tfm
    def __getattr__(self, x): return getattr(self.tfm, x)
    def __call__(self, *args, **kwargs): return self.encodes(*args, **kwargs)
    def encodes(self, x:TensorImage): return self.reshape(x, self.tfm)
    def decodes(self, x:TensorImage): return self.reshape(x, self.tfm.decodes)
    
    def reshape(self, x, func):
        if len(x.shape) != 5: return x
        bs,ts,ch,w,h = x.shape
        x = x.reshape(-1,ch,w,h)
        out = func(x)
        return out.reshape(bs,ts, *out.shape[-3:])
        

# wrapped_tfms = [Wrap(tfm) for tfm in aug_transforms()]

In [17]:
# dsets = get_3d_dsets(Meta.df_any, open_fn=OpenCTs(path_jpg256))
# x,y = dsets[0]
# x.shape, y.shape

# dls = get_3d_dls(df_any, bs=10)
# x,y = dls.one_batch()
# x.shape, y.shape

## Features

In [18]:
#export
def get_np_fn(p):
    def _f(fn): return torch.from_numpy(np.load(str(p/f'{fn}.npy')))
    return _f

In [19]:
#export
class OpenFeats:
    def __init__(self, path):
        self.fn = get_np_fn(path)
    def __call__(self, item):
        if isinstance(item, (str, Path)): return self.fn(item)
        xs = [self.fn(x) for x in item]
        return TensorCTScan(torch.stack(xs))

In [20]:
#export
def get_3d_dls_feat(df, path=path/'features_256', bs=1, num_workers=8):
    dsets = get_3d_dsets(df, open_fn=OpenFeats(path))
    return get_dls(dsets, bs, [], num_workers)


In [21]:
# dls_feat = get_3d_dls_feat(df_any, bs=10)
# xb,yb = dls_feat.one_batch()
# xb.shape, yb.shape

## Model

In [22]:
#export
class ReshapeBodyHook():
    def __init__(self, body):
        super().__init__()
        self.pre_reg = body.register_forward_pre_hook(self.pre_hook)
        self.reg = body.register_forward_hook(self.forward_hook)
        self.shape = None
    
    def deregister(self):
        self.reg.remove()
        self.pre_reg.remove()
        
    def pre_hook(self, module, input):
        x = input[0]
        self.shape = x.shape
        return (x.view(-1, *x.shape[2:]),)
    
    def forward_hook(self, module, input, x):
        return x.view(*self.shape[:2], *x.shape[1:])

In [23]:
#export
def conv3(ni,nf,stride=1):
    return ConvLayer(ni, nf, (5,3,3), stride=(1,stride,stride), ndim=3, padding=(2,1,1))

In [24]:
#export
class Batchify(Module):
    def forward(self, x): return x.transpose(1,2)

class DeBatchify(Module):
    def forward(self, x):
        x_t = x.transpose(1,2)
        x_c = x_t.contiguous().view(-1, *x_t.shape[2:])
        return x_c

def get_3d_head(concat_pool=True):
    pool, feat = (AdaptiveConcatPool2d(1), 64*2) if concat_pool else (nn.AdaptiveAvgPool2d(1), 64)
    m = nn.Sequential(Batchify(),
        conv3(512,256,2), # 8
        conv3(256,128,2), # 4
        conv3(128, 64,2), # 2
        DeBatchify(), pool, Flatten(), nn.Linear(feat,6))
    init_cnn(m)
    return m

## Ignore Padding in Loss

In [25]:
#export
class DePadLoss(Callback):
    def __init__(self, pad_idx=-1): 
        super().__init__()
        store_attr(self, 'pad_idx')

    def after_pred(self):
        learn = self.learn
        targ = learn.yb[0].view(-1, *learn.yb[0].shape[2:])
        if targ.shape[0] != self.pred.shape[0]:
            pred = learn.pred.view(-1, *learn.pred.shape[2:])
        else: pred = learn.pred
        
        mask = targ[:,-1] != self.pad_idx
        
        learn.pred = pred[mask]
        learn.yb = (targ[mask],)

## Training Features - By Batch

In [None]:
dls_feat = get_3d_dls_feat(Meta.df_any, bs=16)

In [None]:
m = get_3d_head()
learn = get_learner(dls_feat, m)

In [29]:
learn.add_cb(DePadLoss())

<fastai2.learner.Learner at 0x7fd0316ad6d0>

In [None]:
do_fit(learn, 1, 1e-2)

In [None]:
learn.summary()

## Train on Slice

In [29]:
m = get_3d_head()

In [32]:
dls = get_3d_dls_aug(Meta.df_any, bs=10)
config=dict(custom_head=m, init=None)
learn = get_learner(dls, resnet18, get_loss(), config=config)

In [33]:
# learn.model[0] = ReshapeCNNBody(learn.model[0])

In [34]:
hook = ReshapeBodyHook(learn.model[0])

In [35]:
learn.add_cb(DePadLoss())

<fastai2.learner.Learner at 0x7f5be0f68710>

In [None]:
do_fit(learn, 1, 1e-2)

## Export

In [3]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_metadata.ipynb.
Converted 01_preprocess.ipynb.
Converted 01_preprocess_mean_std.ipynb.
Converted 02_train.ipynb.
Converted 03_train3d.ipynb.
Converted 03_train3d_01_train3d.ipynb.
Converted 03_train3d_01b_train_lstm.ipynb.
Converted 03_train3d_02_train_3d_head.ipynb.
Converted 03_train3d_02_train_lstm_head.ipynb.
Converted 03_trainfull3d.ipynb.
Converted 04_trainSeq_01_lstm.ipynb.
Converted 04_trainSeq_02_transformer.ipynb.
Converted 04_trainSeq_03_lstm_seutao.ipynb.
Converted 05_train_adjacent.ipynb.
Converted 05_train_adjacent_01_5c_windowed.ipynb.
Converted 05_train_adjacent_01_5slice.ipynb.
Converted 05_train_adjacent_02_3c.ipynb.
Converted 05_train_adjacent_02_3c_stg1.ipynb.
Converted 06_seutao_features.ipynb.
Converted 06_seutao_features_01_simple_lstm_20ep.ipynb.
Converted 06_seutao_features_01b_simple_lstm_10ep.ipynb.
Converted 06_seutao_features_01c_simple_lstm_meta.ipynb.
Converted 06_seutao_features_01d_simple_lstm_meta_full.ipynb.
Converted 06_seutao_features_0