In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

### Small function to get the name of a Python object ("variable") as a string

In [2]:
def varname(var):
    '''Little function that returns the name of a variable as a string'''
    import inspect

    frame = inspect.currentframe()
    var_id = id(var)

    for name in frame.f_back.f_locals.keys():
        try:
            if id(eval(name)) == var_id:
                return(name)
        except:
            pass

### Small function to plot a legend based on a dictionary

In [3]:
def create_legend(dictionary, name_of_dictionary=None, save_plot=False):
    '''Plots a legend from an input dictionary. 

    Mandatory argument:
    - dictionary
    
    Optional argument:
    - name of dictionary (to be plotted next to legend)
    - save_plot  True or False. Saves as PNG file.
    
    The function needs a dictionary as input in the following form:
       my_dict = {name:(code, hexcolor)}
       where:
       - "name" is the name of e.g. lithology, Formation or Group etc.
       - "code" its code (numeric! not important for this function, but need numerical value for create_CPI function)
       - "hexcolor" the color (e.g. #FF3424)
    '''
    n = len(dictionary)
    fig = plt.figure(figsize=(2/3*n,2.5))
    plt.scatter(range(n),[1]*n, c=[f[1] for f in dictionary.values()], s=300,marker='s',edgecolor='k')
    for i, k in enumerate(dictionary.keys()):
        txt = dictionary[k][0] + ' (' + str(k) + ')'
        plt.annotate(txt,xy=(i,1),xytext=(i-0.1,1.6),rotation=90)
    plt.axis()
    if name_of_dictionary == None:
        pass
    else:
        try:
            if len(name_of_dictionary)>14 and len(name_of_dictionary)<17: # hide the last few characters and print ... instead
                name_of_dictionary = name_of_dictionary[:13]+'...'
            elif len(name_of_dictionary)>=17:                             # replace the middle bit with ...
                name_of_dictionary = name_of_dictionary[:13]+'...'+name_of_dictionary[-3:]
            plt.annotate(name_of_dictionary,xy=(-1.2,0.9),rotation=90, fontweight='bold', fontsize=14)
        except:
            pass
    plt.xlim([-1.5,n-0.5])
    plt.ylim([-0.5,4])
    plt.axis('off')
    
    if save_plot == True:
        if name_of_dictionary == None:
            plt.savefig('legend.png')
        else:
            try:
                illegal = ['\',''/',':','*','"','<','>','|','#','$','£','{','}']
                for i in illegal:
                    title = name_of_dictionary.replace(i, '_')
                    plt.savefig(title +'.png')
            except:
                plt.savefig('legend.png')
  
    plt.show()

### Function to plot a CPI

