# Parallel UNet

In [None]:
import sys
sys.path.append('..')

In [None]:
# export
# default_exp parallel
from faimed3d.all import *
from fastai.vision.all import create_body, create_unet_model, hook_outputs, DynamicUnet
from fastai.vision.learner import _default_meta, _add_norm, model_meta 
from fastai.vision.models.unet import  _get_sz_change_idxs



In [None]:
prostate = pd.read_csv('../../dl-prostate-mapping/data/prostata-train.csv', sep = ',')[['t2_cropped_dcm', 't2_mask_cropped']]
dls = SegmentationDataLoaders3D.from_df(prostate, path='/media/..', 
                                codes = ['void', 'transitional', 'peripheral', 'cancer'],
                                 item_tfms = ResizeCrop3D(crop_by = (0., 0.1, 0.1), resize_to = (20, 100, 100), perc_crop = True), 
                                 bs = 2, val_bs = 2)

In [None]:
# export
class UnetBlock3D(nn.Module):
    "A quasi-UNet block, using `ConvTranspose3d` for upsampling`."
    @delegates(ConvLayer.__init__)
    def __init__(self, up_in_c, x_in_c, final_div=True, blur=False, act_cls=defaults.activation,
                 self_attention=False, init=nn.init.kaiming_normal_, norm_type=None, **kwargs):
        self.up = ConvTranspose3D(up_in_c, up_in_c//2, blur=blur, act_cls=act_cls, norm_type=norm_type, **kwargs)
        self.bn = BatchNorm(x_in_c, ndim=3)
        ni = up_in_c//2 + x_in_c
        nf = ni if final_div else ni//2
        self.conv1 = ConvLayer(ni, nf, ndim=3, act_cls=act_cls, norm_type=norm_type, **kwargs)
        self.conv2 = ConvLayer(nf, nf, ndim=3, act_cls=act_cls, norm_type=norm_type,
                               xtra=SelfAttention(nf) if self_attention else None, **kwargs)
        self.relu = act_cls()
        apply_init(nn.Sequential(self.conv1, self.conv2), init)

    def forward(self, up_in, lwr_features):
        up_out = self.up(up_in)
        ssh = lwr_features.shape[-3:]
        if ssh != up_out.shape[-3:]:
            up_out = F.interpolate(up_out, lwr_features.shape[-3:], mode='nearest')
        cat_x = self.relu(torch.cat([up_out, self.bn(lwr_features)], dim=1))
        return self.conv2(self.conv1(cat_x))
        

In [None]:
def build_backbone(backbone, ni, **kwargs):
    model = backbone(ni,  **kwargs) #output_stride, BatchNorm)    
    def forward(x):
        x1=model.stem(x)
        x2=model.layer1(x1)
        x3=model.layer2(x2)
        x4=model.layer3(x3)
        x5=model.layer4(x4)
        return x1, x2, x3, x4, x5
    model.forward = forward
    return model

In [None]:
class ParallelUnet(nn.Module):
    def __init__(self, encoder, n_out, img_size, blur=False, blur_final=True, self_attention=False,
                 y_range=None, last_cross=True, bottle=False, act_cls=defaults.activation,
                 init=nn.init.kaiming_normal_, norm_type=None, **kwargs):
        super(ParallelUnet, self).__init__()

        self.stem = encoder.stem
        self.layer1 = encoder.layer1
        self.layer2 = encoder.layer2
        self.layer3 = encoder.layer3
        self.layer4 = encoder.layer4
        
    
        
    