In [None]:
import os
import sys
import numpy as np
import segyio 
from matplotlib import pyplot as plt
from tkinter import *
from tkinter import filedialog
import ipywidgets as widgets
from ipywidgets import interact

import plotly.express as px
import plotly.graph_objects as go
import numpy as np

In [None]:
def load_segy_file():
    '''
    Function to load a segy file
    
    Return file path of loaded segy file
    '''
    
    global filepath
    
    # Open file name
    file = filedialog.askopenfilename(initialdir = os.getcwd(),
    title = "Please select 2D/3D post-stack seismic data in segy format",
    filetypes = [('sgy files','*.sgy'),('segy files','*.segy'),('All files','*.*')])
    filepath = file
    print("File_path: {0}".format(filepath))
    
# Create root window
root = Tk()
root.geometry('300x200')  
root.title('AB')

# Open button to click in a GUI toolkit
open_button1 = Button(root, text = "Open a File", command = load_segy_file)
open_button2 = Button(root, text = "Close the Window", command = root.destroy)
open_button1.pack()
open_button2.pack()

# Run the application
root.mainloop()

In [None]:
def identify_seismic_data_parameters(filepath_in):    
    """     
    Function to identify data type as 2D or 3D and Post-Stack or Pre-Stack as well as seismic amplitude traces and geometry-related parameters
    
    Parameter:
    ----------
    filepath_in (str): file path of loaded segy file
    
    Returns:
    --------
    data_display (numpy.ndarray): Seismic amplitude traces to plot
    data_type, seismic_data_shape, cdp_no, sample_rate, twt , inline_number, xline_number, diff_inline, diff_xline    

    Author: Amir Abbas Babasafari (AB)
    """

    with segyio.open(    filepath_in, ignore_geometry=True) as f: #  filepath_in, ignore_geometry=True) as f:
        data_format = f.format

    # Supported inline and crossline byte locations
    inline_xline = [[189,193], [9,13], [9,21], [5,21]]
    state = False
    
    # Read segy data with the specified byte location of geometry 
    for k, byte_loc in enumerate(inline_xline):

        try:
            with segyio.open(filepath_in, iline = byte_loc[0], xline = byte_loc[1], ignore_geometry=False) as f:
                # Get the attributes
                seismic_data = segyio.tools.cube(f)
                n_traces = f.tracecount    
                # data = f.trace.raw[:].T 
                # tr = f.bin[segyio.BinField.Traces]
                tr = f.attributes(segyio.TraceField.TraceNumber)[-1]
                if not isinstance(tr, int):
                    tr = f.attributes(segyio.TraceField.TraceNumber)[-2] + 1
                tr = int(tr[0])
                spec = segyio.spec()
                spec.sorting = f.sorting
                data_sorting = spec.sorting == segyio.TraceSortingFormat.INLINE_SORTING
                twt = f.samples
                sample_rate = segyio.tools.dt(f) / 1000
                n_samples = f.samples.size
                
                # TRACE_SEQUENCE_FILE _ byte location:5
                TraceSequenceFile = []
                # FieldRecord _ byte location:9
                Field_Record = []
                # Trace_Field _ byte location:13
                Trace_Field = []
                # CDP _ byte location:21
                CDP = []
                # INLINE_3D _ byte location:189
                Inline_3D = []
                # CROSSLINE_3D _ byte location:193
                Crossline_3D = []

                for i in range(n_traces):
                    trace_no = f.attributes(segyio.TraceField.TRACE_SEQUENCE_FILE)[i]; TraceSequenceFile.append(trace_no)
                    field_record = f.attributes(segyio.TraceField.FieldRecord)[i]; Field_Record.append(field_record)
                    trace_field = f.attributes(segyio.TraceField.TraceNumber)[i]; Trace_Field.append(trace_field)
                    cdp = f.attributes(segyio.TraceField.CDP)[i]; CDP.append(cdp)
                    inline = f.attributes(segyio.TraceField.INLINE_3D)[i]; Inline_3D.append(inline)
                    xline = f.attributes(segyio.TraceField.CROSSLINE_3D)[i]; Crossline_3D.append(xline)

            inline3d = np.unique(Inline_3D)
            crossline3d = np.unique(Crossline_3D)
            fieldrecord = np.unique(Field_Record)
            tracefield = np.unique(Trace_Field)
            tracesequence = np.unique(TraceSequenceFile)
            cdpnumber = np.unique(CDP)

            state = True

        except:
            pass

        if state:
            
            # Identify data as 2D/3D and Post-stack/Pre-stack
            if len(seismic_data.shape) == 3:
                if seismic_data.shape[0] != 1:
                    data_type = 'Post-stack 3D'
                else:
                    if n_traces > tr > 1:   
                        data_type = 'Post-stack 3D'
                    else:
                        data_type = 'Post-stack 2D'
                    
            else:        
                if len(f.offsets) > 1:
                    if seismic_data.shape[0] == 1:
                        data_type = 'Pre-Stack 2D'
                    else:
                        data_type = 'Pre-Stack 3D'    
                else:
                    print('Error, Please check inline and crossline byte locations')

            # create geometry-related parameters
            if k==0:
                inline_number = inline3d 
                xline_number = crossline3d
            elif k==1:
                inline_number = fieldrecord 
                xline_number = tracefield
            elif k==2:
                inline_number = fieldrecord 
                xline_number = cdpnumber
            elif k==3:
                inline_number = tracesequence 
                xline_number = cdpnumber

            if data_type == 'Post-stack 3D':
                if len(inline_number) == 1 or len(xline_number) == 1:
                    pass
                else:
                    break
            else:
                break

    
    # reshape seismic data to the corresponding format based on data type
    try:
        inline, cdp, samples = seismic_data.shape
    except:
        print("Error, data was not loaded successfully, this could happen due to unsupported data format: {0}.".format(data_format)) 
        print("In addition, please check inline and crossline byte locations, that might not be supported in this script.")  
        print("Data format 4-byte IBM float and 4-byte IEEE float are supported.")

              
    if data_type == 'Post-stack 2D':
        data_display = seismic_data.reshape(cdp, samples).T
        cdp_no = np.arange(n_traces) 

        diff_inline = 1
        diff_xline = 1

        print('Data Type: {0}'.format(data_type))
        print('Seismic Data Shape (Time sample, CDP number) : {0}'.format(data_display.shape))

    elif data_type == 'Post-stack 3D':
        if inline == 1 and tr > 1 and n_traces % tr == 0:  
            inline_no =  n_traces / tr
            data_display = seismic_data.reshape(int(inline_no), int(tr), int(samples)).T
            xline_number = np.arange(tr)
            inline_number = np.arange(inline_no)
            cdp_no = xline_number

        else:  
            data_display = seismic_data.reshape(inline, cdp, samples).T
            cdp_no = np.arange(cdp)
            
        diff_inline = np.diff(inline_number)[0]
        diff_xline = np.diff(xline_number)[0]

        print('Data Type: {0}'.format(data_type))
        print('Seismic Data Shape (Time sample, crossline number, inline number) : {0}'.format(data_display.shape))

    return data_display, data_type, data_display.shape, cdp_no, sample_rate, twt, inline_number, xline_number, diff_inline, diff_xline