In [1]:
def create_CPI(df, well, well_col, depth_col,
               log_dict, log_shading_dict=None,
               top_depth=None, bottom_depth=None,
               plot_litho=False, 
               plot_full_litho_legend=False,
               save_plot=False,
               plot_strat_zones=None,
               figsize=(10,10)):

    ''' Creates a CPI plot of data, either plotting the full range, or a user-
    defined interval.
    Mandatory arguments:
    - df:             dataframe containing the data
    - well:           wellname, used for both filtering the relevant data from 
                      df as well as title on plot
    - well_col:       name of the column containing the well name
    - depth:          name of the depth column in the df. Will be used as 
                      vertical scale (so can be MD or TVDSS)   
    - log_dict:       dictionary containing the logs to be plotted, plus the 
                      "cosmetics". The format of the log_dict is:
                          log:(min-scale, max-scale, colour, linestyle, scale (lin/log), track)
                          (note the value is a tuple with 6 elements)
                      Example:
                          log_dict = {'GR':(0,150,'#004D00','--','lin',0)}
                          - log:       'GR' (string: name of column in df)
                          - min-scale:  0 (integer or float)
                          - max-scale:  150 (integer or float)
                          - color:     '#004D00' (hexcolor passed as string)
                          - linestyle: '--' (matplotlib style, passed as string)
                          - scale:     'linear' or 'log' (matplotlib style)
                          - track:      0 (integer!). First track = 0!
    
    Optional arguments:
    - top_depth:      defines the top of the interval to be plotted. 
    - bottom_depth:   defines the base of the interval to be plotted.
                      top_depth and bottom depth can be set to a either:
                        - "None": the shallowest resp deepest depth in df will 
                           be used for top_depth resp bottom_depth
                        - set to a specific depth
                        - set to a zone-top by using a (python) dictionary:
                          {'My Formation X': 'NAME_OF_FORMATION_COLUMN')}
                          in which:
                          - 'My Formation X' is name (value) of the top
                          - 'NAME_OF_FORMATION_COLUMN' is the name of the column 
                            that contains the value 'My Formation X'
                          If 'My Formation X' is not present in 
                          'NAME_OF_FORMATION_COLUMN', the code will loop through 
                          all dictionaries in the "plot_litho" argument to find in
                          which it occurs and will then find the first older* 
                          value that occurs in 'NAME_OF_FORMATION_COLUMN'. If still
                          successful, the minimum (top_depth) or maximum depth
                          (bottom_depth) will be used instead.
                          * it is assumed the dictionary is (stratigraphically)
                          sorted!                          
    - plot_litho:     defines whether "lithology" type of columns should be plotted. 
                      Pass a dictionary in the following format:
                      {name_of_log1: accompanying_dictionaty1,
                       name_of_log2: accompanying_dictionaty2}, or set to "False"
                      if you do not want to plot
    - plot_full_litho_legend: "True" will plot the full legend (even though not all
                      values might occur in the well/interval) whilst "False" will
                      only plot the items actually present. "None" will not print
                      a legend at all.
    - log_shading_dict: defines whether there is any shading to logs. Pass the name 
                      of the dictionary to apply shading, pass "False" or omit for
                      no shading. The code will automatically find the right track
                      to apply shading to.
                      The format of the log_shading_dict is:
                          log_shading_dict = {left:(right, color)}
                          (note the value is a tuple with 2 elements)
                      The left and right can be a log or a constant, but at least one
                      or both should be a log. The color is in hex.
                      Example:
                          log_shading_dict = {'NPHI':('RHOB','#825000'),
                                               0    :('GR'  ,'#48FF92')}
    - save_plot:      set to "True", the plot will be saved to a PNG file in addition
                      to displaying it on the screen.
    - plot_strat_zones: "None" if no zonation needs to plotted across all tracks. If 
                      this is required, a dictionary should be passed as argument. 
                      This dictionary should look like:
                          zonation_cosmetics = {'name_of_column_containing_zonation:
                                                (linecolor, linewidth, linestyle)'}
                      Linecolor and -style are matplotlib-style; linewidth is any
                      number >=0.
                      The dictionary can contain multiple zonations with each their 
                      cosmetics.
    - figsize:        Default size of figure is (10,10) but another size may be passed
                      through this optional argument
    '''
    
    def customize_dictionary(data, column_in_df, dictionary):
        '''Small function to create a "local" version of the passed dictionary 
        (removing items not present in the current well)'''
        removes = []
        for k in dictionary.keys():
            if (k in data[column_in_df].unique()) == False:
                removes.append(k)
        for k in removes:
            dictionary.pop(k)
        return(dictionary)

    
    def get_tracks_shading(log_shading_dict):
        '''gets all track numbers where shading should be added'''
        tracks = []
        for k,v in zip(log_shading_dict.keys(), log_shading_dict.values()):
            try:
                # check which track: if two curves, they need top be in the same track 
                # if curve and constant, only the correct track number needs to be found
                # it could be the left (k) or right hand-side (v[0]) boundary that is a constant (but not both)
                if (isinstance(k,int) or isinstance(k,float)) and (log_dict.get(v[0]) is not None):
                    tracks.append(log_dict[v[0]][5])
                elif (isinstance(v[0],int) or isinstance(v[0],float)) and (log_dict.get(k) is not None):
                    tracks.append(log_dict[k][5])
                elif (log_dict.get(k) is not None) and (log_dict.get(v[0]) is not None):
                    if log_dict[k][5] == log_dict[v[0]][5]:
                        tracks.append(log_dict[k][5])
                else:
                    tracks.append(None)
            except:
                # curves not in the same track, or curve not plotted/in log_dict
                tracks.append(None)
                pass        
        return(tracks)
    
    
    def transform_scale(old_min, old_max, new_min, new_max, log):
        '''Transforms a log from its own scale into another (needed for shading:
        invisible logs at the same scale as the main log in the track)'''
        return(new_max-new_min)*(log-old_min)/(old_max-old_min)+new_min
              
    
    def get_zone_depth(zone_depth, pick='top'):
        '''Small function to get the top of the interval to plot. This can be:
        - top of the dataset (zone_depth == None)
        - top of a stratigraphic element (zone_depth is a string)
        - a user-entered depth (zone_depth is an int or float)
        Optional argument: pick. Valid values: "top" or "base". Default="top".
        "Top" is typically used for the top zone_depth of the interval to be plotted,
        whilst bottom will give the deepest depth encountered for the zone entered.'''
        
        # default pick is "top". Use in case of erroneous input
        if pick != 'top' and pick!='bottom':
            pick = 'top'

        try:
            if zone_depth == None:
                # use max range (i.e. use top=min, bottom=max)
                if pick == 'top':
                    zone_depth = min(data[depth_col])
                elif pick == 'bottom':
                    zone_depth = max(data[depth_col])
                return(zone_depth)
            
            elif isinstance(zone_depth, float) or isinstance(zone_depth, int):
                # numeric value
                if pick == 'top':
                    zone_depth = min(data[depth_col][data[depth_col]>=zone_depth])
                elif pick == 'bottom':
                    zone_depth = max(data[depth_col][data[depth_col]<=zone_depth])
                return(zone_depth)
            
            elif isinstance(zone_depth, dict):
                # find top
                if list(zone_depth.keys())[0] in data[list(zone_depth.values())[0]].values:
                    # zone_depth in dataset:
                    if pick=='top':
                        zone_depth = min(data[depth_col][data[list(zone_depth.values())[0]]==list(zone_depth.keys())[0]])
                    elif pick=='bottom':
                        zone_depth = max(data[depth_col][data[list(zone_depth.values())[0]]==list(zone_depth.keys())[0]])
                    return(zone_depth)
                else:
                    # zone_pick does not occur in the sliced dataset. Find out whether it occurs in any of the
                    # dictionaries passed as argument. If so, use that particular dictionary to find the first 
                    # (deeper) zone_top in dictionary that is also present in the sliced dataset
                    # (if nothing, use min/max depth of dataset)
                    for v in plot_litho.values(): # v is an entire dictionary (in a dictionary itself)
                        if list(zone_depth.keys())[0] in [x[0] for x in v.values()]:
                            # this is the (first) dictionary that contains the zone_depth. Use this
                            # dictionary to find the first deeper top present in dataset
                            for i, z in zip(v.keys(), v.values()):
                                if z[0] == list(zone_depth.keys())[0]:
                                    # get the first top further down 
                                    # NB: STRATIGRAPHIC DICTIONARY IS ASSUMED TO BE ORDERE BY AGE/DEPTH!
                                    for j in range(i, len(v)+1):
                                        if v[j][0] in data[list(zone_depth.values())[0]].values:
                                            print(f'...Top {list(zone_depth.keys())[0]} does not occur in this well. Using the first available top based on dictionary values: {v[j][0]}\n(it is assumed the dictionary is stratigraphically sorted from young to old!)')
                                            # first strat further down that is present in the sliced dataset.
                                            # always use "top". Subtract a very small number to ensure we end up just above the first 
                                            # present top, otherwise the legend will contain an entry for this zone also.
                                            if pick == 'top':
                                                zone_depth = min(data[depth_col][data[list(zone_depth.values())[0]]==v[j][0]])-0.000001
                                            elif pick =='bottom':
                                                zone_depth = max(data[depth_col][data[list(zone_depth.values())[0]]==v[j][0]])+0.000001
                                            return(v[j][0], zone_depth)                                                     
                    # if code comes here, no match in dictionary was found: use fall-back min/max
                    if pick == 'top':
                        print(f'...Top {list(zone_depth.keys())[0]} neither occurs in this well, nor has an entry in the dictionary. Using minimum depth instead...')
                        return('min depth', min(data[depth_col]))
                    elif pick == 'bottom':
                        print(f'...Top {list(zone_depth.keys())[0]} neither occurs in this well, nor has an entry in the dictionary. Using maximum depth instead...')
                        return('max depth', max(data[depth_col]))                  
                                        
            else:
                # last fall-back: use min/max
                if pick == 'top':
                    zone_depth = min(data[depth_col])
                elif pick == 'bottom':
                    zone_depth = max(data[depth_col])
                return(zone_depth)
            
        except:
            # fall-back: use min/max
            if pick == 'top':
                zone_depth = min(data[depth_col])
            elif pick == 'bottom':
                zone_depth = max(data[depth_col])
            return(zone_depth)
    
    
    #########################################
    ### THE MAIN CODE STARTS HERE ###
    #########################################
    # preserve the original input top_depth and bottom_depth to use as annotation
    if isinstance(top_depth, dict):
        topint = 'top ' + list(top_depth.keys())[0]
    elif isinstance(top_depth, float) or isinstance(top_depth, int):
        topint = top_depth
    else:
        topint = 'min depth'
 
    if isinstance(bottom_depth, dict):
        botint = 'base ' + list(bottom_depth.keys())[0]
    elif isinstance(top_depth, float) or isinstance(bottom_depth, int):
        botint = bottom_depth
    else:
        botint = 'max depth'
       

    # name of the depth column
    depth_col = depth_col

    # get the maximum number of tracks defined in the log dictionary
    ncols_logs = max([x[5] for x in log_dict.values()]) + 1
    
    # need extra column(s) for lithology?
    if plot_litho != False:
        ncols_litho = len(plot_litho)
    else:
        ncols_litho = 0

    # slice df to just the well in question
    data = df[df[well_col]==well]
    if len(data)==0:
        print(f'There is no data in the dataset passed for well {well}...\nNo plot can be made.')
        return()
      
      
    # in case the user entered a top that does neither exist in the dataframe or the dictionary
    top_depth = get_zone_depth(top_depth, pick='top')
    if type(top_depth) == tuple:
        topint = top_depth[0]
        top_depth = top_depth[1]
    bottom_depth = get_zone_depth(bottom_depth, pick='bottom')
    if type(bottom_depth) == tuple:
        botint = bottom_depth[0]
        bottom_depth = bottom_depth[1]
  
    
    # slice - will only have an effect if at least one of top_depth or bottom_depth == None (otherwise already sliced)
    data = data.loc[((data[well_col]==well)&(data[depth_col]>=top_depth)&(data[depth_col]<=bottom_depth))]
    
    # length of interval
    interval= bottom_depth - top_depth
    
    # default figsize
    if figsize == None:
        figsize = (10,10)
    
    # check whether any shading should be applied. If so, create a list with tracks (automatically extract using log_dict)
    if log_shading_dict is not None and len(log_shading_dict)>0:
        tracks = get_tracks_shading(log_shading_dict)
        alpha = 0.3

        
    # set up the figure
    ncols = ncols_logs + ncols_litho
    
    # need an extra column for zonation lines/labels?
    if isinstance(plot_strat_zones, dict):
        ncols += len(plot_strat_zones)
    
    fig, ax = plt.subplots(nrows=1, ncols=ncols, figsize=figsize, sharey=True)
    fig.suptitle(f'\nWell {well}\n({str(topint)} - {str(botint)})', fontweight='bold', fontsize=15)

    for t in range(ncols_logs):
        ########## IN THIS LOOP ONLY THE FIRST LOG IN EACH TRACK IS PLOTTED ##########
        # get the logs for this round/track:
        logs = [l for l, v in zip(log_dict.keys(), log_dict.values()) if v[5]==t]

        # set up subplot
        plt.subplot(1,ncols,t+1)

        # hide depth-ticks for all but first subplot
        if t==0:
            plt.ylabel(depth_col,fontweight='bold')
        elif t>0:
            plt.yticks([])
          
        # first log is addressed like this:
        try: 
            color = log_dict[logs[0]][2]
        except:
            color = 'black' # default color
        
        try:
            if log_dict[logs[0]][0]==None:
                min_scale = min(data[logs[0]]) # use curve minimum in case None
            else:
                min_scale = log_dict[logs[0]][0]
        except: 
            min_scale = min(data[logs[0]]) # use curve minimum in case error
            
        try:
            if log_dict[logs[0]][1] == None:
                max_scale = max(data[logs[0]]) # use curve maximum in case None
            else:
                max_scale = log_dict[logs[0]][1]
        except:
            max_scale = max(data[logs[0]]) # use curve maximum in case error

        try:
            linestyle = log_dict[logs[0]][3]
        except:
            linestyle = '-' # defautl
        
        try:
            xscale = log_dict[logs[0]][4]
        except:
            xscale = 'linear' # default
        
        try:
            plt.plot(data[logs[0]],data[depth_col], color=color, linestyle=linestyle, linewidth=0.5)

        except:
            pass
  
        
        # check whether shading should be applied. Loop throug tracks every time (but only track==t will be plotted)
        # not the most elegant, but hardly a big (time) issue.
        try:
            if len(tracks)>0:
                tmp_shading_dict = dict() # empty dictionary for this track
                for k, v in zip(log_shading_dict.keys(), log_shading_dict.values()):
                    if k in [kl for kl, vl in zip(log_dict.keys(), log_dict.values()) if vl[5]==t]:
                        tmp_shading_dict[k] = v
                    elif v[0] in [kl for kl, vl in zip(log_dict.keys(), log_dict.values()) if vl[5]==t]:
                        tmp_shading_dict[k] = v

                # at this point, we have a shading-dictionary for _this_ track
                if len(tmp_shading_dict)>0:
                    # this track has shading
                    for sh in tmp_shading_dict.items():
                        shcolor = sh[1][1]
                        if sh[0] == logs[0]:
                            # the left hand side log is the *key* in log_shading_dictionary
                            left_hand_log = data[sh[0]]

                            if (isinstance(sh[1][0],int) or isinstance(sh[1][0],float)):
                                # right hand-side boundary is number: make log first
                                right_hand_log = np.linspace(sh[1][0], sh[1][0], len(left_hand_log))
                            else:
                                right_hand_log = data[sh[1][0]]

                            if (isinstance(sh[1][0],int) or isinstance(sh[1][0],float)):
                                # no re-scaling to be done here
                                pass
                            else:
                                # transform 
                                right_hand_log = transform_scale(log_dict[sh[1][0]][0], log_dict[sh[1][0]][1], min_scale, max_scale, right_hand_log)

                            # add invisible versions of the curves in the track
                            plt.plot(left_hand_log, data[depth_col], linewidth=0, marker=None)
                            plt.plot(right_hand_log, data[depth_col], linewidth=0, marker=None)
                            if min_scale < max_scale:
                                plt.fill_betweenx(data[depth_col], right_hand_log, left_hand_log, where=left_hand_log<right_hand_log, color=shcolor, alpha=alpha)
                            else:
                                plt.fill_betweenx(data[depth_col], right_hand_log, left_hand_log, where=left_hand_log>right_hand_log, color=shcolor, alpha=alpha)

                        elif sh[1][0] == logs[0]:
                            # the left hand side log is the *value* in log_shading_dictionary
                            right_hand_log = data[sh[1][0]]

                            if (isinstance(sh[0],int) or isinstance(sh[0],float)):
                                # left hand-side boundary is number: make log first
                                left_hand_log = np.linspace(sh[0], sh[0], len(right_hand_log))
                            else:
                                left_hand_log = data[sh[0]]

                            if (isinstance(sh[0],int) or isinstance(sh[0],float)):
                                # no re-scaling to be done here
                                pass
                            else:
                                # transform 
                                left_hand_log = transform_scale(log_dict[sh[0]][0], log_dict[sh[0]][1], 
                                                                min_scale, max_scale, left_hand_log)    

                            # add invisible versions of the curves in the 1st track
                            plt.plot(left_hand_log, data[depth_col], linewidth=0, marker=None)
                            plt.plot(right_hand_log, data[depth_col], linewidth=0, marker=None)
                            if min_scale < max_scale:
                                plt.fill_betweenx(data[depth_col], right_hand_log, left_hand_log, where=left_hand_log<right_hand_log, color=shcolor, alpha=alpha)
                            else:
                                plt.fill_betweenx(data[depth_col], right_hand_log, left_hand_log, where=left_hand_log>right_hand_log, color=shcolor, alpha=alpha)

                        else:
                            # shading for a log not plotted yet (only the first in each track are plotted this round)
                            # left hand:
                            if (isinstance(sh[0],int) or isinstance(sh[0],float)):
                                # left hand is constant
                                left_hand_log = np.linspace(min_scale, min_scale, len(right_hand_log))
                            else:
                                # left hand is log
                                left_hand_log = data[sh[0]]
                                left_hand_log = transform_scale(log_dict[sh[0]][0], log_dict[sh[0]][1], min_scale, max_scale, left_hand_log)

                            # right hand:
                            if (isinstance(sh[1][0],int) or isinstance(sh[1][0],float)):
                                # right hand is constant
                                right_hand_log = np.linspace(max_scale, max_scale, len(left_hand_log))
                            else:
                                # right hand is log
                                right_hand_log = data[sh[1][0]]                       
                                right_hand_log = transform_scale(log_dict[sh[1][0]][0], log_dict[sh[1][0]][1], min_scale, max_scale, right_hand_log)

                            # add invisible versions of the curves in the 1st track
                            plt.plot(left_hand_log, data[depth_col], linewidth=0, marker=None)
                            plt.plot(right_hand_log, data[depth_col], linewidth=0, marker=None)
                            if min_scale < max_scale:
                                plt.fill_betweenx(data[depth_col], right_hand_log, left_hand_log, where=left_hand_log<right_hand_log, color=shcolor, alpha=alpha)
                            else:
                                plt.fill_betweenx(data[depth_col], right_hand_log, left_hand_log, where=left_hand_log>right_hand_log, color=shcolor, alpha=alpha)
        except:
            pass
        plt.xlabel(logs[0], color=color)
        plt.xticks(color=color)
        plt.xscale(xscale)
        plt.xlim([min_scale,max_scale])
        plt.ylim([bottom_depth, top_depth])
        
        # add zonation lines on log-tracks:
        try:
            if isinstance(plot_strat_zones, dict):
                for zn in enumerate(plot_strat_zones):
                    strat_col = list(plot_strat_zones.keys())[zn[0]] 
                    color = list(plot_strat_zones.values())[zn[0]][0]
                    linewidth = list(plot_strat_zones.values())[zn[0]][1]
                    linestyle = list(plot_strat_zones.values())[zn[0]][2]
                    for fm in data.loc[(data[strat_col].notna()),strat_col].unique():
                        plt.plot([log_dict[logs[0]][0], log_dict[logs[0]][1]], 
                                 [min(data[depth_col][data[strat_col]==fm]),min(data[depth_col][data[strat_col]==fm])],
                                 color=color, linewidth=linewidth, linestyle=linestyle)
        except:
            pass
        
        
        
        axes = [ax[t]]
        # first log already plotted - add only scales for any more logs in each track
        for u in range(1,len([v for v in log_dict.values() if v[5]==t])): 
            try:
                color = log_dict[logs[u]][2]
                min_scale = log_dict[logs[u]][0]
                max_scale = log_dict[logs[u]][1]
                linestyle = log_dict[logs[u]][3]
                xscale = log_dict[logs[u]][4]

                axes.append(ax[t].twiny())
                axes[u].spines['top'].set_position(('axes', (-.15)-(u-1)*0.08))   # magic numbers to get the offset 
                axes[u].set_frame_on(True)
                axes[u].patch.set_visible(False)

                axes[u].plot(data[logs[u]],data[depth_col], color=color, linestyle=linestyle, linewidth=0.5)
                locator = plt.MaxNLocator(prune='both', nbins=2)    
                axes[u].xaxis.set_major_locator(locator)
                axes[u].set_xscale(xscale)
                axes[u].set_xlabel(logs[u], color=color)
                axes[u].set_xlim([min_scale,max_scale])
                axes[u].tick_params(axis='x', colors=color)
                axes[u].set_ylim([bottom_depth, top_depth])
                axes[u].set_yticks([])
            except:
                pass

              
    for i in range(ncols_litho):
        # any litho columns (lithology, Group, Formation, &c)
        ax1 = plt.subplot(1, ncols, ncols_logs+i+1)

        col_name = list(plot_litho.keys())[i]
        litho_dict = plot_litho[list(plot_litho.keys())[i]]
        
        ax1.plot(data[col_name], data[depth_col])
        lithoscale = min(litho_dict)-2

        for li in litho_dict.keys():
            color=litho_dict[li][1]
            ax1.fill_betweenx(data[depth_col], lithoscale-2, data[col_name], where=data[col_name]==li, color=color)
            
        try:
            if isinstance(plot_strat_zones, dict):
                for zn in enumerate(plot_strat_zones):
                    strat_col = list(plot_strat_zones.keys())[zn[0]] 
                    color = list(plot_strat_zones.values())[zn[0]][0]
                    linewidth = list(plot_strat_zones.values())[zn[0]][1]
                    linestyle = list(plot_strat_zones.values())[zn[0]][2]
                    for fm in data.loc[(data[strat_col].notna()),strat_col].unique():
                         ax1.plot([lithoscale-2,lithoscale-1], 
                                 [min(data.loc[(data[strat_col]==fm),depth_col]),min(data.loc[(data[strat_col]==fm),depth_col])],
                                 color=color, linewidth=linewidth, linestyle=linestyle)
        except:
            pass
          
        ax1.set_xticks([])
        ax1.set_xlim([lithoscale-2,lithoscale-1]) # just be on the safe side
        ax1.set_xlabel(col_name, fontweight='bold', rotation=90)
        ax1.set_ylim([bottom_depth,top_depth])
        ax1.set_yticks([])
    
    title = str(well).replace('/','_') + ' ' + str(topint) + ' - ' + str(botint)
    
    plt.savefig(title+'.png')
    if save_plot == True:
        plt.savefig(title+'.png')


    if isinstance(plot_strat_zones, dict):
        # last (i.e. right hand side) column of the CPI. Use for stratigraphy labels
        try:
            for zn in enumerate(plot_strat_zones):
                strat_col = list(plot_strat_zones.keys())[zn[0]] 
                color = list(plot_strat_zones.values())[zn[0]][0]
                linewidth = list(plot_strat_zones.values())[zn[0]][1]
                linestyle = list(plot_strat_zones.values())[zn[0]][2]
                for fm in data.loc[(data[strat_col].notna()),strat_col].unique():
                    avg = data.loc[((data[strat_col]==fm)&(data[depth_col].notna())),depth_col].mean()
                    ax[ncols-len(plot_strat_zones)+zn[0]].plot([0,1], [min(data.loc[(data[strat_col]==fm),depth_col]),min(data.loc[(data[strat_col]==fm),depth_col])],
                                     color=color, linewidth=linewidth, linestyle=linestyle)
                    ax[ncols-len(plot_strat_zones)+zn[0]].annotate(fm, xy=(0.05, avg), xytext=(0.05, avg), 
                                         va='center', fontsize=8, color=color)
                # remove the frame
                ax[ncols-len(plot_strat_zones)+zn[0]].axis('off')
        except:
            pass        
    
    # display the CPI
    plt.show()

    
    # add legends for litho-tracks
    for i in range(ncols_litho):
        if plot_full_litho_legend==False:
            modified_dictionary = plot_litho[list(plot_litho.keys())[i]].copy()
            modified_dictionary = customize_dictionary(data, list(plot_litho.keys())[i], modified_dictionary)
            create_legend(modified_dictionary, list(plot_litho.keys())[i])
        elif plot_full_litho_legend==True:
            modified_dictionary = plot_litho[list(plot_litho.keys())[i]]
            create_legend(modified_dictionary, list(plot_litho.keys())[i])
        else:
            pass

