In [None]:
#default_exp models
from nbdev.showdoc import show_doc

# Models

> Pytorch segmentation models.

In [None]:
#export
import torch, numpy as np
import cv2
import segmentation_models_pytorch as smp
from fastcore.basics import patch
from fastdownload import download_url
from pathlib import Path
import sys, subprocess
from pip._internal.operations import freeze
from transformers import SegformerForSemanticSegmentation

In [None]:
#hide
from fastcore.test import *

In [None]:
#export
# https://github.com/qubvel/segmentation_models.pytorch#architectures-
ARCHITECTURES =  ['SegFormer', 'Unet', 'UnetPlusPlus', 'FPN', 'PAN', 'PSPNet', 'Linknet', 'DeepLabV3', 'DeepLabV3Plus'] #'MAnet',

# https://github.com/qubvel/segmentation_models.pytorch#encoders-
ENCODERS = [*smp.encoders.encoders.keys()]

## Transformers/Huggingface Integration

In [None]:
#export
class SegFormer(torch.nn.Module):
    def __init__(self, classes=2, in_channels=1, **kwargs):
        super(SegFormer, self).__init__()
        self.segformer = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0", 
                                                             ignore_mismatched_sizes=True,
                                                             num_channels=in_channels,
                                                             num_labels=classes,
                                                             #id2label=id2label, label2id=label2id,
                                                             reshape_last_stage=True)

    def forward(self, pixel_values):
        outputs = self.segformer(pixel_values=pixel_values)
        return torch.nn.functional.interpolate(outputs.logits, scale_factor=4)

## Segmenation Models Pytorch Integration

From the website: 

- High level API (just two lines to create a neural network)
- 9 models architectures for binary and multi class segmentation (including legendary Unet)
- 104 available encoders
- All encoders have pre-trained weights for faster and better convergence

See https://github.com/qubvel/segmentation_models.pytorch for API details.

In [None]:
#export
def get_pretrained_options(encoder_name):
    'Return available options for pretrained weights for a given encoder'
    options = smp.encoders.encoders[encoder_name]['pretrained_settings'].keys()
    return [*options, None]

In [None]:
#export 
def create_smp_model(arch, **kwargs):
    'Create segmentation_models_pytorch model'
    
    assert arch in ARCHITECTURES, f'Select one of {ARCHITECTURES}'
        
    if arch=="Unet": model =  smp.Unet(**kwargs)
    elif arch=="UnetPlusPlus": model = smp.UnetPlusPlus(**kwargs)
    elif arch=="MAnet":model = smp.MAnet(**kwargs)
    elif arch=="FPN": model = smp.FPN(**kwargs)
    elif arch=="PAN": model = smp.PAN(**kwargs)
    elif arch=="PSPNet": model = smp.PSPNet(**kwargs)
    elif arch=="Linknet": model = smp.Linknet(**kwargs)
    elif arch=="DeepLabV3": model = smp.DeepLabV3(**kwargs)
    elif arch=="DeepLabV3Plus": model = smp.DeepLabV3Plus(**kwargs)
    elif arch=="SegFormer": model = SegFormer(**kwargs)
    else: raise NotImplementedError
    
    setattr(model, 'kwargs', kwargs)    
    return model

In [None]:
# Tests
bs = 2
tile_shapes = [512] #1024
in_channels = [1] #1,3,4
classes = [2] # 2,5
encoders = ENCODERS[1:2]#+ENCODERS[-1:]

for ts in tile_shapes:
    for in_c in in_channels:
        for c in classes:
            inp = torch.randn(bs, in_c, ts, ts)
            out_shape = [bs, c, ts, ts]
            for arch in ARCHITECTURES:
                for encoder_name in encoders:
                    model = create_smp_model(arch=arch, 
                                             encoder_name=encoder_name,
                                             encoder_weights=None,
                                             in_channels=in_c, 
                                             classes=c)
                    out = model(inp)
                    test_eq(out.shape, out_shape)
del model

Some weights of the model checkpoint at nvidia/mit-b0 were not used when initializing SegformerForSemanticSegmentation: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_fuse.weight', 'decode_head.batch_norm.weight', 'decode_head.linear_c.1.proj.weight', 'decode_head.batch_n

In [None]:
#export
def save_smp_model(model, arch, path, stats=None, pickle_protocol=2):
    'Save smp model, optionally including  stats'
    path = Path(path)
    state = model.state_dict()
    save_dict = {'model': state, 'arch': arch, 'stats': stats, **model.kwargs}
    torch.save(save_dict, path, pickle_protocol=pickle_protocol, _use_new_zipfile_serialization=False)
    return path

In [None]:
arch = 'Unet'
path = 'tst.pth'
stats = (1,1)
kwargs = {'encoder_name': 'resnet34'}
tst = create_smp_model(arch, **kwargs)
path = save_smp_model(tst, arch, path, stats=stats)

In [None]:
arch = 'SegFormer'
path = 'tst.pth'
stats = (1,1)
kwargs = {'encoder_name': 'resnet34'}
tst = create_smp_model(arch, **kwargs)
path = save_smp_model(tst, arch, path, stats=stats)

Some weights of the model checkpoint at nvidia/mit-b0 were not used when initializing SegformerForSemanticSegmentation: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_fuse.weight', 'decode_head.batch_norm.weight', 'decode_head.linear_c.1.proj.weight', 'decode_head.batch_n