filepath = "/Users/moyinolorunadegbie/Downloads/Nanuq 3D AK.SGY"
data_display, data_type, seismic_data_shape, cdp_no, sample_rate, twt , inline_number, xline_number, diff_inline, diff_xline= identify_seismic_data_parameters(filepath)

In [None]:
"""Copyright China University of Petroleum East China, Yimin Dou, Kewen Li

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License."""

import torch
import matplotlib.pyplot as plt
import numpy as np
import os
import cv2


def normalization(data):
    _range = np.max(data) - np.min(data)
    return (data - np.min(data)) / _range


def normalization_tensor(data):
    _range = torch.max(data) - torch.min(data)
    return (data - torch.min(data)) / _range


def z_score(data):
    return (data - np.mean(data)) / np.std(data)


def cubing_prediction(model, data, device, infer_size):
    with torch.no_grad():
        ol = 1
        model.eval()
        n1, n2, n3 = args # infer_size
        input_tensor = torch.from_numpy(data)
        m1, m2, m3 = data.shape
        #c1 = np.ceil((m1 + ol) / (n1 - ol)).astype(np.int)
        #c2 = np.ceil((m2 + ol) / (n2 - ol)).astype(np.int)
        #c3 = np.ceil((m3 + ol) / (n3 - ol)).astype(np.int)
        
        c1 = int(np.ceil((m1 + ol) / (n1 - ol)))
        c2 = int(np.ceil((m2 + ol) / (n2 - ol)))
        c3 = int(np.ceil((m3 + ol) / (n3 - ol)))
        print([c1,c2,c3],"<<<<< dim " )##  <<<<<<<<<<<<<<<<
        p1 = (n1 - ol) * c1 + ol
        p2 = (n2 - ol) * c2 + ol
        p3 = (n3 - ol) * c3 + ol
        gp = torch.zeros((p1, p2, p3)).float() + 0.5
        gy = np.zeros((p1, p2, p3), dtype=np.single)
        gp[:m1, :m2, :m3] = input_tensor
        if device.type != 'cpu': gp = gp.half()
        print([c1,c2,c3],"<<<<< dim " )##  <<<<<<<<<<<<<<<<
        for k1 in range(c1):
            for k2 in range(c2):
                for k3 in range(c3):
                    b1 = k1 * n1 - k1 * ol
                    e1 = b1 + n1
                    b2 = k2 * n2 - k2 * ol
                    e2 = b2 + n2
                    b3 = k3 * n3 - k3 * ol
                    e3 = b3 + n3
                    gs = gp[b1:e1, b2:e2, b3:e3]
                    gs = normalization_tensor(gs[None, None, :, :, :]).to(device)
                    Y = model(gs).cpu().numpy()
                    gy[b1:e1, b2:e2, b3:e3] = gy[b1:e1, b2:e2, b3:e3] + Y[0, 0, :, :, :]
    return gy[:m1, :m2, :m3]