### Function to plot a bar-chart

In [5]:
def bar_chart_percent_lithology(df, strat_col, litho_col, strat_dict, litho_dict, normalized=True, width=1,
                                figsize=(12,8), save_plot=False):
    '''Plots a bar chart, stacking the different lithologies per unique value in stratigraphy.
    
    Mandatory arguments:
    - df:         pandas DataFrame
    - strat_col:  name of the stratigraphy column
    - litho_col:  name of the lithology column
    - strat_dict: stratigraphic dictionary
    - litho_dict: lithology dictionary
    
    Dictionaries must have the following format:
    numeric key: (alpha-numeric description, hexcolor) 
    (i.e. the value is a tuple containing a description and a colour)
    
    Optional arguments:
    - normalized: if True, then normalized to 100%, if False the bars will have height proportional
                  to the number of datapoints
    - width:      fractional width, between 0 and 1. A value of 1 means no gap between "neighbouring bars"
    - figsize:    (tuple) in case a larger figure is needed. Default is (12,8)
    '''
    # default is normalized = True
    if normalized!=True and normalized!=False:
        normalized=True
    
    # default width = 1
    if isinstance(width, int) or isinstance(width, float):
        if width >= 0 and width <= 1:
            pass
        else:
            width = 1
    else:
        width = 1
        
    df = df
    strat_col = strat_col
    litho_col = litho_col
    strat_dict = strat_dict
    litho_dict = litho_dict
    
    try:
        fig, ax = plt.subplots(figsize=figsize)
    except:
        fig, ax = plt.subplots(figsize=(12,8))

    # create a list with zeros: update each loop by adding the values of that current loop so the next bars
    # will be stacked
    last_liths = np.linspace(0,0,len(strat_dict.keys())) 
    
    for i, li in enumerate(litho_dict.keys()):        
        liths = []
        for x in strat_dict.values():
            if normalized==True:
                try:
                    liths.append(100*len(df[litho_col][df[strat_col]==x[0]][df[litho_col]==li])/len(df.loc[df[strat_col]==x[0],strat_col]))
                except:
                    liths.append(0)
            else:
                liths.append(len(df[litho_col][df[strat_col]==x[0]][df[litho_col]==li]))
        if sum(liths)==0:
            label = None
        else:
            label =  str(li) + ' (' + litho_dict[li][0] + ')'
    
        ax.bar([v[0] for v in strat_dict.values()], liths,
               width, bottom=last_liths, color=litho_dict[li][1], label=label, edgecolor='k', linewidth=0.5)
        last_liths = [sum(x) for x in zip(last_liths, liths)]

    ax.set_xticks(range(len([v[0] for v in strat_dict.values()])))
    ax.set_xlim([-0.5,len([v[0] for v in strat_dict.values()])-0.5])
    ax.set_xticklabels([v[0] for v in strat_dict.values()], rotation=90)
    if normalized == True:
        ax.set_ylabel(f'% lithology in {strat_col}', fontweight='bold')
        ax.set_ylim([0,100])
    else:
        ax.set_ylabel(f' number of lithology points in {strat_col}', fontweight='bold')
    ax.set_title(f'Lithology per {strat_col}', fontweight='bold')
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', title='$\\bf{lithologies}$', fancybox = True)

    if save_plot == True:
        if normalized==True:
            plt.savefig(f'Lithology per {strat_col}_normalized.png')
        else:
            plt.savefig(f'Lithology per {strat_col}.png')
    plt.show()

