# Run trained model

In [1]:
import json
import os
from typing import Tuple

import numpy as np
import skimage.graph as graph
import tensorflow as tf
from rockml.data.adapter.seismic.segy import PostStackDatum
from rockml.data.adapter.seismic.segy.poststack import PostStackAdapter2D
from rockml.data.pipeline import Pipeline
from rockml.data.transformations import Composer
from rockml.data.transformations.seismic.image import Crop2D, ScaleIntensity
from seisfast.io.horizon import Writer

import utils

gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[0], 'GPU')


In [10]:
def get_input_seed(label_img: tf.Tensor,
                   canvas: tf.Tensor,
                   seed_class: int) -> Tuple[int, tf.Variable]:
    input_seed = 0
    for i in range(label_img.shape[0]):
        el = label_img[i, 8]
        if el == seed_class:
            input_seed = i
            break
    canvas_tile = tf.Variable(canvas)
    canvas_tile[input_seed, 0].assign(1)
    return input_seed, canvas_tile


def get_prediction_tiles(predictions: np.ndarray) -> np.ndarray:
    return np.argmax(predictions, axis=3)


def run_model(canvas: tf.Tensor,
              model: tf.keras.Model,
              inline_datum: Tuple[tf.Tensor, tf.Tensor],
              break_tiles_info) -> np.ndarray:
    seed = break_tiles_info[0]
    size = break_tiles_info[1]
    stride = break_tiles_info[2]

    feat, label = inline_datum
    feat = tf.reshape(feat, [feat.shape[0], feat.shape[1]])

    canvas_tile = tf.Variable(canvas)

    v_seed = seed

    old_v_center = seed

    for h_seed in range(0, (feat.shape[1] - size), stride):
        tl_seed = (v_seed, h_seed)

        feat_window = feat[tl_seed[0]: tl_seed[0] + size, tl_seed[1]: tl_seed[1] + size]
        canvas_window = canvas_tile[tl_seed[0]: tl_seed[0] + size, tl_seed[1]: tl_seed[1] + size]

        if feat_window.shape != (size, size):
            break

        mini_feat_reshape = tf.reshape(feat_window, [feat_window.shape[0], feat_window.shape[1], 1])
        canvas_window_reshape = tf.reshape(canvas_window, [feat_window.shape[0], feat_window.shape[1], 1])
        model_input = tf.stack([mini_feat_reshape, canvas_window_reshape], axis=2)
        model_input = tf.reshape(model_input, [1, feat_window.shape[0], feat_window.shape[1], 2])

        prediction = model(model_input)
        output_tile = tf.reshape(get_prediction_tiles(prediction.numpy()), [size, size])

        output_tile = tf.dtypes.cast(output_tile, tf.uint8, name=None)
        # temp = (output_tile - 1) * np.random.random_integers(1, size, size ** 2).reshape((size, size))
        # path, _ = graph.shortest_path(temp, reach=1, axis=-1, output_indexlist=True)
        # zeros = np.zeros((size, size), dtype=np.uint8)
        # for t in path:
        #     zeros[t] = 1

        canvas_tile[tl_seed[0]: tl_seed[0] + size, tl_seed[1]: tl_seed[1] + size].assign(output_tile)

        next_v_center = np.argmax(canvas_tile[:, h_seed + int(size / 2) + stride] > 0) + 1
        if np.abs(old_v_center - next_v_center) <= size:
            v_seed = next_v_center - int(size / 2)
            old_v_center = next_v_center

        if v_seed < 0:
            v_seed = 0
        elif v_seed > feat.shape[0] + size:
            v_seed = feat.shape[0] - size

    return canvas_tile.numpy()


def export_horizons(segy: PostStackAdapter2D,
                    horizons: np.array,
                    path: str) -> None:
    for cls in horizons.keys():
        writer = Writer(os.path.join(path, f'{cls}.xyz'))
        writer.write('inlines', np.asarray(horizons[cls], dtype=np.float32), segy.segy_raw_data)


def get_horizons_from_max(line_number: int,
                          amplitudes: np.array,
                          horizons: np.array,
                          params: dict) -> dict:
    crop_left = params['dataset_info']['crop'][0],
    crop_top = params['dataset_info']['crop'][2]
    horizon_dict = {
        os.path.basename(l).split('.')[0]: [None] * horizons.shape[2] for l in
        params['dataset_info']['horizons_path_list']
    }

    bad_keys = []

    for idx, hrz_key in enumerate(horizon_dict.keys()):
        if np.sum(horizons[idx]) == 0:
            bad_keys.append(hrz_key)
            continue

        columns, rows = np.where(np.transpose(horizons[idx]))

        amp = 0
        for row, column in enumerate(columns):
            if horizon_dict[hrz_key][column] is None:
                horizon_dict[hrz_key][column] = [line_number, column + crop_left, rows[row] + crop_top]
            elif amplitudes[rows[row], column] > amp:
                horizon_dict[hrz_key][column] = [line_number, column + crop_left, rows[row] + crop_top]
                amp = amplitudes[rows[row], column]

        # Clean list up from None values (discontinuities)
        horizon_dict[hrz_key] = [rec for rec in horizon_dict[hrz_key] if rec]

    for bad_key in bad_keys:
        del horizon_dict[bad_key]

    return horizon_dict