def prediction(model, data, device):
    model.eval()
    data = normalization(data)
    m1, m2, m3 = data.shape
    #c1 = (np.ceil(m1 / 16) * 16).astype(np.int)
    #c2 = (np.ceil(m2 / 16) * 16).astype(np.int)
    #c3 = (np.ceil(m3 / 16) * 16).astype(np.int)
    c1 = int((np.ceil(m1 / 16) * 16))
    c2 = int((np.ceil(m2 / 16) * 16))
    c3 = int((np.ceil(m3 / 16) * 16))
    input_tensor = np.zeros((c1, c2, c3), dtype=np.float32) + 0.5
    input_tensor[:m1, :m2, :m3] = data
    input_tensor = torch.from_numpy(input_tensor)[None, None, :, :, :].to(device)
    if device.type == 'cpu':
        input_tensor = input_tensor.float()
    else:
        input_tensor = input_tensor.half()
    with torch.no_grad():
        result = model(input_tensor).cpu().numpy()[0, 0, :m1, :m2, :m3]
    return result


def write_data(results, geo_cube, out_path, input_file, axis=0):
    file_name = os.path.split(input_file)[-1]
    geo_cube = normalization(geo_cube)
    assert axis == 0 or axis == 1 or axis == 2
    for i in range(geo_cube.shape[axis]):
        if axis == 0:
            result = results[i, :, :]
            geo = geo_cube[i, :, :]
        elif axis == 1:
            result = results[:, i, :]
            geo = geo_cube[:, i, :]
        else:
            result = results[:, :, i]
            geo = geo_cube[:, :, i]
        hm = plt.get_cmap('bone')(geo)[:, :, :-1]
        geo = plt.get_cmap('seismic')(geo)[:, :, :-1]
        logits = np.clip((result[:, :, None]), a_min=0, a_max=1)
        colormap = plt.get_cmap('jet')(logits[:, :, 0])[:, :, :-1]
        hm = np.where(logits > 0.5, colormap, hm)
        line = np.ones((geo.shape[0], 50, 3))
        result = np.concatenate((geo, line, hm), axis=1)
        result = (result * 255).astype(np.uint8)
        if axis == 0:
            cv2.imwrite(os.path.join(out_path, file_name, 'tline', f'{axis}_%05d_.png' % i),
                        cv2.cvtColor(result, cv2.COLOR_RGB2BGR))
        if axis == 1:
            cv2.imwrite(os.path.join(out_path, file_name, 'xline', f'{axis}_%05d_.png' % i),
                        cv2.cvtColor(result, cv2.COLOR_RGB2BGR))
        if axis == 2:
            cv2.imwrite(os.path.join(out_path, file_name, 'iline', f'{axis}_%05d_.png' % i),
                        cv2.cvtColor(result, cv2.COLOR_RGB2BGR))


