# Facies segmentation Python Demo

This demo demonstrate how to run facies classification using OpenVINO&trade;

This model came from seismic interpretation tasks. Fasies is the overall characteristics of a rock unit that reflect its origin and differentiate the unit from others around it.  Mineralogy and sedimentary source, fossil content, sedimentary structures and texture distinguish one facies from another. Data are presented in the 3D arrays.

## Demo Output

The application uses Jupyter notebook to display 3d itkwidget with resulting instance classification masks.

## How It Works
Upon the start-up, the demo application loads a network and an given dataset file to the Inference Engine plugin. When inference is done, the application displays 3d itkwidget viewer with facies interpretation.

## Installation of dependencies

### Setup virtual-env

Step 1: You can install the required packages with the following command:

In [None]:
# install prerequinces for the demo
import sys
!{sys.executable} -m pip install -r requirements.txt

In [None]:
import sys
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import ipywidgets as widgets

from tqdm import tqdm
from itkwidgets import view
from collections import defaultdict

### Download model:

Step 1: Create model folder and download a model:

In [None]:
! cd notebooks/202-facies-segmentation
! mkdir model
! wget -O model/facies-segmentation-deconvnet.bin https://www.dropbox.com/s/x0c7ao8kebxykj1/facies-segmentation-deconvnet.bin?dl=1 
! wget -O model/facies-segmentation-deconvnet.xml https://www.dropbox.com/s/g288xdcd7xumqm7/facies-segmentation-deconvnet.xml?dl=1
! wget -O model/facies-segmentation-deconvnet.mapping https://www.dropbox.com/s/a7kge25hfpjnhvf/facies-segmentation-deconvnet.mapping?dl=1

### Download dataset:

The dataset is used from here: https://github.com/yalaudah/facies_classification_benchmark

In [None]:
# Download dataset
! mkdir data
! wget -O data/test2_seismic.npy https://www.dropbox.com/s/sbj2atyukpjgssx/test2_seismic.npy?dl=1

### Define useful functions

In [None]:
def get_config():
    config = defaultdict(str)
    config.update({
             "model": 'model/facies-segmentation-deconvnet.xml',
             "data_path": 'data/test2_seismic.npy',
             "name_classes": ['upper_ns', 'middle_ns',
                              'lower_ns', 'rijnland_chalk',
                              'scruff', 'zechstein'],
             "edge_one": (0, 30, 0),
             "edge_two": (500, 199, 244),
             "device":"CPU"})

    return config    

In [None]:
def normalize(data, mu=0, std=1):
    if not isinstance(data, np.ndarray):
        data = np.array(data)
    data = (data - data.flatten().mean())/data.flatten().std()
    return data * std + mu

def load_data(config):
    data_format = config["data_path"].split('.')[1]
    assert not (config["data_path"].split('.')[0] == '' or data_format == ''), \
        f'Invalid path to data file: {config["data_path"]}'
    if data_format == 'npy':
        data = np.load(config["data_path"])
    elif data_format == 'dat':
        data = np.fromfile(config["data_path"])
    elif data_format == 'segy':
        import segyio
        data = segyio.tools.cube(config["data_path"])
        data = np.moveaxis(data, -1, 0)
        data = np.ascontiguousarray(data, 'float32')
    else:
        assert False, f'Unsupported data format: {data_format}'

    data = normalize(data, mu=1e-8, std=0.2097654)
    print(f"[INFO] Dataset has been loaded, shape is {data.shape}")
    print(f"[INFO] Dataset mean is {data.flatten().mean():.5f}, std {data.flatten().std():.5f}")
    
    x_min =  min(config["edge_one"][0], config["edge_two"][0])
    x_max =  max(config["edge_one"][0], config["edge_two"][0])
    y_min =  min(config["edge_one"][1], config["edge_two"][1])
    y_max =  max(config["edge_one"][1], config["edge_two"][1])
    z_min =  min(config["edge_one"][2], config["edge_two"][2])
    z_max =  max(config["edge_one"][2], config["edge_two"][2])
    x_lim, y_lim, z_lim = data.shape
    assert x_min >=0 and y_min>=0 and z_min >= 0
    assert x_max < x_lim and y_max < y_lim and z_max < z_lim, "Invalid edges"
    sub_data = data[x_min: x_max , y_min: y_max, z_min: z_max]
    return sub_data

