In [None]:
#default_exp models
from nbdev.showdoc import show_doc

# Models

> Pytorch segmentation models.

In [None]:
#export
import torch
import segmentation_models_pytorch as smp

In [None]:
#hide
from fastcore.test import *

In [None]:
#export
# https://github.com/qubvel/segmentation_models.pytorch#architectures-
ARCHITECTURES =  ['Unet', 'UnetPlusPlus', 'FPN', 'PAN', 'PSPNet', 'Linknet', 'DeepLabV3', 'DeepLabV3Plus'] #'MAnet',

# https://github.com/qubvel/segmentation_models.pytorch#encoders-
ENCODERS = [*smp.encoders.encoders.keys()]

## Segmenation Models Pytorch Integration

From the website: 

- High level API (just two lines to create a neural network)
- 9 models architectures for binary and multi class segmentation (including legendary Unet)
- 104 available encoders
- All encoders have pre-trained weights for faster and better convergence

See https://github.com/qubvel/segmentation_models.pytorch for API details.

In [None]:
#export
def get_pretrained_options(encoder_name):
    'Return available options for pretrained weights for a given encoder'
    options = smp.encoders.encoders[encoder_name]['pretrained_settings'].keys()
    return [*options, None]

In [None]:
#export 
def create_smp_model(arch, **kwargs):
    'Create segmentation_models_pytorch model'
    
    assert arch in ARCHITECTURES, f'Select one of {ARCHITECTURES}'
        
    if arch=="Unet": model =  smp.Unet(**kwargs)
    elif arch=="UnetPlusPlus": model = smp.UnetPlusPlus(**kwargs)
    elif arch=="MAnet":model = smp.MAnet(**kwargs)
    elif arch=="FPN": model = smp.FPN(**kwargs)
    elif arch=="PAN": model = smp.PAN(**kwargs)
    elif arch=="PSPNet": model = smp.PSPNet(**kwargs)
    elif arch=="Linknet": model = smp.Linknet(**kwargs)
    elif arch=="DeepLabV3": model = smp.DeepLabV3(**kwargs)
    elif arch=="DeepLabV3Plus": model = smp.DeepLabV3Plus(**kwargs)
    else: raise NotImplementedError
    
    setattr(model, 'kwargs', kwargs)    
    return model

In [None]:
# Tests
bs = 2
tile_shapes = [512] #1024
in_channels = [1] #1,3,4
classes = [2] # 2,5
encoders = ENCODERS[1:2]#+ENCODERS[-1:]

for ts in tile_shapes:
    for in_c in in_channels:
        for c in classes:
            inp = torch.randn(bs, in_c, ts, ts)
            out_shape = [bs, c, ts, ts]
            for arch in ARCHITECTURES:
                for encoder_name in encoders:
                    model = create_smp_model(arch=arch, 
                                             encoder_name=encoder_name,
                                             encoder_weights=None,
                                             in_channels=in_c, 
                                             classes=c)
                    out = model(inp)
                    test_eq(out.shape, out_shape)

In [None]:
#export
def save_smp_model(model, arch, file,  stats=None, pickle_protocol=2):
    'Save smp model, optionally including  stats'
    state = model.state_dict()
    save_dict = {'model': state, 'arch': arch, 'stats': stats, **model.kwargs}
    torch.save(save_dict, file, pickle_protocol=pickle_protocol, _use_new_zipfile_serialization=False)

In [None]:
arch = 'Unet'
file = 'tst.pth'
stats = (1,1)
kwargs = {'encoder_name': 'resnet34'}
tst = create_smp_model(arch, **kwargs)
save_smp_model(tst, arch, file, stats=stats)

In [None]:
#export
def load_smp_model(file, device=None, strict=True, **kwargs):
    'Loads smp model from file '
    if isinstance(device, int): device = torch.device('cuda', device)
    elif device is None: device = 'cpu'  
    model_dict = torch.load(file, map_location=device)
    state = model_dict.pop('model')    
    stats = model_dict.pop('stats')    
    model = create_smp_model(**model_dict)
    model.load_state_dict(state, strict=strict)
    return model, stats

In [None]:
tst2, stats2 = load_smp_model(file)
for p1, p2 in zip(tst.parameters(), tst2.parameters()):
    test_eq(p1.detach(), p2.detach())
test_eq(stats, stats2)

# Export -

In [None]:
#hide
from nbdev.export import *
notebook2script()

Converted 00_learner.ipynb.
Converted 01_models.ipynb.
Converted 02_data.ipynb.
Converted 03_metrics.ipynb.
Converted 05_losses.ipynb.
Converted 06_utils.ipynb.
Converted 07_tta.ipynb.
Converted 08_gui.ipynb.
Converted 09_gt.ipynb.
Converted add_information.ipynb.
Converted deepflash2.ipynb.
Converted gt_estimation.ipynb.
Converted index.ipynb.
Converted model_library.ipynb.
Converted predict.ipynb.
Converted train.ipynb.
Converted tutorial.ipynb.