def create_out_dir(output_dir, input_file):
    file_name = os.path.split(input_file)[-1]

    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
        os.mkdir(os.path.join(output_dir, file_name))
        os.mkdir(os.path.join(output_dir, file_name, 'iline'))
        os.mkdir(os.path.join(output_dir, file_name, 'xline'))
        os.mkdir(os.path.join(output_dir, file_name, 'tline'))

    if not os.path.exists(os.path.join(output_dir, file_name)):
        os.mkdir(os.path.join(output_dir, file_name))

    if not os.path.exists(os.path.join(output_dir, file_name, 'iline')):
        os.mkdir(os.path.join(output_dir, file_name, 'iline'))

    if not os.path.exists(os.path.join(output_dir, file_name, 'xline')):
        os.mkdir(os.path.join(output_dir, file_name, 'xline'))

    if not os.path.exists(os.path.join(output_dir, file_name, 'tline')):
        os.mkdir(os.path.join(output_dir, file_name, 'tline'))

    

    

In [None]:
"""
pip install thop
"""

In [None]:




#from utils import prediction,normalization
import time
#import matplotlib.pyplot as plt
from thop import profile
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
model = torch.jit.load('/Users/moyinolorunadegbie/Downloads/FaultNet_Gamma0.7.pt').to(device)
if device.type != 'cpu': model = model.half()
print('parameters_count:',count_parameters(model)/10**6,'MB')

In [None]:
# (128, 512, 384)
# data_display
# data = np.load('data/F3.npy').transpose((2, 0, 1))
#data_display = data_display[100:740,0:640,0:640]
#data_display = data_display[400:540,0:640,0:640]
#data_display = data_display[250:750,0:128,0:384]
data_display = data_display[250:750,0:200,0:384]
#data = data_display.transpose((2, 0, 1))
data = data_display
infer_size = data.shape 
print('infer size:',infer_size)
start = time.time()
result = prediction(model, data, device)
end = time.time()
print('1080Ti GPU infer time:',end-start,'s')


In [None]:
def get_result_img (geo_slice, result_slice):
    hm = plt.get_cmap('bone')(geo_slice)[:, :, :-1]
    geo_slice = plt.get_cmap('seismic')(geo_slice)[:, :, :-1]
    logits = np.clip((result_slice[:, :, None]), a_min=0, a_max=1)
    colormap = plt.get_cmap('jet')(logits[:, :, 0])[:, :, :-1]
    hm = np.where(logits > 0.5, colormap, hm)
    return geo_slice, hm

In [None]:
data = normalization(data)

_105 = 250
_16 = 344
_20 = 64
tline_slice = data[_105,:,:]
iline_slice = data[:,:,_16]
xline_slice = data[:,_20,:]
tline_result = result[_105,:,:]
iline_result = result[:,:,_16]
xline_result = result[:,_20,:]
tline_slice,tline_result = get_result_img(tline_slice,tline_result)
iline_slice,iline_result = get_result_img(iline_slice,iline_result)
xline_slice,xline_result = get_result_img(xline_slice,xline_result)
tline = np.concatenate((tline_slice,tline_result),1)
iline = np.concatenate((iline_slice,iline_result),1)
xline = np.concatenate((xline_slice,xline_result),1)
#plt.imshow(tline)

In [None]:
#plt.imshow(iline)

In [None]:
#plt.imshow(xline)

In [None]:
# type(  tline )

In [None]:
sgy_file = 'seismic_amplitude' # default
# sgy_file = 'property'

if sgy_file == 'seismic_amplitude':
    cmp = 'seismic'
else:
    cmp = 'property'