In [None]:
#export
def load_smp_model(path, device=None, strict=True, **kwargs):
    'Loads smp model from file '
    path = Path(path)
    if isinstance(device, int): device = torch.device('cuda', device)
    elif device is None: device = 'cpu'  
    model_dict = torch.load(path, map_location=device)
    state = model_dict.pop('model')    
    stats = model_dict.pop('stats')    
    model = create_smp_model(**model_dict)
    model.load_state_dict(state, strict=strict)
    return model, stats

In [None]:
tst2, stats2 = load_smp_model(path)
for p1, p2 in zip(tst.parameters(), tst2.parameters()):
    test_eq(p1.detach(), p2.detach())
test_eq(stats, stats2)
path.unlink()

Some weights of the model checkpoint at nvidia/mit-b0 were not used when initializing SegformerForSemanticSegmentation: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_fuse.weight', 'decode_head.batch_norm.weight', 'decode_head.linear_c.1.proj.weight', 'decode_head.batch_n

## Cellpose integration

for reliable cell and nucleus segmentation. Visit [cellpose](https://github.com/MouseLand/cellpose) for more information. 

Cellpose integration for deepflash2 is tested on version 0.6.6.dev13+g316927e

In [None]:
#export
def check_cellpose_installation(show_progress=True):
    tarball = 'cellpose-0.6.6.dev13+g316927e.tar.gz' # '316927eff7ad2201391957909a2114c68baee309'
    try: 
        extract = [x for x in freeze.freeze() if x.startswith('cellpose')][0][-15:]
        assert extract==tarball[-15:]
    except:
        print(f'Installing cellpose. Please wait.')
        home_dir = Path.home()/'.deepflash2'
        home_dir.mkdir(exist_ok=True, parents=True)
        url = f'https://github.com/matjesg/deepflash2/releases/download/0.1.4/{tarball}'
        file = download_url(url, home_dir, show_progress=show_progress)
        subprocess.check_call([sys.executable, "-m", "pip", "install", '--no-deps', file.as_posix()])

In [None]:
#export
def get_diameters(masks):
    'Get diameters from deepflash2 prediction'
    from cellpose import utils
    diameters = []
    for m in masks:
        _, comps = cv2.connectedComponents(m.astype('uint8'), connectivity=4)
        diameters.append(utils.diameters(comps)[0])
    return int(np.array(diameters).mean())

In [None]:
#export
def run_cellpose(probs, masks, model_type='nuclei', diameter=0, min_size=-1, gpu=True):
    'Run cellpose on deepflash2 predictions'
    check_cellpose_installation()

    if diameter==0: 
        diameter = get_diameters(masks)
    print(f'Using diameter of {diameter}')
    
    from cellpose import models, dynamics, utils
    @patch
    def _compute_masks(self:models.CellposeModel, dP, cellprob, p=None, niter=200,
                        flow_threshold=0.4, interp=True, do_3D=False, min_size=15, resize=None, **kwargs):
        """ compute masks using dynamics from dP and cellprob """
        if p is None:
            p = dynamics.follow_flows(-1 * dP * mask / 5., niter=niter, interp=interp, use_gpu=self.gpu)
        maski = dynamics.get_masks(p, iscell=mask, flows=dP, threshold=flow_threshold if not do_3D else None)
        maski = utils.fill_holes_and_remove_small_masks(maski, min_size=min_size)
        if resize is not None:
            maski = transforms.resize_image(maski, resize[0], resize[1], 
                                            interpolation=cv2.INTER_NEAREST)
        return maski, p
    
    model = models.Cellpose(gpu=gpu, model_type=model_type)
    cp_masks = []
    for prob, mask in zip(probs, masks):
        cp_pred, _, _, _ = model.eval(prob, 
                                       net_avg=True,
                                       augment=True,
                                       diameter=diameter, 
                                       normalize=False,
                                       min_size=min_size,
                                       resample=True,
                                       channels=[0,0])
        cp_masks.append(cp_pred)
    return cp_masks 

In [None]:
probs = [np.random.rand(512,512)]
masks = [x>0. for x in probs]
cp_preds = run_cellpose(probs, masks, diameter=17.)
test_eq(probs[0].shape, cp_preds[0].shape)

Using diameter of 17.0
2021-12-08 13:48:42,625 [INFO] WRITING LOG OUTPUT TO /media/data/home/mag01ud/.cellpose/run.log
2021-12-08 13:48:43,818 [INFO] ** TORCH CUDA version installed and working. **
2021-12-08 13:48:43,818 [INFO] >>>> using GPU
2021-12-08 13:48:43,861 [INFO] ~~~ FINDING MASKS ~~~
2021-12-08 13:48:46,569 [INFO] >>>> TOTAL TIME 2.71 sec


# Export -

In [None]:
#hide
from nbdev.export import *
notebook2script()

Converted 00_learner.ipynb.
Converted 01_models.ipynb.
Converted 02_data.ipynb.
Converted 05_losses.ipynb.
Converted 06_utils.ipynb.
Converted 07_tta.ipynb.
Converted 08_gui.ipynb.
Converted 09_gt.ipynb.
Converted add_information.ipynb.
Converted index.ipynb.
Converted model_library.ipynb.
Converted tutorial.ipynb.
Converted tutorial_gt.ipynb.
Converted tutorial_pred.ipynb.
Converted tutorial_train.ipynb.
