# Classification Model

In [1]:
# export
# default_exp models
import sys
sys.path.append('..')
from fastai.basics import *
from fastai.layers import *
from fastai.vision.learner import create_body as create_body_2d

from faimed3d.learner import create_head, create_body, create_cnn_model_3d
from faimed3d.layers import AdaptiveConcatPool3d

In [2]:
from faimed3d.models.resnet import resnet18_3d

In [3]:
# export 
class SequentialModel(Module):
    def __init__(self, model):
        
        self.body = nn.Sequential( 
            model[0], 
            nn.AdaptiveMaxPool3d(1), 
        )
        self.head = nn.Sequential(
            AdaptiveConcatPool3d(1),            
            model[1][1:])
        
    def forward(self, x):
        features = [self.body(_x.unsqueeze(1)) for _x in torch.unbind(x, 1)]
        features = torch.cat(features, 2)
        return self.head(features)

In [4]:
m = create_cnn_model_3d(resnet18_3d, n_in = 1, n_out = 2)
m = SequentialModel(m)
m(torch.randn(2,2,1,16,16)).shape

torch.Size([2, 2])

In [5]:
# hide
from nbdev.export import *
notebook2script()

Converted 00_utils.ipynb.
Converted 01_preprocessing.ipynb.
Converted 02_augment.ipynb.
Converted 03_noise-reduction.ipynb.
Converted 04_classification_models.ipynb.
Converted 05_classification.ipynb.
Converted index.ipynb.