In [None]:
def plot(seismic_data, direction = None, segy = 'seismic'):

    '''
    Function to plot seismic amplitude traces
    '''
    #import scipy
    #import plotly.express as px
    #import plotly.graph_objects as go
    #from scipy.interpolate import interp3d
    #import numpy as np

    
    if segy == 'seismic':
        color =    'Picnic'  #'RdBu'  # 'Picnic' # 'Bluered_r' #"Seismic"  'RdBu'  'Picnic'
    elif segy == 'property':
        color = 'jet'

    # Plot seismic data 
    if direction == 'inline':
        extent = (np.min(xline_number), np.max(xline_number), np.max(twt), np.min(twt))
        xlabel = "Crossline No."
        ylabel = "Time (ms)"
        label = 'Interactive In-line Visualization'
        
    elif direction == 'xline':
        extent = (np.min(inline_number), np.max(inline_number), np.max(twt), np.min(twt))
        xlabel = "Inline No."
        ylabel = "Time (ms)"
        label = 'Interactive Cross-line Visualization'

    elif direction == 'time-slice':
        extent = (np.min(inline_number), np.max(inline_number), np.max(xline_number), np.min(xline_number))
        xlabel = "Inline No."
        ylabel = "Crossline No."
        label = 'Interactive Time-Slice Visualization'

    elif direction == '2D Line':
        extent = (np.min(xline_number), np.max(xline_number), np.max(twt), np.min(twt))
        xlabel = "CDP No."
        ylabel = "Time (ms)"
        label = '2D Line Visualization'

    #plt.figure(figsize=(10,10))
    #plt.imshow(seismic_data, interpolation = 'nearest', cmap = color, aspect = 'auto', 
              # vmin = -np.max(seismic_data), vmax = np.max(seismic_data), extent = extent)
    
    #z = np.arange(0, seismic_data.shape[2])
    #x = np.arange(0, seismic_data.shape[1])
    #y = np.arange(0, seismic_data.shape[0])
    
    #f = interp2d(x, y, image, kind='cubic')
    #seismic_data = interp3d(x, y, z, seismic_data  , kind = 'nearest')
    hx = [ np.min(xline_number), np.max(xline_number) ]
    hy = [np.min(twt), np.max(twt) ]
    
    plt = px.imshow(seismic_data  , color_continuous_scale= color , zmin= -np.max(seismic_data), zmax= np.max(seismic_data)  )
    #plt.title("{0} \n Seismic file name: {1}".format(label, os.path.splitext(os.path.basename(filepath))[0]))
    #plt.grid(True)
    #plt.colorbar()
    #plt.show()


    
  
    
    
    
    
    
    data = {'x': [1, 2, 3, 4, 500], 'y': [2, 4, 1, 3, 500]}

    # Create the line plot
    #fibg = go.Figure(data=go.Scatter(x=data['x'], y=data['y'], mode='lines',marker=dict(size=[3],color=['red']*4),line=dict( color='red', width=5)))
   
    #tu = tuple()
   # tu += fibg.data + plt.data

    #plt = go.Figure(data=tu)
    
    
    
    #fibg = go.Figure(data=go.Scatter(
    #x=[1, 2, 3, 500],
    #y=[10, 11, 12, 500],
    #mode='markers',
    #marker=dict(size=[10]*4,
     #           color=[ 1]*4) ))
    
    #tu = tuple()
    #tu += fibg.data + plt.data

    #plt = go.Figure(data=tu) # DATA : seismic_data
    
    
    
    
    plt.update_layout(  
    #width=500,
    #height=1500 ,
    #autosize= True , 
    #margin=dict(l=extent[0], r=extent[1], b=extent[2], t=extent[3]) ,
    #margin=dict(l= 10300, r= 11500 , b= 4000, t= 0) ,
    title=dict( text= "{0} \n Seismic file name: {1}".format(label, os.path.splitext(os.path.basename(filepath))[0]) ) ,
    xaxis_title=dict(text=xlabel ),
    yaxis_title=dict(text=ylabel),
    )
    
    #plt.update_layout(xaxis_range=[ extent[0],  extent[1]])
    #plt.update_layout(yaxis_range=[ extent[3] , extent[2] ])
    
    plt.update_layout(
    xaxis=dict(
        showgrid=True ),
    
    yaxis=dict(
        showgrid=True )  )
    
    
    
   
    #plt.update_layout(coloraxis_showscale=True)
    
    #plt.update_layout(coloraxis_colorbar=dict(
    #title=dict(text=""),
    #thicknessmode="pixels", # thickness=50,
    #lenmode="pixels", len=300,
    #yanchor="top", y=1 ,
    #xanchor="right", x=0.8 ,  
    #ticks="outside", # ticksuffix=" bills",
    #dtick=5 
    #))

    #plt.update_layout(width=700, height=900)
    
    plt.update_layout(  coloraxis_showscale=True , 
		autosize=False,
		width=1150,#800
		height=950,#800
		margin=dict(l=65, r=50, b=65, t=90) , 
	)

    
    #plt.show()
    plt.show()


   