In [None]:
def reshape_model(net, shape, axis=None):
    if axis is None:
        index_of_dim = np.argmin(shape)
    else:
        index_of_dim = axis
    input_data_shape = list(shape)
    del input_data_shape[index_of_dim]

    input_net_info = net.input_info
    input_name = next(iter(input_net_info))
    input_net_shape = input_net_info[input_name].input_data.shape
    
    print(f"[INFO] Infer should be on {input_data_shape} resolution")
    if input_data_shape != input_net_shape[-2:]:
        net.reshape({input_name: [1, 1, *input_data_shape]})
        print(f"[INFO] Reshaping model to fit for slice shape: {input_data_shape}")
    else:
        print(f"[INFO] Use not reshaped model")

In [None]:
def infer_cube(exec_net, data, axis=None):
    if axis is None:
        index_of_dim = np.argmin(data.shape)
    else:
        index_of_dim = axis
    predicted_cube = np.empty(data.shape)
    size = data.shape[index_of_dim]
    for slice_index in tqdm(range(size)):
        if index_of_dim == 0:
            inp = data[slice_index, :, :]
            out = exec_net.infer(inputs={'input': inp})['output']
            out = np.argmax(out, axis=1).squeeze()
            predicted_cube[slice_index, :, :] = out
        if index_of_dim == 1:
            inp = data[:, slice_index, :]
            out = exec_net.infer(inputs={'input': inp})['output']
            out = np.argmax(out, axis=1).squeeze()
            predicted_cube[:, slice_index, :] = out
        if index_of_dim == 2:
            inp = data[:, :, slice_index]
            out = exec_net.infer(inputs={'input': inp})['output']
            out = np.argmax(out, axis=1).squeeze()
            predicted_cube[:, :, slice_index] = out
    return predicted_cube

In [None]:
def discrete_cmap(N, base_cmap=None):
    """Create an N-bin discrete colormap from the specified input map"""

    # Note that if base_cmap is a string or None, you can simply do
    #    return plt.cm.get_cmap(base_cmap, N)
    # The following works for string, None, or a colormap instance:

    base = plt.cm.get_cmap(base_cmap)
    color_list = base(np.linspace(0, 1, N))
    cmap_name = base.name + str(N)
    return base.from_list(cmap_name, color_list, N)

def show_legend(N, cmap_name):
    base = plt.cm.get_cmap(cmap_name)
    color_list = base(np.linspace(0, 1, N))
    print(color_list)
    
def show_legend(labels, cmap):
    N = len(labels)
    fig = plt.figure(figsize=(12, 6))
    ax1 = fig.add_axes([0.05, 0.80, 0.9, 0.15])
    cb1 = mpl.colorbar.ColorbarBase(ax1, cmap=cmap,
                                    ticks=np.arange(0, N, 1)/N + 1/(2*N),
                                    orientation='horizontal')
    cb1.ax.set_xticklabels(labels, fontsize = 20)
    cb1.set_label('Legend', fontsize = 24)
    plt.show()

### Get config and load dataset

In [None]:
config = get_config()
data = load_data(config)

## Running

### Load model

In [None]:
from openvino_extensions import get_extensions_path
from openvino.inference_engine import IECore

In [None]:
ie = IECore()
ie.add_extension(get_extensions_path(), "CPU")
net = ie.read_network(config["model"])

### Prepare model

In [None]:
reshape_model(net, data.shape, axis=1)
exec_net = ie.load_network(network=net, device_name=config["device"])

### Run model

In [None]:
predicted_data = infer_cube(exec_net, data, axis=1)

Now the inference of the model is running. Slices along the axis 1 are fed to the input and the result is combined into an output cube (`predicted_data`).

### Visualize original and predicted data

* Prepare origidal data viewer

In [None]:
viewer_orig_data = view(data, shadow=False)
count_of_greys = 100
viewer_orig_data.cmap = np.array([[i/count_of_greys, i/count_of_greys, i/count_of_greys] for i
                                  in range(count_of_greys)])

* Prepare predicted data viewer

In [None]:
cmap = discrete_cmap(len(config["name_classes"]), 'jet')
show_legend(config["name_classes"], cmap)
viewer_interpret_data = view(predicted_data, shadow=False)
viewer_interpret_data.cmap = cmap
widgets.link((viewer_interpret_data, 'camera'), (viewer_orig_data, 'camera')) # link widget cameras
pass

* Run render

In [None]:
viewer_interpret_data

In [None]:
viewer_orig_data

You can now see a visualization of the interpreted and raw seismic data. You can interactively use your mouse to rotate or zoom in and explore interpretated data.