In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp

In [None]:
def remove_module_from_state_dict(state_dict):
    """Removes 'module.' from nn.Parallel models.  
    If module does not exist it just returns the state dict"""
    if list(state_dict.keys())[0].startswith('module'):
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:] # remove `module.`
            new_state_dict[name] = v           
        return new_state_dict
    elif list(state_dict.keys())[0].startswith('features.module'):
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[:9] + k[9+7:] # remove `module.`
            if name.startswith('features.'):
                new_state_dict[name] = v
        return new_state_dict
    else:
        return state_dict


In [None]:
def load_model(model_path):
    model_data = torch.load(model_path)
    #DEVICE = 'cuda'
    DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    decoder = model_data['decoder']
    encoder = model_data['encoder']
    #class_values = best_model['class_values']
    class_values = {'background': 0,
                    'oxide': 1,
                    'crack': 2}
    activation = 'softmax2d' if len(class_values) > 1 else 'sigmoid' #'softmax2d' for multicalss segmentation
    try:
        preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, 'imagenet')
    except ValueError:
        preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, 'imagenet+5k')
    model = getattr(smp, decoder)(encoder_name=encoder, 
                                          encoder_weights=None,
                                          classes=len(class_values),
                                          activation=activation)

    model.load_state_dict(remove_module_from_state_dict(model_data['state_dict']))
    model.eval()
    return model, preprocessing_fn

In [1]:
# https://github.com/choosehappy/PytorchDigitalPathology
def segmentation_models_inference(io, model, preprocessing_fn, device = None, batch_size = 8, patch_size = 512,
                                  num_classes=3, probabilities=None):

    # This will not output the first class and assumes that the first class is wherever the other classes are not!

    io = preprocessing_fn(io)
    io_shape_orig = np.array(io.shape)
    stride_size = patch_size // 2
    if device is None:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # add half the stride as padding around the image, so that we can crop it away later
    io = np.pad(io, [(stride_size // 2, stride_size // 2), (stride_size // 2, stride_size // 2), (0, 0)],
                mode="reflect")

    io_shape_wpad = np.array(io.shape)

    # pad to match an exact multiple of unet patch size, otherwise last row/column are lost
    npad0 = int(np.ceil(io_shape_wpad[0] / patch_size) * patch_size - io_shape_wpad[0])
    npad1 = int(np.ceil(io_shape_wpad[1] / patch_size) * patch_size - io_shape_wpad[1])

    io = np.pad(io, [(0, npad0), (0, npad1), (0, 0)], mode="constant")

    with warnings.catch_warnings():
        warnings.simplefilter('ignore')
        arr_out = sklearn.feature_extraction.image.extract_patches(io, (patch_size, patch_size, 3), stride_size)

    arr_out_shape = arr_out.shape
    arr_out = arr_out.reshape(-1, patch_size, patch_size, 3)

    # in case we have a large network, lets cut the list of tiles into batches
    output = np.zeros((0, num_classes, patch_size, patch_size))
    for batch_arr in divide_batch(arr_out, batch_size):
        arr_out_gpu = torch.from_numpy(batch_arr.transpose(0, 3, 1, 2).astype('float32')).to(device)

        # ---- get results
        output_batch = model.predict(arr_out_gpu)

        # --- pull from GPU and append to rest of output
        if probabilities is None:
            output_batch = output_batch.detach().cpu().numpy().round()
        else:
            output_batch = output_batch.detach().cpu().numpy()

        output = np.append(output, output_batch, axis=0)

    output = output.transpose((0, 2, 3, 1))

    # turn from a single list into a matrix of tiles
    output = output.reshape(arr_out_shape[0], arr_out_shape[1], patch_size, patch_size, output.shape[3])

    # remove the padding from each tile, we only keep the center
    output = output[:, :, stride_size // 2:-stride_size // 2, stride_size // 2:-stride_size // 2, :]

    # turn all the tiles into an image
    output = np.concatenate(np.concatenate(output, 1), 1)

    # incase there was extra padding to get a multiple of patch size, remove that as well
    output = output[0:io_shape_orig[0], 0:io_shape_orig[1], :]  # remove paddind, crop back
    if probabilities is None:
        return output[:, :, 1:].astype('bool')
    else:
        for i in range(num_classes-1): #don't care about background class
            output[:,:,i+1] = output[:,:,i+1] > probabilities[i]
        return output[:, :, 1:].astype('bool')

In [None]:
model_path = './segmentation_models/ebc_exp1/Unet__inceptionresnetv2__microscopynet__200__0.981.pth.tar'
model, preprocessing_fn = load_model(model_path)
im_path = 'image.tif'
im = imageio.imread(im_path)
im = gray2rgb(im)  # convert to color
im = img_as_ubyte(im)


segmentation = segmentation_models_inference(im, model, preprocessing_fn, batch_size=4, patch_size=512,
                                                     probabilities=None)

for i in segmentation.shape[2]:
    plt.imshow(segmentation[:,:,i], cmap=plt.cm.gray)
    plt.show()