In [1]:
#default_exp trainfull3d

In [2]:
#export
from rsna_retro.imports import *
from rsna_retro.metadata import *
from rsna_retro.preprocess import *
from rsna_retro.train import *
from rsna_retro.train3d import *

Loading imports


In [3]:
torch.cuda.set_device(3)

## 3d dataset with tfms

In [4]:
wrapped_tfms = [Wrap(tfm) for tfm in aug_transforms()]

In [5]:
dls = get_3d_dls(Meta.df_any, bs=8, tfms=wrapped_tfms)

## 3d model

In [8]:
xb,yb = dls.one_batch()
xb.shape, yb.shape

(torch.Size([8, 60, 3, 256, 256]), torch.Size([8, 60, 6]))

In [9]:
#export
def conv3d(ni,nf,ks=(5,3,3),s=(1,2,2),**kwargs):
    p = tuple([i//2 for i in ks])
    return ConvLayer(ni, nf, ks, stride=s, padding=p, **kwargs)

In [10]:

class ResBlock3D(nn.Module):
    "Resnet block from `ni` to `nh` with `stride`"
    @delegates(ConvLayer.__init__)
    def __init__(self, expansion, ni, nf, stride=1, groups=1, reduction=None, nh1=None, nh2=None, dw=False, g2=1,
                 sa=False, sym=False, norm_type=NormType.Batch, act_cls=defaults.activation, ndim=3,
                 pool=AvgPool, pool_first=True, **kwargs):
        super().__init__()
        norm2 = (NormType.BatchZero if norm_type==NormType.Batch else
                 NormType.InstanceZero if norm_type==NormType.Instance else norm_type)
        if nh2 is None: nh2 = nf
        if nh1 is None: nh1 = nh2
        ks = (3,3,3)
        nf,ni = nf*expansion,ni*expansion
        k0 = dict(norm_type=norm_type, act_cls=act_cls, ndim=ndim, **kwargs)
        k1 = dict(norm_type=norm2, act_cls=None, ndim=ndim, **kwargs)
        layers  = [conv3d(ni,  nh2, ks=ks, s=(1,stride,stride), groups=ni if dw else groups, **k0),
                   conv3d(nh2,  nf, ks=(3,3,3), s=1, groups=g2, **k1)
        ] if expansion == 1 else [
                   conv3d(ni,  nh1, ks=1, s=1, **k0),
                   conv3d(nh1, nh2, ks=ks, s=(1,stride,stride), groups=nh1 if dw else groups, **k0),
                   conv3d(nh2,  nf, ks=1, s=1, groups=g2, **k1)]
        self.convs = nn.Sequential(*layers)
        convpath = [self.convs]
        if reduction: convpath.append(SEModule(nf, reduction=reduction, act_cls=act_cls))
        if sa: convpath.append(SimpleSelfAttention(nf,ks=1,sym=sym))
        self.convpath = nn.Sequential(*convpath)
        idpath = []
        if ni!=nf: idpath.append(ConvLayer(ni, nf, 1, act_cls=None, ndim=ndim, **kwargs))
        if stride!=1: idpath.insert((1,0)[pool_first], pool(ks=(1,2,2), stride=(1,2,2), ndim=ndim, ceil_mode=True))
        self.idpath = nn.Sequential(*idpath)
        self.act = defaults.activation(inplace=True) if act_cls is defaults.activation else act_cls()

    def forward(self, x): return self.act(self.convpath(x) + self.idpath(x))

In [11]:
class Flat3d(Module):
    def forward(self, x): return x.view(*x.shape[:2],-1)

In [12]:

#export
class XResNet(nn.Sequential):
    @delegates(ResBlock3D)
    def __init__(self, block, expansion, layers, p=0.0, c_in=3, c_out=1000, stem_szs=(32,32,64),
                 widen=1.0, sa=False, act_cls=defaults.activation, **kwargs):
        store_attr(self, 'block,expansion,act_cls')
        stem_szs = [c_in, *stem_szs]
        stem = [conv3d(stem_szs[i], stem_szs[i+1], s=(1,2,2) if i==0 else (1,1,1), act_cls=act_cls, ndim=3)
                for i in range(3)]

        block_szs = [int(o*widen) for o in [64,128,256,512] +[256]*(len(layers)-4)]
        block_szs = [64//expansion] + block_szs
        blocks = [self._make_layer(ni=block_szs[i], nf=block_szs[i+1], blocks=l,
                                   stride=1 if i==0 else 2, sa=sa and i==len(layers)-4, **kwargs)
                  for i,l in enumerate(layers)]
        super().__init__(
            Batchify(),
            *stem, nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1)),
            *blocks,
            nn.AdaptiveAvgPool3d((None, 1, 1)), Batchify(), Flat3d(), nn.Dropout(p),
            nn.Linear(block_szs[-1]*expansion, c_out),
        )
        init_cnn(self)

    def _make_layer(self, ni, nf, blocks, stride, sa, **kwargs):
        return nn.Sequential(
            *[self.block(self.expansion, ni if i==0 else nf, nf, stride=stride if i==0 else 1,
                      sa=sa and i==(blocks-1), act_cls=self.act_cls, **kwargs)
              for i in range(blocks)])

In [13]:
def xres3d(**kwargs):
#     m = XResNet(ResBlock3D, expansion=1, layers=[1, 1, 1, 1], c_out=6, **kwargs)
    m = XResNet(ResBlock3D, expansion=1, layers=[2, 2,  2, 2], c_out=6, **kwargs)
    init_cnn(m)
    return m

In [14]:
m = xres3d().cuda()

In [15]:
learn = get_learner(dls, m, get_loss())
# learn.to_fp16()

In [16]:
learn.add_cb(DePadLoss())

<fastai2.learner.Learner at 0x7f7e98054e90>

In [17]:
# for xb,yb in progress_bar(dls.train): pass

In [18]:
learn.model

XResNet(
  (0): Batchify()
  (1): ConvLayer(
    (0): Conv3d(3, 32, kernel_size=(5, 3, 3), stride=(1, 2, 2), padding=(2, 1, 1), bias=False)
    (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (2): ConvLayer(
    (0): Conv3d(32, 32, kernel_size=(5, 3, 3), stride=(1, 1, 1), padding=(2, 1, 1), bias=False)
    (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (3): ConvLayer(
    (0): Conv3d(32, 64, kernel_size=(5, 3, 3), stride=(1, 1, 1), padding=(2, 1, 1), bias=False)
    (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (4): MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), dilation=1, ceil_mode=False)
  (5): Sequential(
    (0): ResBlock3D(
      (convs): Sequential(
        (0): ConvLayer(
          (0): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      

In [19]:
learn.summary()

XResNet (Input shape: ['8 x 60 x 3 x 256 x 256'])
Layer (type)         Output Shape         Param #    Trainable 
Batchify             8 x 3 x 60 x 256 x 256 0          False     
________________________________________________________________
Conv3d               8 x 32 x 60 x 128 x 128 4,320      True      
________________________________________________________________
BatchNorm3d          8 x 32 x 60 x 128 x 128 64         True      
________________________________________________________________
ReLU                 8 x 32 x 60 x 128 x 128 0          False     
________________________________________________________________
Conv3d               8 x 32 x 60 x 128 x 128 46,080     True      
________________________________________________________________
BatchNorm3d          8 x 32 x 60 x 128 x 128 64         True      
________________________________________________________________
ReLU                 8 x 32 x 60 x 128 x 128 0          False     
____________________________

In [20]:
do_fit(learn, 1, 1e-2)

epoch,train_loss,valid_loss,accuracy_multi,accuracy_any,time
0,5344.695801,01:36,,,


KeyboardInterrupt: 

In [None]:
# Timing fully3d
# with torch.no_grad():
#     %timeit -n 4 learn2.model(xb)

In [29]:
# Timing the old semi-3d model
# with torch.no_grad():
#     %timeit -n 6 m(xb)

# m2 = get_3d_head()
# config=dict(custom_head=m2)
# learn2 = get_learner(dls, xresnet34, get_loss(), config=config)
# hook = ReshapeBodyHook(learn2.model[0])
# learn2.add_cb(DePadLoss())

237 ms ± 24.2 ms per loop (mean ± std. dev. of 7 runs, 6 loops each)


## Export

In [1]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_metadata.ipynb.
Converted 01_preprocess.ipynb.
Converted 01_preprocess_mean_std.ipynb.
Converted 02_train.ipynb.
Converted 03_train3d.ipynb.
Converted 03_train3d_01_train3d.ipynb.
Converted 03_train3d_02_train_head.ipynb.
Converted 04_trainSeq_01_lstm.ipynb.
Converted 04_trainSeq_02_transformer.ipynb.
Converted 04_trainSeq_03_lstm_seutao.ipynb.
Converted 05_train_adjacent.ipynb.
Converted 05_train_adjacent_01_5c.ipynb.
Converted 05_train_adjacent_02_3c.ipynb.
Converted 06_seutao_features.ipynb.
This cell doesn't have an export destination and was ignored:
e
Converted 06_seutao_features_01_2ndPlace.ipynb.
This cell doesn't have an export destination and was ignored:
e
Converted 06_seutao_features_02_1stPlace.ipynb.