def predict_line(model: tf.keras.Model,
                 params: dict,
                 seismic_datum: PostStackDatum) -> Tuple[np.ndarray, np.ndarray]:
    size = params['dataset_info']['tile_shape'][0]
    stride = params['dataset_info']['stride_shape'][0]

    feat = tf.dtypes.cast(seismic_datum.features, tf.uint8, name=None)
    label_img = tf.dtypes.cast(seismic_datum.label, tf.uint8, name=None)
    inline_datum = (feat, label_img)

    horizon_list = [None] * len(params['dataset_info']['horizons_path_list'])

    for idx, horizon in enumerate(params['dataset_info']['horizons_path_list']):
        canvas = tf.zeros(label_img.shape, tf.uint8)
        input_seed, canvas = get_input_seed(label_img, canvas, idx)
        canvas = utils.clean_label(canvas)

        seed = input_seed - int(size / 2)
        break_tiles_info = (seed, size, stride)

        if input_seed > 0:
            output_horizon = run_model(canvas, model, inline_datum, break_tiles_info)
            zeros = get_graph(output_horizon)
            # from PIL import Image
            # Image.fromarray(output_horizon*128).save('/home/sallesd/out.png')
            # Image.fromarray(zeros*128).save('/home/sallesd/zeros.png')
            # Image.fromarray((zeros+output_horizon)*120).save('/home/sallesd/prod.png')
            # horizon_list[idx] = output_horizon.astype(np.uint8) * 175
            horizon_list[idx] = zeros * 175

    raw_amplitudes = seismic_datum.features
    clean_horizons = [i if i is not None else np.zeros([feat.shape[0], feat.shape[1]]) for i in horizon_list]
    horizons = np.stack(clean_horizons)

    return raw_amplitudes, horizons


def get_graph(output_horizon):
    temp = (output_horizon - 1) * np.random.random_integers(
        1,
        output_horizon.shape[0],
        np.prod(output_horizon.shape)
    ).reshape(output_horizon.shape)
    path, _ = graph.shortest_path(temp, reach=1, axis=-1, output_indexlist=True)
    zeros = np.zeros(output_horizon.shape, dtype=np.uint8)
    for t in path:
        zeros[t] = 1
    return zeros

In [2]:
params = dict()
params['model_path'] = '/home/sallesd/ffn_model/best.h5'
params['output_dir'] = '/home/sallesd/ffn_model/hrz'
params['dataset_info'] = json.load(open('/home/sallesd/ffn_dataset/info.json'))

In [3]:
segy = PostStackAdapter2D(
    segy_path=params['dataset_info']['segy_info']['segy_path'],
    horizons_path_list=params['dataset_info']['horizons_path_list'],
    data_dict={'inline': [[150, 250]]}
)
scan_result = segy.initial_scan()
first_inline, last_inline = scan_result['range_inlines']

pre_proc = [
    Crop2D(
        crop_left=params['dataset_info']['crop'][0],
        crop_right=params['dataset_info']['crop'][1],
        crop_top=params['dataset_info']['crop'][2],
        crop_bottom=params['dataset_info']['crop'][3]
    ),
    ScaleIntensity(
        gray_levels=params['dataset_info']['gray_levels'],
        percentile=params['dataset_info']['percentile']
    ),
]

composer = Composer(transformations=pre_proc)
dataset = Pipeline(composer=composer).build_dataset(segy, 80, 80)

print(composer)
print(f'Number of lines: {len(dataset)}')   





In [14]:
utils.makedir(params['output_dir'])
model = tf.keras.models.load_model(params['model_path'])

for datum in dataset:
    print(f"Processing inline: {datum.line_number}")

    # Run prediction through number of horizons
    amplitudes, horizons = predict_line(
        model,
        params,
        datum,
    )

    # Write .xyz files
    horizon_list = get_horizons_from_max(
        datum.line_number,
        amplitudes,
        horizons,
        params
    )
    export_horizons(segy, horizon_list, params['output_dir'])

Processing inline: 150
Processing inline: 151
Processing inline: 152
Processing inline: 153
Processing inline: 154