### Function to plot a confusion matrix

In [6]:
def plot_confusion_matrix(y_true, y_pred, classes, normalize=False, title=None, 
                          cmap=plt.cm.Blues, zero_out_diagonal=False, save_plot=False):
    if not title:
        if normalize:
            title = 'Normalized confusion matrix'
        else:
            title = 'Confusion matrix'

    cm = confusion_matrix(y_true, y_pred)
    
    if zero_out_diagonal == True:
        np.fill_diagonal(cm, 0)
    
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    fig, ax = plt.subplots(figsize=(10, 10))
    im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
    ax.figure.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    ax.set(xticks=np.arange(cm.shape[1]),
           yticks=np.arange(cm.shape[0]),
           xticklabels=classes, yticklabels=classes,
           title=title,
           ylabel='True label',
           xlabel='Predicted label')

    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, format(cm[i, j], fmt),
                    ha="center", va="center",
                    color="white" if cm[i, j] > thresh else "black")
    fig.tight_layout()
    plt.xlim(-0.5, len(np.unique(y_true))-0.5)
    plt.ylim(len(np.unique(y_true))-0.5, -0.5)
    np.set_printoptions(precision=2)
    plt.savefig(title+'.png')
    return ax

### Function to plot a matrix of histograms

In [1]:
import math