"""
f segy == 'seismic':
        color = "Seismic"
    elif segy == 'property':
        color = 'jet'
        
fig.update_layout(
    autosize= True , 
    margin=dict(l=extent[0], r=extent[1], b=extent[2], t=extent[3]) ,
    title=dict( text= "{0} \n Seismic file name: {1}".format(label, os.path.splitext(os.path.basename(filepath))[0]) ) ,
    xaxis_title=dict(text=xlabel ),
    yaxis_title=dict(text=ylabel),
    )  ## 2 --> 3
    
fig.update_layout(
    xaxis=dict(
        showgrid=True ),
    
    yaxis=dict(
        showgrid=True )
)
   
fig.update_layout(coloraxis_showscale=True)
    
fig.update_layout(
    autosize= True )         zmin= -np.max(seismic_data), zmax= np.max(seismic_data),     animation_frame='time'
    width=500,
    height=500 )
    
px.imshow(seismic_data , color_continuous_scale= color , zmin= -np.max(seismic_data), zmax= np.max(seismic_data) )

"""

In [None]:
plot( tline , direction='inline', segy = cmp)

In [None]:
plot( iline , direction='inline', segy = cmp)

In [None]:
plot( xline , direction='inline', segy = cmp)

In [None]:
#!/usr/bin/env python3

