## Start

In [1]:
import torch
torch.cuda.set_device(3)

from rsnautils import *
from fastai2.callback.data import *
from fastai2.patch_tables import patch_tables
from fastai2.test import *
patch_tables()

nw = 8

In [2]:
pre = 'xrn34_wgtd_deep'

In [3]:
src_x = path_cts/pre
src_y = path_ct_lbls/pre

In [4]:
sids = src_x.ls().attrgot('name')
test_eq(len(sids), 23017)

In [5]:
val_sid = df_comb.loc[val_sops].SeriesInstanceUID.unique()
idx = L.range(sids)
val_sid = set(list(val_sid))

In [6]:
mask = np.array([o in val_sid for o in sids])
s_splits = idx[~mask],idx[mask]
test_eq(len(s_splits[0]),19058)
test_eq(len(s_splits[1]),3959 )

In [7]:
def read_slice(s): return TensorCTScan(tensor((src_x/s).load_array()))
def read_slbls(s): return tensor((src_y/s).load_array()).float()

In [8]:
tfms = [[read_slice], [read_slbls]]
dsets = Datasets(sids, tfms, splits=s_splits)

In [9]:
# dls = dsets.dataloaders(None, after_batch=Cuda(), num_workers=8)
dls = DataLoaders(
    DataLoader(dsets.train, after_batch=Cuda(), num_workers=8, shuffle=True),
    DataLoader(dsets.valid, after_batch=Cuda(), num_workers=8)
)

dls.device = default_device()

In [10]:
x,y = dls.valid.one_batch()
n_final = x.shape[1]
x.shape,y.shape

(torch.Size([28, 256, 1, 1]), torch.Size([28, 6]))

In [11]:
loss_func = get_loss()

In [12]:
class Batchify(Module):
    def forward(self, x): return x[...,0,0][None].transpose(1,2)

class DeBatchify(Module):
    def forward(self, x): return x[0].transpose(0,1)

In [13]:
def conv3(ni,nf,stride=1, norm_type=NormType.Batch):
#     return ConvLayer(ni, nf, 5, stride=(1,stride,stride), ndim=3)
    return ConvLayer(ni, nf, (5,3,3), stride=(1,stride,stride), ndim=3, padding=(2,1,1), norm_type=norm_type)

In [14]:
def conv1(ni,nf,stride=1, norm_type=NormType.Batch):
    return ConvLayer(ni, nf, 5, ndim=1, padding=2, norm_type=norm_type)

m = nn.Sequential(Batchify(),
    conv1(n_final,256,1, norm_type=None), # 8
    conv1(    256,128,1, norm_type=None), # 4
    conv1(    128, 64,1, norm_type=None), # 2
    DeBatchify(), Flatten(), nn.Linear(64,6))

init_cnn(m)

In [16]:
learn = Learner(dls, m, loss_func=loss_func, metrics=metrics)
learn.model[-1].bias.data = to_device(logit(avg_lbls))

In [17]:
#ks=5,3,3 xrn34_wgtd_deep
learn.fit_one_cycle(2,1e-3)

epoch,train_loss,valid_loss,accuracy_multi,accuracy_any,None,accum,time
0,0.047986,0.078214,0.976942,0.960953,0.078214,0.070074,07:36
1,0.046641,0.06409,0.979135,0.961831,0.06409,0.063143,05:30


In [18]:
learn.save(f'{pre}-3d')

In [20]:
preds,targs = learn.get_preds()
p,t = to_device((preds,targs))
loss_func(logit(p), t)

tensor(0.0636, device='cuda:3')

In [21]:
preds.shape

torch.Size([136785, 6])

In [None]:
#export
class XRN3dHead(nn.Sequential):
    def __init__(self, block_szs, layers, c_in=3, c_out=1000, act_cls=defaults.activation):
        blocks = [self._make_layer(block_szs[i], block_szs[i+1], l, 1 if i==0 else 2, act_cls=act_cls)
                  for i,l in enumerate(layers)]
        super().__init__(*blocks,
            nn.AdaptiveAvgPool(1, ndim=3), Flatten(),
            nn.Linear(block_szs[-1], c_out),
        )
        init_cnn(self)

    def _make_layer(self, expansion, ni, nf, blocks, stride, sa, sym, act_cls):
        return nn.Sequential(
            *[ResBlock(expansion, ni if i==0 else nf, nf, stride if i==0 else 1, ndim=3, act_cls=act_cls)
              for i in range(blocks)])