def histogram_matrix(df, log_dict, var, by, by_dict, by_subset_list=None, figsize=None, save_plot=False):
    '''Creates a "matrix" with histograms of a categorial/discrete type. 
    Optionally, a subset can be plotted. It is also possible to filter the data 
    based on another categorial/discrete variable.
    
    Mandatory arguments:
    - df               the pandas DataFrame
    - log_dict         the (python) dictionary containing info on the logs
                       (cosmetics such as color, linestyle and min/max scale)
    - var              the variable for plotting in the histogram
    - by               used to "split" the dataset (every unique value in "by" 
                       will appear in its own histogram)  
    - by_dict          dictionary containing info on the categorial variable 
                       used to split the data. The format of the (python) is
                       as follows {key1:(description1, hexcolor1), 
                                   key2:(description2, hexcolor2),
                                   ....}
    
    Optional arguments:
    - by_subset_list   if "None", the entire dataset from variabele "var" will
                       be plotted. Alternatively, a list with values (occurring
                       in "var") can be passed. Only those will be plotted in 
                       that case.
    - figsize          tuple defining the size of the figure. Enter "None" to 
                       set automatically
    '''
    
    no_bins_over_range = 100

    
    # check if df needs slicing
    if by_subset_list == None:
        data = df
    else:
        data = df[df[by].isin(by_subset_list)]
    

    # find the number of rows and columns:
    n = len(data[by].unique())

    if n > 3:
        ncols = 3
        nrows = math.ceil(n/ncols)
    else:
        ncols = n
        nrows = 1
    
    if figsize==None:
        figsize=(ncols*5, nrows*4)
            
    
    # set up the figure                          
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize,sharex=True)
    if len(data[by].unique())>1:
        axes = axes.ravel() # this does the trick to plot histograms in subplots..??

        for ax, li in zip(enumerate(axes),data[by].unique()):
            ax[1].hist(data[var][data[by]==li],
                       color=by_dict[li][1],
                       bins = np.linspace(log_dict[var][0],log_dict[var][1],no_bins_over_range),
                       stacked=True, density=True)
            ax[1].set_title(by_dict[li][0], fontweight='bold')

            ax2 = ax[1].twinx()
            ax2.hist(df[var][df[by]==li],
                       color='red', histtype='step',
                       bins = np.linspace(log_dict[var][0],log_dict[var][1],no_bins_over_range),
                       cumulative=True)
            ax2.set_yticks([])

    else:
        li = data[by].unique()[0]
        axes.hist(data[var][data[by]==li],
                   color=by_dict[li][1],
                   bins = np.linspace(log_dict[var][0],log_dict[var][1],no_bins_over_range),
                   stacked=True, density=True)
        axes.set_title(by_dict[li][0], fontweight='bold')

        ax2 = axes.twinx()
        ax2.hist(df[var][df[by]==li],
                   color='red', histtype='step',
                   bins = np.linspace(log_dict[var][0],log_dict[var][1],no_bins_over_range),
                   cumulative=True)
        ax2.set_yticks([])

            
            
    # delete subplots that are too many in the matrix-grid
    if len(data[by].unique())>1:
        for i in range(len(axes)-1, n-1, -1):
            axes[i].set_visible(False)
    
    plt.suptitle(var, fontweight='bold')
    plt.xlim([log_dict[var][0],log_dict[var][1]])
    
    if save_plot == True:
        plt.savefig(f'{var}_histogram_matrix.png')
        
    plt.show()