def plot(seismic_data_, result, direction = None, segy = 'seismic'):
	
	'''
	Function to plot seismic amplitude traces
	'''
	#import scipy
	#import plotly.express as px
	#import plotly.graph_objects as go
	#from scipy.interpolate import interp3d
	#import numpy as np
	
	_20 = 0
	
	
	
	
	if segy == 'seismic':
		color =    'Picnic'  #'RdBu'  # 'Picnic' # 'Bluered_r' #"Seismic"  'RdBu'  'Picnic'
	elif segy == 'property':
		color = 'jet'
		
	# Plot seismic data 
	if direction == 'inline':
		extent = (np.min(xline_number), np.max(xline_number), np.max(twt), np.min(twt))
		xlabel = "Crossline No."
		ylabel = "Time (ms)"
		label = 'Interactive In-line Visualization'
		_20 = len(seismic_data_[0])
		print("Inlines")
		
	elif direction == 'xline':
		extent = (np.min(inline_number), np.max(inline_number), np.max(twt), np.min(twt))
		xlabel = "Inline No."
		ylabel = "Time (ms)"
		label = 'Interactive Cross-line Visualization'
		_20 = len(seismic_data_[0][0])
		print("Crosslines")
		
	elif direction == 'time-slice':
		extent = (np.min(inline_number), np.max(inline_number), np.max(xline_number), np.min(xline_number))
		xlabel = "Inline No."
		ylabel = "Crossline No."
		label = 'Interactive Time-Slice Visualization'
		_20 = len(seismic_data_)
		print("TWT's")
		
	elif direction == '2D Line':
		extent = (np.min(xline_number), np.max(xline_number), np.max(twt), np.min(twt))
		xlabel = "CDP No."
		ylabel = "Time (ms)"
		label = '2D Line Visualization'
		
	data = seismic_data_
	# xline_slice = data[:,_20,:]
	# xline_result = result[:,_20,:]
	# xline_slice,xline_result = get_result_img(xline_slice,xline_result)
	# xline = np.concatenate((xline_slice,xline_result),1)
	
	
	
	if direction == 'inline':
		
		slices = []
		for __20  in range(_20) :
			xline_slice,xline_result = get_result_img( data[:,__20,:] , result[:,__20,:] )
			xline = np.concatenate((xline_slice,xline_result),1)
			# __ = np.concatenate((get_result_img( data[:,__20,:] , result[:,__20,:] )[0]    , get_result_img( data[:,__20,:] , result[:,__20,:] )[1] )  ,1 )
			#slices.append(__)
			slices.append(xline)
		slices = np.array(slices)
		#xline = np.concatenate((xline_slice,xline_result),1)
			#plt = px.imshow(slices , color_continuous_scale= color , zmin= -np.max(seismic_data_), zmax= np.max(seismic_data_) ,  labels={"animation_frame": "Slice index"} )  
		plt = px.imshow(slices  , animation_frame=0, color_continuous_scale= color , zmin=seismic_data_.min(), zmax=seismic_data_.max() , labels={"animation_frame": "inline number"} )
		
	# fig = px.imshow(
	# slices,
	# animation_frame=0,   # use slice index as frame
	# color_continuous_scale="gray",
	# zmin=data.min(),
	# zmax=data.max(),
	# labels={"animation_frame": "Slice index"},
	# )
		plt.update_xaxes(title="Crossline Index", showticklabels=True)
		plt.update_yaxes(title="TWT", showticklabels=True)
		
		plt.update_layout(
		title="3D Seismic Cube – Scroll with Slider",
		#width=700,
		#height=700,
		coloraxis_showscale=True , 
		autosize=False,
		width=1150,#800
		height=800,#800
		margin=dict(l=65, r=50, b=65, t=90) , 
		)
		
		plt.show()
		
	elif direction == 'xline':
		
		slices = []
		for __20  in range(_20) :
			xline_slice,xline_result =  get_result_img( data[:,:,__20] , result[:,:,__20] )
			xline = np.concatenate((xline_slice,xline_result),1)
			#__ = np.concatenate((get_result_img( data[:,:,__20] , result[:,:,__20] )[0]    , get_result_img( data[:,:,__20] , result[:,:,__20] )[1] ) , 1 )
			#slices.append(__)
			slices.append(xline)
		slices = np.array(slices)
		#xline = np.concatenate((xline_slice,xline_result),1)
			#plt = px.imshow(slices  , color_continuous_scale= color , zmin= -np.max(seismic_data_), zmax= np.max(seismic_data_) ,  labels={"animation_frame": "Slice index"} )  
		
		plt = px.imshow(slices  , animation_frame=0, color_continuous_scale= color , zmin=seismic_data_.min(), zmax=seismic_data_.max() , labels={"animation_frame": "xline number"} )
		
	# fig = px.imshow(
	# slices,
	# animation_frame=0,   # use slice index as frame
	# color_continuous_scale="gray",
	# zmin=data.min(),
	# zmax=data.max(),
	# labels={"animation_frame": "Slice index"},
	# )
		plt.update_xaxes(title="Inline Index", showticklabels=True)
		plt.update_yaxes(title="TWT", showticklabels=True)

		plt.update_layout(
		title="3D Seismic Cube – Scroll with Slider",
		#width=700,
		#height=700,
		coloraxis_showscale=True , 
		autosize=False,
		width=1150,#800
		height=800,#800
		margin=dict(l=65, r=50, b=65, t=90) , 
		)
		
		plt.show()
		
	elif direction == 'time-slice':
		slices = []
		for __20  in range(_20) :
			xline_slice,xline_result =   get_result_img( data[__20,:,:] , result[__20,:,:] )
			xline = np.concatenate((xline_slice,xline_result),1)
			#__ = np.concatenate((get_result_img( data[__20,:,:] , result[__20,:,:] )[0]    , get_result_img( data[__20,:,:] , result[__20,:,:] )[1] ) , 1 )
			#slices.append(__)
			slices.append(xline)
		slices = np.array(slices)
		#xline = np.concatenate((xline_slice,xline_result),1)
			#plt = px.imshow(slices , color_continuous_scale= color , zmin= -np.max(seismic_data_), zmax= np.max(seismic_data_) ,  labels={"animation_frame": "Slice index"} )  
		plt = px.imshow(slices  , animation_frame=0, color_continuous_scale= color , zmin=seismic_data_.min(), zmax=seismic_data_.max() , labels={"animation_frame": "TWT number"} )
		
	# fig = px.imshow(
	# slices,
	# animation_frame=0,   # use slice index as frame
	# color_continuous_scale="gray",
	# zmin=data.min(),
	# zmax=data.max(),
	# labels={"animation_frame": "Slice index"},
	# )
		plt.update_xaxes(title="Crossline Index", showticklabels=True)
		plt.update_yaxes(title="Inline Index", showticklabels=True)

		plt.update_layout(
		title="3D Seismic Cube – Scroll with Slider",
		#width=700,
		#height=700,
		coloraxis_showscale=True , 
		autosize=False,
		width=1150,#800
		height=800,#800
		margin=dict(l=65, r=50, b=65, t=90) , 
		)
		
		plt.show()
		
		
	# hx = [ np.min(xline_number), np.max(xline_number) ]
	# hy = [np.min(twt), np.max(twt) ]
	
	
	
	
	
		

In [None]:
plot( data , result, direction='inline', segy = cmp)

In [None]:
plot( data , result, direction='xline', segy = cmp)

In [None]:

plot( data , result, direction='time-slice', segy = cmp)