### Function to plot a box-and-whiskers plot

In [8]:
def box_plot(df, log_dict, var, by, by_dict, addendum_title=None, by_subset_list=None, figsize=None, save_plot=False):
    '''Creates a box-and-whiskers plot for a variable, split by a categorial/
    discrete variable. The boxes span the 25th-75th percentile. The notch + mark
    in the box the 50th percentile, whilst the whiskers are thr 5th resp 95th 
    percentiles Points beyond whiskers are plotted as individual markers.
    
    Mandatory arguments:
    - df               the pandas DataFrame
    - log_dict         the (python) dictionary containing info on the logs
                       (cosmetics such as color, linestyle and min/max scale)
    - var              the variable for plotting in the histogram
    - by               used to "split" the dataset (every unique value in "by" 
                       will appear in its own histogram)  
    - by_dict          dictionary containing info on the categorial variable 
                       used to split the data. The format of the (python) is
                       as follows {key1:(description1, hexcolor1), 
                                   key2:(description2, hexcolor2),
                                   ....}
    
    Optional arguments:
    - addendum_title   if "None", nothing will be added to the default title. 
                       When not "None", the string will be added as a suffix.
                       This can be useful e.g. when creating plots in batch 
                       (e.g. per well, etc)
    - by_subset_list   if "None", the entire dataset from variabele "var" will
                       be plotted. Alternatively, a list with values (occurring
                       in "var") can be passed. Only those will be plotted in 
                       that case.
    - figsize          tuple defining the size of the figure. Enter "None" to 
                       set automatically'''

    print(f'by={by}')
    # define the size of the plot
    if figsize == None:
        figsize=(12,8)   
    
    # check if df needs slicing
    if by_subset_list == None:
        data = df
    else:
        data = df[df[by].isin(by_subset_list)]

    title = f'{var} by {by}'
    labels = [litho_dict[int(x)][0] for x in data[by].unique()]
    colors = [litho_dict[int(x)][1] for x in data[by].unique()]
    
    # create the data to plot each "by" as an own bar
    data = [np.array(data[var][data[by]==b]) for b in data[by].unique()]
    
    if addendum_title != None:
        title += f'\n{addendum_title}'
  

    # setting up the figure
    fig = plt.figure(figsize=figsize) 
    ax = fig.add_subplot(111)
  
    # Creating axes instance
    bp = ax.boxplot(data, patch_artist=True, notch='True', vert=0)

    for patch, color in zip(bp['boxes'], colors):
        patch.set_facecolor(color)

        # changing color and linewidth of whiskers 
        for whisker in bp['whiskers']: 
            whisker.set(color ='#000000', linewidth = 0.5, linestyle ='-') 

        # changing color and linewidth of caps 
        for cap in bp['caps']: 
            cap.set(color ='#FFFFFF', linewidth = 2) 

        # changing color and linewidth of medians 
        for median in bp['medians']: 
            median.set(color ='red', linewidth = 3) 

        # changing style of fliers 
        for flier in bp['fliers']: 
            flier.set(marker ='D', color ='#e7298a', alpha = 0.5) 

        # x-axis labels 
        ax.set_yticklabels(labels, fontweight='bold')

        # Adding title  
        plt.title(title, fontweight='bold') 

        # Removing top axes and right axes ticks 
        ax.get_xaxis().tick_bottom() 
        ax.get_yaxis().tick_left() 

    plt.xlim([log_dict[var][0],log_dict[var][1]])
    
    if save_plot==True:
        plt.savefig(f'{var}_boxplot.png')
    plt.show() 